fix: fix output filename restriction to allow "." and "/"
parent
69580a348a
commit
e18a67a95d
|
|
@ -17,7 +17,7 @@ def on_ui_tabs():
|
|||
with gr.Row():
|
||||
with gr.Column(variant="panel"):
|
||||
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
|
||||
btn_clear_weighted = gr.Button(value="Clear values")
|
||||
btn_clear_weight = gr.Button(value="Clear values")
|
||||
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
|
||||
html_output_block_weight_info = gr.HTML()
|
||||
with gr.Column():
|
||||
|
|
@ -114,6 +114,8 @@ def on_ui_tabs():
|
|||
_model_B_info = model_B_info.model_name
|
||||
else:
|
||||
_model_B_info = ""
|
||||
|
||||
txt_model_O = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', txt_model_O)
|
||||
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
|
||||
if ".ckpt" not in model_O:
|
||||
model_O = model_O + ".ckpt"
|
||||
|
|
@ -132,15 +134,17 @@ def on_ui_tabs():
|
|||
print(f"output_file: {_output}")
|
||||
print(f"weights : {_weights}")
|
||||
|
||||
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw)
|
||||
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite,
|
||||
base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw
|
||||
)
|
||||
|
||||
sd_models.list_models()
|
||||
if result:
|
||||
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{model_O}"
|
||||
else:
|
||||
ret_html = ret_message
|
||||
|
||||
# save log to history.tsv
|
||||
sd_models.list_models()
|
||||
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
|
||||
model_O_hash = "" if not model_O_info else model_O_info.hash
|
||||
_names = presetWeights.find_names_by_weight(_weights)
|
||||
|
|
@ -153,11 +157,13 @@ def on_ui_tabs():
|
|||
return gr.update(value=f"{ret_html}")
|
||||
btn_do_merge_block_weighted.click(
|
||||
fn=onclick_btn_do_merge_block_weighted,
|
||||
inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
|
||||
inputs=[model_A, model_B]
|
||||
+ sl_IN + sl_MID + sl_OUT
|
||||
+ [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
|
||||
outputs=[html_output_block_weight_info]
|
||||
)
|
||||
|
||||
btn_clear_weighted.click(
|
||||
btn_clear_weight.click(
|
||||
fn=lambda: [gr.update(value=0.5) for _ in range(25)],
|
||||
inputs=[],
|
||||
outputs=[
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ def on_ui_tabs():
|
|||
]])
|
||||
|
||||
# debug output
|
||||
print( "#### Merge Block Weighted ####")
|
||||
print( "#### Merge Block Weighted Each ####")
|
||||
|
||||
if (not dd_model_A or not dd_model_B) and txt_multi_process_cmd == "":
|
||||
_err_msg = f"ERROR: model not found. [{dd_model_A}][{dd_model_B}]"
|
||||
|
|
@ -198,7 +198,10 @@ def on_ui_tabs():
|
|||
if len(_items) > 0:
|
||||
ret_html += _run_merge(
|
||||
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
|
||||
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw, params=_items)
|
||||
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O,
|
||||
verbose=chk_verbose_mbw,
|
||||
params=_items
|
||||
)
|
||||
else:
|
||||
_ret = f" multi-merge text found, but invalid params. skipped :[{_line}]"
|
||||
ret_html += _ret
|
||||
|
|
@ -207,7 +210,9 @@ def on_ui_tabs():
|
|||
# normal merge
|
||||
ret_html += _run_merge(
|
||||
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
|
||||
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw)
|
||||
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O,
|
||||
verbose=chk_verbose_mbw
|
||||
)
|
||||
|
||||
sd_models.list_models()
|
||||
print( "#### All merge process done. ####")
|
||||
|
|
@ -215,22 +220,25 @@ def on_ui_tabs():
|
|||
return gr.update(value=f"{ret_html}")
|
||||
btn_do_merge_block_weighted.click(
|
||||
fn=onclick_btn_do_merge_block_weighted,
|
||||
inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd] + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
|
||||
inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd]
|
||||
+ sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT
|
||||
+ [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
|
||||
outputs=[html_output_block_weight_info]
|
||||
)
|
||||
|
||||
def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0, model_Output="", verbose=False, params=[]):
|
||||
def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0,
|
||||
model_Output="", verbose=False, params=[]
|
||||
):
|
||||
|
||||
# generate output file name from param
|
||||
model_A_info = sd_models.get_closet_checkpoint_match(model_0)
|
||||
_model_A_name = "" if not model_A_info else model_A_info.model_name # expect "" or "aaa.ckpt [abcdefgh]"
|
||||
def validate_output_filename(output_filename):
|
||||
output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', output_filename)
|
||||
_, ext = os.path.splitext(output_filename)
|
||||
_ret = output_filename
|
||||
if ext != ".safetensors" and ext != ".ckpt":
|
||||
_ret = f"{output_filename}.ckpt"
|
||||
return _ret
|
||||
|
||||
model_B_info = sd_models.get_closet_checkpoint_match(model_1)
|
||||
_model_B_name = "" if not model_B_info else model_B_info.model_name
|
||||
|
||||
model_O = f"bw-merge-{_model_A_name}-{_model_B_name}-{base_alpha}.ckpt" if model_Output == "" else model_Output
|
||||
model_O = model_O if ".ckpt" in model_O else f"{model_O}.ckpt"
|
||||
model_O = re.sub(r'[\\|/|:|?|.|"|<|>|\|\*]', '-', model_O)
|
||||
model_O = ""
|
||||
|
||||
if params and len(params) > 0:
|
||||
for _item in params:
|
||||
|
|
@ -250,7 +258,7 @@ def on_ui_tabs():
|
|||
elif _item_l.lower() == "model_b":
|
||||
print(f" * Model changed: {model_1} -> {_model_info.title}")
|
||||
model_1 = _model_info.title
|
||||
|
||||
|
||||
elif _item_l.lower() == "preset_weights":
|
||||
_weights = presetWeights.find_weight_by_name(_item_r)
|
||||
if _weights != "" and len(_weights.split(',')) == 25:
|
||||
|
|
@ -259,7 +267,7 @@ def on_ui_tabs():
|
|||
weight_A = ",".join([str(1-float(x)) for x in _weights.split(',')])
|
||||
else:
|
||||
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
|
||||
|
||||
|
||||
elif _item_l.lower() == "weight_values":
|
||||
_weights = _item_r.strip()
|
||||
if _weights != "" and len(_weights.split(' ')) == 25: # this is work-around to use space as separator. Double-meaning issue on commna which already used as value separator and weights separator.
|
||||
|
|
@ -268,17 +276,18 @@ def on_ui_tabs():
|
|||
weight_A = ",".join([str(1-float(x)) for x in _weights.split(' ')])
|
||||
else:
|
||||
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
|
||||
|
||||
|
||||
elif _item_l.lower() == "base_alpha":
|
||||
if float(_item_r) >= 0:
|
||||
print(f" * base_alpha changed: {base_alpha} -> {_item_r}")
|
||||
base_alpha = float(_item_r)
|
||||
|
||||
|
||||
elif _item_l.upper() == "O":
|
||||
_item_r = _item_r if ".ckpt" in _item_r else f"{_item_r}.ckpt"
|
||||
print(f" * Output filename changed:[{model_O}] -> [{_item_r}]")
|
||||
model_O = _item_r
|
||||
|
||||
if _item_r.strip() != "":
|
||||
_ret = validate_output_filename(_item_r.strip())
|
||||
print(f" * Output filename changed:[{model_O}] -> [{_ret}]")
|
||||
model_O = _ret
|
||||
|
||||
elif len(_item_l.split("_")) == 3:
|
||||
_IMO = _item_l.split("_")[0]
|
||||
_AB = _item_l.split("_")[1]
|
||||
|
|
@ -308,6 +317,19 @@ def on_ui_tabs():
|
|||
#
|
||||
# Prepare params before run merge
|
||||
#
|
||||
|
||||
# generate output file name from param
|
||||
model_A_info = sd_models.get_closet_checkpoint_match(model_0)
|
||||
_model_A_name = "" if not model_A_info else model_A_info.filename
|
||||
|
||||
model_B_info = sd_models.get_closet_checkpoint_match(model_1)
|
||||
_model_B_name = "" if not model_B_info else model_B_info.filename
|
||||
|
||||
if model_O == "":
|
||||
_a = os.path.splitext(os.path.basename(_model_A_name))[0]
|
||||
_b = os.path.splitext(os.path.basename(_model_B_name))[0]
|
||||
model_O = f"bw-merge-{_a}-{_b}-{base_alpha}.ckpt" if model_Output == "" else model_Output
|
||||
model_O = validate_output_filename(model_O)
|
||||
output_file = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O)
|
||||
#
|
||||
# Check params
|
||||
|
|
|
|||
Loading…
Reference in New Issue