291 lines
14 KiB
Python
291 lines
14 KiB
Python
import gradio as gr
|
|
import os
|
|
import re
|
|
|
|
from modules import sd_models, shared
|
|
from tqdm import tqdm
|
|
try:
|
|
from modules import hashes
|
|
from modules.sd_models import CheckpointInfo
|
|
except:
|
|
pass
|
|
|
|
from scripts.mbw.merge_block_weighted import merge
|
|
from scripts.mbw_util.preset_weights import PresetWeights
|
|
from scripts.mbw_util.merge_history import MergeHistory
|
|
|
|
presetWeights = PresetWeights()
|
|
mergeHistory = MergeHistory()
|
|
|
|
|
|
def on_ui_tabs():
|
|
with gr.Column():
|
|
with gr.Row():
|
|
with gr.Column(variant="panel"):
|
|
html_output_block_weight_info = gr.HTML()
|
|
with gr.Row():
|
|
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
|
|
btn_clear_weight = gr.Button(value="Clear values")
|
|
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
|
|
with gr.Column():
|
|
dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list())
|
|
txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25")
|
|
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
|
|
with gr.Row():
|
|
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.01, value=0)
|
|
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
|
|
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
|
|
with gr.Row():
|
|
with gr.Column(scale=3):
|
|
with gr.Row():
|
|
chk_save_as_half = gr.Checkbox(label="Save as half", value=False)
|
|
chk_save_as_safetensors = gr.Checkbox(label="Save as safetensors", value=False)
|
|
with gr.Column(scale=4):
|
|
radio_position_ids = gr.Radio(label="Skip/Reset CLIP position_ids", choices=["None", "Skip", "Force Reset"], value="None", type="index")
|
|
with gr.Row():
|
|
model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles())
|
|
model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles())
|
|
txt_model_O = gr.Text(label="Output Model Name")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
with gr.Column():
|
|
gr.Slider(visible=False)
|
|
gr.Slider(visible=False)
|
|
gr.Slider(visible=False)
|
|
gr.Slider(visible=False)
|
|
gr.Slider(visible=False)
|
|
gr.Slider(visible=False)
|
|
gr.Slider(visible=False)
|
|
gr.Slider(visible=False)
|
|
gr.Slider(visible=False)
|
|
gr.Slider(visible=False)
|
|
gr.Slider(visible=False)
|
|
sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="mbw_sl_M00")
|
|
with gr.Column():
|
|
sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.01, value=0.5)
|
|
|
|
sl_IN = [
|
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11]
|
|
sl_MID = [sl_M_00]
|
|
sl_OUT = [
|
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11]
|
|
|
|
# Events
|
|
def onclick_btn_do_merge_block_weighted(
|
|
model_A, model_B,
|
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
|
sl_M_00,
|
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
|
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite,
|
|
chk_save_as_safetensors, chk_save_as_half,
|
|
radio_position_ids
|
|
):
|
|
|
|
# debug output
|
|
print( "#### Merge Block Weighted ####")
|
|
|
|
_weights = ",".join(
|
|
[str(x) for x in [
|
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
|
sl_M_00,
|
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11
|
|
]])
|
|
#
|
|
if not model_A or not model_B:
|
|
return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]")
|
|
|
|
#
|
|
# Prepare params before run merge
|
|
#
|
|
|
|
# generate output file name from param
|
|
model_A_info = sd_models.get_closet_checkpoint_match(model_A)
|
|
if model_A_info:
|
|
_model_A_name = model_A_info.model_name
|
|
else:
|
|
_model_A_name = ""
|
|
model_B_info = sd_models.get_closet_checkpoint_match(model_B)
|
|
if model_B_info:
|
|
_model_B_info = model_B_info.model_name
|
|
else:
|
|
_model_B_info = ""
|
|
|
|
def validate_output_filename(output_filename, save_as_safetensors=False, save_as_half=False):
|
|
output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', output_filename)
|
|
filename_body, filename_ext = os.path.splitext(output_filename)
|
|
_ret = output_filename
|
|
_footer = "-half" if save_as_half else ""
|
|
if filename_ext in [".safetensors", ".ckpt"]:
|
|
_ret = f"{filename_body}{_footer}{filename_ext}"
|
|
elif save_as_safetensors:
|
|
_ret = f"{output_filename}{_footer}.safetensors"
|
|
else:
|
|
_ret = f"{output_filename}{_footer}.ckpt"
|
|
return _ret
|
|
|
|
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
|
|
model_O = validate_output_filename(model_O, save_as_safetensors=chk_save_as_safetensors, save_as_half=chk_save_as_half)
|
|
|
|
_output = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O)
|
|
|
|
if not chk_allow_overwrite:
|
|
if os.path.exists(_output):
|
|
_err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort."
|
|
print(_err_msg)
|
|
return gr.update(value=f"{_err_msg} [{_output}]")
|
|
print(f" model_0 : {model_A}")
|
|
print(f" model_1 : {model_B}")
|
|
print(f" base_alpha : {sl_base_alpha}")
|
|
print(f" output_file: {_output}")
|
|
print(f" weights : {_weights}")
|
|
print(f" skip ids : {radio_position_ids} : 0:None, 1:Skip, 2:Reset")
|
|
|
|
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite,
|
|
base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw,
|
|
save_as_safetensors=chk_save_as_safetensors,
|
|
save_as_half=chk_save_as_half,
|
|
skip_position_ids=radio_position_ids
|
|
)
|
|
|
|
if result:
|
|
ret_html = "merged.<br>" \
|
|
+ f"{model_A}<br>" \
|
|
+ f"{model_B}<br>" \
|
|
+ f"{model_O}<br>" \
|
|
+ f"base_alpha={sl_base_alpha}<br>" \
|
|
+ f"Weight_values={_weights}<br>"
|
|
print("merged.")
|
|
else:
|
|
ret_html = ret_message
|
|
print("merge failed.")
|
|
|
|
# save log to history.tsv
|
|
sd_models.list_models()
|
|
model_A_info = sd_models.get_closet_checkpoint_match(model_A)
|
|
model_B_info = sd_models.get_closet_checkpoint_match(model_B)
|
|
model_O_info = sd_models.get_closet_checkpoint_match(os.path.basename(_output))
|
|
if hasattr(model_O_info, "sha256") and model_O_info.sha256 is None:
|
|
model_O_info:CheckpointInfo = model_O_info
|
|
model_O_info.sha256 = hashes.sha256(model_O_info.filename, "checkpoint/" + model_O_info.title)
|
|
_names = presetWeights.find_names_by_weight(_weights)
|
|
if _names and len(_names) > 0:
|
|
weight_name = _names[0]
|
|
else:
|
|
weight_name = ""
|
|
|
|
def model_name(model_info):
|
|
return model_info.name if hasattr(model_info, "name") else model_info.title
|
|
def model_sha256(model_info):
|
|
return model_info.sha256 if hasattr(model_info, "sha256") else ""
|
|
mergeHistory.add_history(
|
|
model_name(model_A_info),
|
|
model_A_info.hash,
|
|
model_sha256(model_A_info),
|
|
model_name(model_B_info),
|
|
model_B_info.hash,
|
|
model_sha256(model_B_info),
|
|
model_name(model_O_info),
|
|
model_O_info.hash,
|
|
model_sha256(model_O_info),
|
|
sl_base_alpha,
|
|
_weights,
|
|
"",
|
|
weight_name
|
|
)
|
|
|
|
return gr.update(value=f"{ret_html}")
|
|
btn_do_merge_block_weighted.click(
|
|
fn=onclick_btn_do_merge_block_weighted,
|
|
inputs=[model_A, model_B]
|
|
+ sl_IN + sl_MID + sl_OUT
|
|
+ [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite]
|
|
+ [chk_save_as_safetensors, chk_save_as_half, radio_position_ids],
|
|
outputs=[html_output_block_weight_info]
|
|
)
|
|
|
|
btn_clear_weight.click(
|
|
fn=lambda: [gr.update(value=0.5) for _ in range(25)],
|
|
inputs=[],
|
|
outputs=[
|
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
|
sl_M_00,
|
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
|
]
|
|
)
|
|
|
|
def on_change_dd_preset_weight(dd_preset_weight):
|
|
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
|
|
_ret = on_btn_apply_block_weight_from_txt(_weights)
|
|
return [gr.update(value=_weights)] + _ret
|
|
dd_preset_weight.change(
|
|
fn=on_change_dd_preset_weight,
|
|
inputs=[dd_preset_weight],
|
|
outputs=[txt_block_weight,
|
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
|
sl_M_00,
|
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
|
]
|
|
)
|
|
|
|
def on_btn_reload_checkpoint_mbw():
|
|
sd_models.list_models()
|
|
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
|
|
btn_reload_checkpoint_mbw.click(
|
|
fn=on_btn_reload_checkpoint_mbw,
|
|
inputs=[],
|
|
outputs=[model_A, model_B]
|
|
)
|
|
|
|
def on_btn_apply_block_weight_from_txt(txt_block_weight):
|
|
if not txt_block_weight or txt_block_weight == "":
|
|
return [gr.update() for _ in range(25)]
|
|
_list = [x.strip() for x in txt_block_weight.split(",")]
|
|
if(len(_list) != 25):
|
|
return [gr.update() for _ in range(25)]
|
|
return [gr.update(value=x) for x in _list]
|
|
btn_apply_block_weithg_from_txt.click(
|
|
fn=on_btn_apply_block_weight_from_txt,
|
|
inputs=[txt_block_weight],
|
|
outputs=[
|
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
|
sl_M_00,
|
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
|
]
|
|
)
|
|
|