support pruning
parent
6a5936c3e0
commit
ac4d44df83
|
|
@ -667,6 +667,14 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
|||
|
||||
return state_dict
|
||||
|
||||
def prune_model(model, isxl=False):
|
||||
keys = list(model.keys())
|
||||
base_prefix = "conditioner." if isxl else "cond_stage_model."
|
||||
for k in keys:
|
||||
if "diffusion_model." not in k and "first_stage_model." not in k and base_prefix not in k:
|
||||
model.pop(k, None)
|
||||
return model
|
||||
|
||||
def to_half(sd):
|
||||
for key in sd.keys():
|
||||
if 'model' in key and sd[key].dtype == torch.float:
|
||||
|
|
@ -752,6 +760,8 @@ def savemodel(state_dict,currentmodel,fname,savesets,metadata={}):
|
|||
|
||||
if "fp16" in savesets:
|
||||
state_dict = to_half(state_dict)
|
||||
if "prune" in savesets:
|
||||
state_dict = prune_model(state_dict, isxl)
|
||||
|
||||
try:
|
||||
if ext == ".safetensors":
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ def on_ui_tabs():
|
|||
with gr.Accordion("Save Settings", open=False):
|
||||
with gr.Row():
|
||||
with gr.Column(scale = 3):
|
||||
save_sets = gr.CheckboxGroup(["save model", "overwrite","safetensors","fp16","save metadata"], value=["safetensors"], show_label=False, label="save settings")
|
||||
save_sets = gr.CheckboxGroup(["save model", "overwrite","safetensors","fp16","save metadata","prune"], value=["safetensors"], show_label=False, label="save settings")
|
||||
with gr.Column(min_width = 50, scale = 1):
|
||||
components.id_sets = gr.CheckboxGroup(["image", "PNG info"], label="save merged model ID to")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue