support pruning

pull/220/head
Won-Kyu Park 2023-09-05 04:47:43 +09:00
parent 6a5936c3e0
commit ac4d44df83
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 11 additions and 1 deletions

View File

@ -667,6 +667,14 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
return state_dict
def prune_model(model, isxl=False):
keys = list(model.keys())
base_prefix = "conditioner." if isxl else "cond_stage_model."
for k in keys:
if "diffusion_model." not in k and "first_stage_model." not in k and base_prefix not in k:
model.pop(k, None)
return model
def to_half(sd):
for key in sd.keys():
if 'model' in key and sd[key].dtype == torch.float:
@ -752,6 +760,8 @@ def savemodel(state_dict,currentmodel,fname,savesets,metadata={}):
if "fp16" in savesets:
state_dict = to_half(state_dict)
if "prune" in savesets:
state_dict = prune_model(state_dict, isxl)
try:
if ext == ".safetensors":

View File

@ -86,7 +86,7 @@ def on_ui_tabs():
with gr.Accordion("Save Settings", open=False):
with gr.Row():
with gr.Column(scale = 3):
save_sets = gr.CheckboxGroup(["save model", "overwrite","safetensors","fp16","save metadata"], value=["safetensors"], show_label=False, label="save settings")
save_sets = gr.CheckboxGroup(["save model", "overwrite","safetensors","fp16","save metadata","prune"], value=["safetensors"], show_label=False, label="save settings")
with gr.Column(min_width = 50, scale = 1):
components.id_sets = gr.CheckboxGroup(["image", "PNG info"], label="save merged model ID to")