fix: fix output filename restriction to allow "." and "/"

fix_update
bbc_mc 2023-01-08 13:00:00 +09:00
parent 69580a348a
commit e18a67a95d
2 changed files with 55 additions and 27 deletions

View File

@ -17,7 +17,7 @@ def on_ui_tabs():
with gr.Row():
with gr.Column(variant="panel"):
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
btn_clear_weighted = gr.Button(value="Clear values")
btn_clear_weight = gr.Button(value="Clear values")
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
html_output_block_weight_info = gr.HTML()
with gr.Column():
@ -114,6 +114,8 @@ def on_ui_tabs():
_model_B_info = model_B_info.model_name
else:
_model_B_info = ""
txt_model_O = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', txt_model_O)
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
if ".ckpt" not in model_O:
model_O = model_O + ".ckpt"
@ -132,15 +134,17 @@ def on_ui_tabs():
print(f"output_file: {_output}")
print(f"weights : {_weights}")
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw)
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite,
base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw
)
sd_models.list_models()
if result:
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{model_O}"
else:
ret_html = ret_message
# save log to history.tsv
sd_models.list_models()
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
model_O_hash = "" if not model_O_info else model_O_info.hash
_names = presetWeights.find_names_by_weight(_weights)
@ -153,11 +157,13 @@ def on_ui_tabs():
return gr.update(value=f"{ret_html}")
btn_do_merge_block_weighted.click(
fn=onclick_btn_do_merge_block_weighted,
inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
inputs=[model_A, model_B]
+ sl_IN + sl_MID + sl_OUT
+ [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
outputs=[html_output_block_weight_info]
)
btn_clear_weighted.click(
btn_clear_weight.click(
fn=lambda: [gr.update(value=0.5) for _ in range(25)],
inputs=[],
outputs=[

View File

@ -178,7 +178,7 @@ def on_ui_tabs():
]])
# debug output
print( "#### Merge Block Weighted ####")
print( "#### Merge Block Weighted Each ####")
if (not dd_model_A or not dd_model_B) and txt_multi_process_cmd == "":
_err_msg = f"ERROR: model not found. [{dd_model_A}][{dd_model_B}]"
@ -198,7 +198,10 @@ def on_ui_tabs():
if len(_items) > 0:
ret_html += _run_merge(
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw, params=_items)
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O,
verbose=chk_verbose_mbw,
params=_items
)
else:
_ret = f" multi-merge text found, but invalid params. skipped :[{_line}]"
ret_html += _ret
@ -207,7 +210,9 @@ def on_ui_tabs():
# normal merge
ret_html += _run_merge(
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw)
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O,
verbose=chk_verbose_mbw
)
sd_models.list_models()
print( "#### All merge process done. ####")
@ -215,22 +220,25 @@ def on_ui_tabs():
return gr.update(value=f"{ret_html}")
btn_do_merge_block_weighted.click(
fn=onclick_btn_do_merge_block_weighted,
inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd] + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd]
+ sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT
+ [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
outputs=[html_output_block_weight_info]
)
def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0, model_Output="", verbose=False, params=[]):
def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0,
model_Output="", verbose=False, params=[]
):
# generate output file name from param
model_A_info = sd_models.get_closet_checkpoint_match(model_0)
_model_A_name = "" if not model_A_info else model_A_info.model_name # expect "" or "aaa.ckpt [abcdefgh]"
def validate_output_filename(output_filename):
output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', output_filename)
_, ext = os.path.splitext(output_filename)
_ret = output_filename
if ext != ".safetensors" and ext != ".ckpt":
_ret = f"{output_filename}.ckpt"
return _ret
model_B_info = sd_models.get_closet_checkpoint_match(model_1)
_model_B_name = "" if not model_B_info else model_B_info.model_name
model_O = f"bw-merge-{_model_A_name}-{_model_B_name}-{base_alpha}.ckpt" if model_Output == "" else model_Output
model_O = model_O if ".ckpt" in model_O else f"{model_O}.ckpt"
model_O = re.sub(r'[\\|/|:|?|.|"|<|>|\|\*]', '-', model_O)
model_O = ""
if params and len(params) > 0:
for _item in params:
@ -250,7 +258,7 @@ def on_ui_tabs():
elif _item_l.lower() == "model_b":
print(f" * Model changed: {model_1} -> {_model_info.title}")
model_1 = _model_info.title
elif _item_l.lower() == "preset_weights":
_weights = presetWeights.find_weight_by_name(_item_r)
if _weights != "" and len(_weights.split(',')) == 25:
@ -259,7 +267,7 @@ def on_ui_tabs():
weight_A = ",".join([str(1-float(x)) for x in _weights.split(',')])
else:
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
elif _item_l.lower() == "weight_values":
_weights = _item_r.strip()
if _weights != "" and len(_weights.split(' ')) == 25: # this is work-around to use space as separator. Double-meaning issue on commna which already used as value separator and weights separator.
@ -268,17 +276,18 @@ def on_ui_tabs():
weight_A = ",".join([str(1-float(x)) for x in _weights.split(' ')])
else:
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
elif _item_l.lower() == "base_alpha":
if float(_item_r) >= 0:
print(f" * base_alpha changed: {base_alpha} -> {_item_r}")
base_alpha = float(_item_r)
elif _item_l.upper() == "O":
_item_r = _item_r if ".ckpt" in _item_r else f"{_item_r}.ckpt"
print(f" * Output filename changed:[{model_O}] -> [{_item_r}]")
model_O = _item_r
if _item_r.strip() != "":
_ret = validate_output_filename(_item_r.strip())
print(f" * Output filename changed:[{model_O}] -> [{_ret}]")
model_O = _ret
elif len(_item_l.split("_")) == 3:
_IMO = _item_l.split("_")[0]
_AB = _item_l.split("_")[1]
@ -308,6 +317,19 @@ def on_ui_tabs():
#
# Prepare params before run merge
#
# generate output file name from param
model_A_info = sd_models.get_closet_checkpoint_match(model_0)
_model_A_name = "" if not model_A_info else model_A_info.filename
model_B_info = sd_models.get_closet_checkpoint_match(model_1)
_model_B_name = "" if not model_B_info else model_B_info.filename
if model_O == "":
_a = os.path.splitext(os.path.basename(_model_A_name))[0]
_b = os.path.splitext(os.path.basename(_model_B_name))[0]
model_O = f"bw-merge-{_a}-{_b}-{base_alpha}.ckpt" if model_Output == "" else model_Output
model_O = validate_output_filename(model_O)
output_file = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O)
#
# Check params