diff --git a/scripts/mbw/ui_mbw.py b/scripts/mbw/ui_mbw.py index 25cc9f4..5623cb5 100644 --- a/scripts/mbw/ui_mbw.py +++ b/scripts/mbw/ui_mbw.py @@ -17,7 +17,7 @@ def on_ui_tabs(): with gr.Row(): with gr.Column(variant="panel"): btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary") - btn_clear_weighted = gr.Button(value="Clear values") + btn_clear_weight = gr.Button(value="Clear values") btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint") html_output_block_weight_info = gr.HTML() with gr.Column(): @@ -114,6 +114,8 @@ def on_ui_tabs(): _model_B_info = model_B_info.model_name else: _model_B_info = "" + + txt_model_O = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', txt_model_O) model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O if ".ckpt" not in model_O: model_O = model_O + ".ckpt" @@ -132,15 +134,17 @@ def on_ui_tabs(): print(f"output_file: {_output}") print(f"weights : {_weights}") - result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw) + result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, + base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw + ) - sd_models.list_models() if result: ret_html = "merged.
" + f"{model_A}
" + f"{model_B}
" + f"{model_O}" else: ret_html = ret_message # save log to history.tsv + sd_models.list_models() model_O_info = sd_models.get_closet_checkpoint_match(model_O) model_O_hash = "" if not model_O_info else model_O_info.hash _names = presetWeights.find_names_by_weight(_weights) @@ -153,11 +157,13 @@ def on_ui_tabs(): return gr.update(value=f"{ret_html}") btn_do_merge_block_weighted.click( fn=onclick_btn_do_merge_block_weighted, - inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite], + inputs=[model_A, model_B] + + sl_IN + sl_MID + sl_OUT + + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite], outputs=[html_output_block_weight_info] ) - btn_clear_weighted.click( + btn_clear_weight.click( fn=lambda: [gr.update(value=0.5) for _ in range(25)], inputs=[], outputs=[ diff --git a/scripts/mbw_each/ui_mbw_each.py b/scripts/mbw_each/ui_mbw_each.py index dda106c..7e93b54 100644 --- a/scripts/mbw_each/ui_mbw_each.py +++ b/scripts/mbw_each/ui_mbw_each.py @@ -178,7 +178,7 @@ def on_ui_tabs(): ]]) # debug output - print( "#### Merge Block Weighted ####") + print( "#### Merge Block Weighted Each ####") if (not dd_model_A or not dd_model_B) and txt_multi_process_cmd == "": _err_msg = f"ERROR: model not found. [{dd_model_A}][{dd_model_B}]" @@ -198,7 +198,10 @@ def on_ui_tabs(): if len(_items) > 0: ret_html += _run_merge( weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B, - allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw, params=_items) + allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, + verbose=chk_verbose_mbw, + params=_items + ) else: _ret = f" multi-merge text found, but invalid params. skipped :[{_line}]" ret_html += _ret @@ -207,7 +210,9 @@ def on_ui_tabs(): # normal merge ret_html += _run_merge( weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B, - allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw) + allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, + verbose=chk_verbose_mbw + ) sd_models.list_models() print( "#### All merge process done. ####") @@ -215,22 +220,25 @@ def on_ui_tabs(): return gr.update(value=f"{ret_html}") btn_do_merge_block_weighted.click( fn=onclick_btn_do_merge_block_weighted, - inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd] + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite], + inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd] + + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT + + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite], outputs=[html_output_block_weight_info] ) - def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0, model_Output="", verbose=False, params=[]): + def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0, + model_Output="", verbose=False, params=[] + ): - # generate output file name from param - model_A_info = sd_models.get_closet_checkpoint_match(model_0) - _model_A_name = "" if not model_A_info else model_A_info.model_name # expect "" or "aaa.ckpt [abcdefgh]" + def validate_output_filename(output_filename): + output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', output_filename) + _, ext = os.path.splitext(output_filename) + _ret = output_filename + if ext != ".safetensors" and ext != ".ckpt": + _ret = f"{output_filename}.ckpt" + return _ret - model_B_info = sd_models.get_closet_checkpoint_match(model_1) - _model_B_name = "" if not model_B_info else model_B_info.model_name - - model_O = f"bw-merge-{_model_A_name}-{_model_B_name}-{base_alpha}.ckpt" if model_Output == "" else model_Output - model_O = model_O if ".ckpt" in model_O else f"{model_O}.ckpt" - model_O = re.sub(r'[\\|/|:|?|.|"|<|>|\|\*]', '-', model_O) + model_O = "" if params and len(params) > 0: for _item in params: @@ -250,7 +258,7 @@ def on_ui_tabs(): elif _item_l.lower() == "model_b": print(f" * Model changed: {model_1} -> {_model_info.title}") model_1 = _model_info.title - + elif _item_l.lower() == "preset_weights": _weights = presetWeights.find_weight_by_name(_item_r) if _weights != "" and len(_weights.split(',')) == 25: @@ -259,7 +267,7 @@ def on_ui_tabs(): weight_A = ",".join([str(1-float(x)) for x in _weights.split(',')]) else: print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]") - + elif _item_l.lower() == "weight_values": _weights = _item_r.strip() if _weights != "" and len(_weights.split(' ')) == 25: # this is work-around to use space as separator. Double-meaning issue on commna which already used as value separator and weights separator. @@ -268,17 +276,18 @@ def on_ui_tabs(): weight_A = ",".join([str(1-float(x)) for x in _weights.split(' ')]) else: print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]") - + elif _item_l.lower() == "base_alpha": if float(_item_r) >= 0: print(f" * base_alpha changed: {base_alpha} -> {_item_r}") base_alpha = float(_item_r) - + elif _item_l.upper() == "O": - _item_r = _item_r if ".ckpt" in _item_r else f"{_item_r}.ckpt" - print(f" * Output filename changed:[{model_O}] -> [{_item_r}]") - model_O = _item_r - + if _item_r.strip() != "": + _ret = validate_output_filename(_item_r.strip()) + print(f" * Output filename changed:[{model_O}] -> [{_ret}]") + model_O = _ret + elif len(_item_l.split("_")) == 3: _IMO = _item_l.split("_")[0] _AB = _item_l.split("_")[1] @@ -308,6 +317,19 @@ def on_ui_tabs(): # # Prepare params before run merge # + + # generate output file name from param + model_A_info = sd_models.get_closet_checkpoint_match(model_0) + _model_A_name = "" if not model_A_info else model_A_info.filename + + model_B_info = sd_models.get_closet_checkpoint_match(model_1) + _model_B_name = "" if not model_B_info else model_B_info.filename + + if model_O == "": + _a = os.path.splitext(os.path.basename(_model_A_name))[0] + _b = os.path.splitext(os.path.basename(_model_B_name))[0] + model_O = f"bw-merge-{_a}-{_b}-{base_alpha}.ckpt" if model_Output == "" else model_Output + model_O = validate_output_filename(model_O) output_file = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O) # # Check params