diff --git a/scripts/mbw/ui_mbw.py b/scripts/mbw/ui_mbw.py
index 25cc9f4..5623cb5 100644
--- a/scripts/mbw/ui_mbw.py
+++ b/scripts/mbw/ui_mbw.py
@@ -17,7 +17,7 @@ def on_ui_tabs():
with gr.Row():
with gr.Column(variant="panel"):
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
- btn_clear_weighted = gr.Button(value="Clear values")
+ btn_clear_weight = gr.Button(value="Clear values")
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
html_output_block_weight_info = gr.HTML()
with gr.Column():
@@ -114,6 +114,8 @@ def on_ui_tabs():
_model_B_info = model_B_info.model_name
else:
_model_B_info = ""
+
+ txt_model_O = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', txt_model_O)
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
if ".ckpt" not in model_O:
model_O = model_O + ".ckpt"
@@ -132,15 +134,17 @@ def on_ui_tabs():
print(f"output_file: {_output}")
print(f"weights : {_weights}")
- result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw)
+ result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite,
+ base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw
+ )
- sd_models.list_models()
if result:
ret_html = "merged.
" + f"{model_A}
" + f"{model_B}
" + f"{model_O}"
else:
ret_html = ret_message
# save log to history.tsv
+ sd_models.list_models()
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
model_O_hash = "" if not model_O_info else model_O_info.hash
_names = presetWeights.find_names_by_weight(_weights)
@@ -153,11 +157,13 @@ def on_ui_tabs():
return gr.update(value=f"{ret_html}")
btn_do_merge_block_weighted.click(
fn=onclick_btn_do_merge_block_weighted,
- inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
+ inputs=[model_A, model_B]
+ + sl_IN + sl_MID + sl_OUT
+ + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
outputs=[html_output_block_weight_info]
)
- btn_clear_weighted.click(
+ btn_clear_weight.click(
fn=lambda: [gr.update(value=0.5) for _ in range(25)],
inputs=[],
outputs=[
diff --git a/scripts/mbw_each/ui_mbw_each.py b/scripts/mbw_each/ui_mbw_each.py
index dda106c..7e93b54 100644
--- a/scripts/mbw_each/ui_mbw_each.py
+++ b/scripts/mbw_each/ui_mbw_each.py
@@ -178,7 +178,7 @@ def on_ui_tabs():
]])
# debug output
- print( "#### Merge Block Weighted ####")
+ print( "#### Merge Block Weighted Each ####")
if (not dd_model_A or not dd_model_B) and txt_multi_process_cmd == "":
_err_msg = f"ERROR: model not found. [{dd_model_A}][{dd_model_B}]"
@@ -198,7 +198,10 @@ def on_ui_tabs():
if len(_items) > 0:
ret_html += _run_merge(
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
- allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw, params=_items)
+ allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O,
+ verbose=chk_verbose_mbw,
+ params=_items
+ )
else:
_ret = f" multi-merge text found, but invalid params. skipped :[{_line}]"
ret_html += _ret
@@ -207,7 +210,9 @@ def on_ui_tabs():
# normal merge
ret_html += _run_merge(
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
- allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw)
+ allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O,
+ verbose=chk_verbose_mbw
+ )
sd_models.list_models()
print( "#### All merge process done. ####")
@@ -215,22 +220,25 @@ def on_ui_tabs():
return gr.update(value=f"{ret_html}")
btn_do_merge_block_weighted.click(
fn=onclick_btn_do_merge_block_weighted,
- inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd] + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
+ inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd]
+ + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT
+ + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
outputs=[html_output_block_weight_info]
)
- def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0, model_Output="", verbose=False, params=[]):
+ def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0,
+ model_Output="", verbose=False, params=[]
+ ):
- # generate output file name from param
- model_A_info = sd_models.get_closet_checkpoint_match(model_0)
- _model_A_name = "" if not model_A_info else model_A_info.model_name # expect "" or "aaa.ckpt [abcdefgh]"
+ def validate_output_filename(output_filename):
+ output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', output_filename)
+ _, ext = os.path.splitext(output_filename)
+ _ret = output_filename
+ if ext != ".safetensors" and ext != ".ckpt":
+ _ret = f"{output_filename}.ckpt"
+ return _ret
- model_B_info = sd_models.get_closet_checkpoint_match(model_1)
- _model_B_name = "" if not model_B_info else model_B_info.model_name
-
- model_O = f"bw-merge-{_model_A_name}-{_model_B_name}-{base_alpha}.ckpt" if model_Output == "" else model_Output
- model_O = model_O if ".ckpt" in model_O else f"{model_O}.ckpt"
- model_O = re.sub(r'[\\|/|:|?|.|"|<|>|\|\*]', '-', model_O)
+ model_O = ""
if params and len(params) > 0:
for _item in params:
@@ -250,7 +258,7 @@ def on_ui_tabs():
elif _item_l.lower() == "model_b":
print(f" * Model changed: {model_1} -> {_model_info.title}")
model_1 = _model_info.title
-
+
elif _item_l.lower() == "preset_weights":
_weights = presetWeights.find_weight_by_name(_item_r)
if _weights != "" and len(_weights.split(',')) == 25:
@@ -259,7 +267,7 @@ def on_ui_tabs():
weight_A = ",".join([str(1-float(x)) for x in _weights.split(',')])
else:
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
-
+
elif _item_l.lower() == "weight_values":
_weights = _item_r.strip()
if _weights != "" and len(_weights.split(' ')) == 25: # this is work-around to use space as separator. Double-meaning issue on commna which already used as value separator and weights separator.
@@ -268,17 +276,18 @@ def on_ui_tabs():
weight_A = ",".join([str(1-float(x)) for x in _weights.split(' ')])
else:
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
-
+
elif _item_l.lower() == "base_alpha":
if float(_item_r) >= 0:
print(f" * base_alpha changed: {base_alpha} -> {_item_r}")
base_alpha = float(_item_r)
-
+
elif _item_l.upper() == "O":
- _item_r = _item_r if ".ckpt" in _item_r else f"{_item_r}.ckpt"
- print(f" * Output filename changed:[{model_O}] -> [{_item_r}]")
- model_O = _item_r
-
+ if _item_r.strip() != "":
+ _ret = validate_output_filename(_item_r.strip())
+ print(f" * Output filename changed:[{model_O}] -> [{_ret}]")
+ model_O = _ret
+
elif len(_item_l.split("_")) == 3:
_IMO = _item_l.split("_")[0]
_AB = _item_l.split("_")[1]
@@ -308,6 +317,19 @@ def on_ui_tabs():
#
# Prepare params before run merge
#
+
+ # generate output file name from param
+ model_A_info = sd_models.get_closet_checkpoint_match(model_0)
+ _model_A_name = "" if not model_A_info else model_A_info.filename
+
+ model_B_info = sd_models.get_closet_checkpoint_match(model_1)
+ _model_B_name = "" if not model_B_info else model_B_info.filename
+
+ if model_O == "":
+ _a = os.path.splitext(os.path.basename(_model_A_name))[0]
+ _b = os.path.splitext(os.path.basename(_model_B_name))[0]
+ model_O = f"bw-merge-{_a}-{_b}-{base_alpha}.ckpt" if model_Output == "" else model_Output
+ model_O = validate_output_filename(model_O)
output_file = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O)
#
# Check params