diff --git a/scripts/mergers/model_util.py b/scripts/mergers/model_util.py index 0db48a2..d6c2b8d 100644 --- a/scripts/mergers/model_util.py +++ b/scripts/mergers/model_util.py @@ -667,6 +667,14 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path): return state_dict +def prune_model(model, isxl=False): + keys = list(model.keys()) + base_prefix = "conditioner." if isxl else "cond_stage_model." + for k in keys: + if "diffusion_model." not in k and "first_stage_model." not in k and base_prefix not in k: + model.pop(k, None) + return model + def to_half(sd): for key in sd.keys(): if 'model' in key and sd[key].dtype == torch.float: @@ -752,6 +760,8 @@ def savemodel(state_dict,currentmodel,fname,savesets,metadata={}): if "fp16" in savesets: state_dict = to_half(state_dict) + if "prune" in savesets: + state_dict = prune_model(state_dict, isxl) try: if ext == ".safetensors": diff --git a/scripts/supermerger.py b/scripts/supermerger.py index 123864d..3d0cbaf 100644 --- a/scripts/supermerger.py +++ b/scripts/supermerger.py @@ -86,7 +86,7 @@ def on_ui_tabs(): with gr.Accordion("Save Settings", open=False): with gr.Row(): with gr.Column(scale = 3): - save_sets = gr.CheckboxGroup(["save model", "overwrite","safetensors","fp16","save metadata"], value=["safetensors"], show_label=False, label="save settings") + save_sets = gr.CheckboxGroup(["save model", "overwrite","safetensors","fp16","save metadata","prune"], value=["safetensors"], show_label=False, label="save settings") with gr.Column(min_width = 50, scale = 1): components.id_sets = gr.CheckboxGroup(["image", "PNG info"], label="save merged model ID to")