diff --git a/.gitignore b/.gitignore index 0a1b85c..8c92e0e 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,5 @@ /csv/history.tsv +/csv/preset_own.tsv + +# +_* diff --git a/javascript/js_mbw_each.js b/javascript/js_mbw_each.js new file mode 100644 index 0000000..51777c1 --- /dev/null +++ b/javascript/js_mbw_each.js @@ -0,0 +1,23 @@ +// +// fix position of sliders +// + +// +// UI +// +onUiUpdate(function () { + // check Extension loaded + if (gradioApp().querySelector("div#tab_mbw_each") == null ) return; + + // check already done + //if (gradioApp().querySelector("#div_mdl_size_a") != null) return; + + // apply + let _style = "min-width: min(200px, 100%); flex-grow: 1"; + gradioApp().querySelector("#sl_IN_A_00").parentElement.parentElement.setAttribute("style", _style) + gradioApp().querySelector("#sl_IN_B_00").parentElement.parentElement.setAttribute("style", _style) + gradioApp().querySelector("#sl_M_A_00").parentElement.parentElement.setAttribute("style", _style) + gradioApp().querySelector("#sl_M_B_00").parentElement.parentElement.setAttribute("style", _style) + gradioApp().querySelector("#sl_OUT_A_00").parentElement.parentElement.setAttribute("style", _style) + gradioApp().querySelector("#sl_OUT_B_00").parentElement.parentElement.setAttribute("style", _style) +}); diff --git a/scripts/merge_block_weighted.py b/scripts/mbw/merge_block_weighted.py similarity index 100% rename from scripts/merge_block_weighted.py rename to scripts/mbw/merge_block_weighted.py diff --git a/scripts/mbw/ui_mbw.py b/scripts/mbw/ui_mbw.py new file mode 100644 index 0000000..79742c0 --- /dev/null +++ b/scripts/mbw/ui_mbw.py @@ -0,0 +1,215 @@ +import gradio as gr +import os + +from modules import sd_models, shared +from tqdm import tqdm + +from scripts.mbw.merge_block_weighted import merge +from scripts.util.preset_weights import PresetWeights +from scripts.util.merge_history import MergeHistory + +presetWeights = PresetWeights() +mergeHistory = MergeHistory() + + +def on_ui_tabs(): + with gr.Column(): + with gr.Row(): + with gr.Column(variant="panel"): + btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary") + btn_clear_weighted = gr.Button(value="Clear values") + btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint") + html_output_block_weight_info = gr.HTML() + with gr.Column(): + dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list()) + txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25") + btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary") + with gr.Row(): + sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=1) + chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False) + chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False) + with gr.Row(): + model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles()) + model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles()) + txt_model_O = gr.Text(label="Output Model Name") + with gr.Row(): + with gr.Column(): + sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.00000000001, value=0.5) + with gr.Column(): + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="mbw_sl_M00") + with gr.Column(): + sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.00000000001, value=0.5) + sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.00000000001, value=0.5) + + sl_IN = [ + sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, + sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11] + sl_MID = [sl_M_00] + sl_OUT = [ + sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, + sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11] + + # Events + def onclick_btn_do_merge_block_weighted( + model_A, model_B, + sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, + sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, + sl_M_00, + sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, + sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, + txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite + ): + _weights = ",".join( + [str(x) for x in [ + sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, + sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, + sl_M_00, + sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, + sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11 + ]]) + # + if not model_A or not model_B: + return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]") + + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path + model_A_info = sd_models.get_closet_checkpoint_match(model_A) + if model_A_info: + _model_A_name = model_A_info.model_name + else: + _model_A_name = "" + model_B_info = sd_models.get_closet_checkpoint_match(model_B) + if model_B_info: + _model_B_info = model_B_info.model_name + else: + _model_B_info = "" + model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O + if ".ckpt" not in model_O: + model_O = model_O + ".ckpt" + + _output = os.path.join(ckpt_dir, model_O) + # debug output + print( "#### Merge Block Weighted ####") + if not chk_allow_overwrite: + if os.path.exists(_output): + _err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort." + print(_err_msg) + return gr.update(value=f"{_err_msg} [{_output}]") + print(f"model_0 : {model_A}") + print(f"model_1 : {model_B}") + print(f"base_alpha : {sl_base_alpha}") + print(f"output_file: {_output}") + print(f"weights : {_weights}") + + result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw) + + sd_models.list_models() + if result: + ret_html = "merged.
" + f"{model_A}
" + f"{model_B}
" + f"{model_O}" + else: + ret_html = ret_message + + # save log to history.tsv + model_O_info = sd_models.get_closet_checkpoint_match(model_O) + model_O_hash = "" if not model_O_info else model_O_info.hash + _names = presetWeights.find_names_by_weight(_weights) + if _names and len(_names) > 0: + weight_name = _names[0] + else: + weight_name = "" + mergeHistory.add_history(model_A, model_B, model_O, model_O_hash, sl_base_alpha, _weights, "", weight_name) + + return gr.update(value=f"{ret_html}") + btn_do_merge_block_weighted.click( + fn=onclick_btn_do_merge_block_weighted, + inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite], + outputs=[html_output_block_weight_info] + ) + + btn_clear_weighted.click( + fn=lambda: [gr.update(value=0.5) for _ in range(25)], + inputs=[], + outputs=[ + sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, + sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, + sl_M_00, + sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, + sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, + ] + ) + + def on_change_dd_preset_weight(dd_preset_weight): + _weights = presetWeights.find_weight_by_name(dd_preset_weight) + _ret = on_btn_apply_block_weithg_from_txt(_weights) + return [gr.update(value=_weights)] + _ret + dd_preset_weight.change( + fn=on_change_dd_preset_weight, + inputs=[dd_preset_weight], + outputs=[txt_block_weight, + sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, + sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, + sl_M_00, + sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, + sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, + ] + ) + + def on_btn_reload_checkpoint_mbw(): + sd_models.list_models() + return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())] + btn_reload_checkpoint_mbw.click( + fn=on_btn_reload_checkpoint_mbw, + inputs=[], + outputs=[model_A, model_B] + ) + + def on_btn_apply_block_weithg_from_txt(txt_block_weight): + if not txt_block_weight or txt_block_weight == "": + return [gr.update() for _ in range(25)] + _list = [x.strip() for x in txt_block_weight.split(",")] + if(len(_list) != 25): + return [gr.update() for _ in range(25)] + return [gr.update(value=x) for x in _list] + btn_apply_block_weithg_from_txt.click( + fn=on_btn_apply_block_weithg_from_txt, + inputs=[txt_block_weight], + outputs=[ + sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, + sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, + sl_M_00, + sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, + sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, + ] + ) + diff --git a/scripts/mbw_each/merge_block_weighted_mod.py b/scripts/mbw_each/merge_block_weighted_mod.py new file mode 100644 index 0000000..ebcaa38 --- /dev/null +++ b/scripts/mbw_each/merge_block_weighted_mod.py @@ -0,0 +1,160 @@ +# from https://note.com/kohya_ss/n/n9a485a066d5b +# kohya_ss +# original code: https://github.com/eyriewow/merge-models + +# use them as base of this code +# 2022/12/15 +# bbc-mc + +import os +import argparse +import re +import torch +from tqdm import tqdm + +from modules import sd_models, shared + + +NUM_INPUT_BLOCKS = 12 +NUM_MID_BLOCK = 1 +NUM_OUTPUT_BLOCKS = 12 +NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS + + +def dprint(str, flg): + if flg: + print(str) + + +def merge(weight_A:list, weight_B:list, model_0, model_1, device="cpu", base_alpha=0.5, + output_file="", allow_overwrite=False, verbose=False): + + def _check_arg_weight(weight): + if weight is None: + return None + else: + _weight = [float(w) for w in weight.split(",")] + if len(_weight) != NUM_TOTAL_BLOCKS: + return None + else: + return _weight + + weight_A = _check_arg_weight(weight_A) + if weight_A is None: + _err_msg = f"Weight A invalid. program abort. weights value must be {NUM_TOTAL_BLOCKS}." + print(_err_msg) + return False, _err_msg + weight_B = _check_arg_weight(weight_B) + if weight_B is None: + _err_msg = f"Weight B invalid. program abort. weights value must be {NUM_TOTAL_BLOCKS}." + print(_err_msg) + return False, _err_msg + + device = device if device in ["cpu", "cuda"] else "cpu" + + alpha = base_alpha + if not output_file or output_file == "": + output_file = f'bw-{model_0}-{model_1}-{str(alpha)[2:] + "0"}.ckpt' + else: + output_file = output_file if ".ckpt" in output_file else output_file + ".ckpt" + + # check if output file already exists + if os.path.isfile(output_file) and not allow_overwrite: + _err_msg = f"Exiting... [{output_file}]" + print(_err_msg) + return False, _err_msg + + def load_model(_model, _device): + model_info = sd_models.get_closet_checkpoint_match(_model) + if model_info: + model_file = model_info.filename + else: + return None + cache_enabled = shared.opts.sd_checkpoint_cache > 0 + if cache_enabled and model_info in sd_models.checkpoints_loaded: + print(" load from cache") + return sd_models.checkpoints_loaded[model_info].copy() + else: + print(" loading ...") + return sd_models.read_state_dict(model_file, map_location=_device) + + print("loading", model_0) + theta_0 = load_model(model_0, device) + + print("loading", model_1) + theta_1 = load_model(model_1, device) + + re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12 + re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1 + re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12 + + dprint(f"-- start Stage 1/2 --", verbose) + count_target_of_basealpha = 0 + for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not verbose else theta_0.keys()): + if "model" in key and key in theta_1: + dprint(f" key : {key}", verbose) + + current_alpha_A = 1 - alpha + current_alpha_B = alpha + current_alpha_I = 0 + + # check weighted and U-Net or not + if weight_A is not None and 'model.diffusion_model.' in key: + # check block index + weight_index = -1 + + if 'time_embed' in key: + weight_index = 0 # before input blocks + elif '.out.' in key: + weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks + else: + m = re_inp.search(key) + if m: + inp_idx = int(m.groups()[0]) + weight_index = inp_idx + else: + m = re_mid.search(key) + if m: + weight_index = NUM_INPUT_BLOCKS + else: + m = re_out.search(key) + if m: + out_idx = int(m.groups()[0]) + weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx + + if weight_index >= NUM_TOTAL_BLOCKS: + print(f"error. illegal block index: {key}") + if weight_index >= 0: + current_alpha_A = weight_A[weight_index] + current_alpha_B = weight_B[weight_index] + current_alpha_I = 1 - current_alpha_A - current_alpha_B + if verbose: + print(f"weighted '{key}': A{current_alpha_A} B{current_alpha_B} I{current_alpha_I}") + + # create I tensor + tensor_I_0 = torch.zeros_like(theta_0[key], dtype=theta_0[key].dtype) + _var1 = current_alpha_I * tensor_I_0 + _var2 = current_alpha_A * theta_0[key] + _var3 = current_alpha_B * theta_1[key] + theta_0[key] = _var1 + _var2 + _var3 + + # theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key] + + else: + dprint(f" key - {key}", verbose) + + dprint(f"-- start Stage 2/2 --", verbose) + for key in tqdm(theta_1.keys(), desc="Stage 2/2"): + if "model" in key and key not in theta_0: + dprint(f" key : {key}", verbose) + theta_0.update({key:theta_1[key]}) + else: + dprint(f" key - {key}", verbose) + + print("Saving...") + + torch.save({"state_dict": theta_0}, output_file) + + print("Done!") + + return True, f"{output_file}
base_alpha applied [{count_target_of_basealpha}] times." diff --git a/scripts/mbw_each/ui_mbw_each.py b/scripts/mbw_each/ui_mbw_each.py new file mode 100644 index 0000000..1c353be --- /dev/null +++ b/scripts/mbw_each/ui_mbw_each.py @@ -0,0 +1,426 @@ +import gradio as gr +import os +import re + +from modules import sd_models, shared +from tqdm import tqdm + +from scripts.mbw_each.merge_block_weighted_mod import merge +from scripts.util.preset_weights import PresetWeights +from scripts.util.merge_history import MergeHistory + +presetWeights = PresetWeights() +mergeHistory = MergeHistory() + + +def on_ui_tabs(): + with gr.Column(): + with gr.Row(): + with gr.Column(variant="panel"): + with gr.Row(): + btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary") + btn_clear_weighted = gr.Button(value="Clear values") + btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint") + with gr.Row(): + txt_multi_process_cmd = gr.TextArea(label="Multi Proc Cmd", placeholder="Keep empty if dont use.") + html_output_block_weight_info = gr.HTML() + with gr.Column(): + dd_preset_weight = gr.Dropdown(label="Preset_Weights", choices=presetWeights.get_preset_name_list()) + txt_block_weight = gr.Text(label="Weight_values", placeholder="Put weight sets. float number x 25") + btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary") + with gr.Row(): + sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=0) + chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False) + chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False) + with gr.Row(): + dd_model_A = gr.Dropdown(label="Model_A", choices=sd_models.checkpoint_tiles()) + dd_model_B = gr.Dropdown(label="Model_B", choices=sd_models.checkpoint_tiles()) + txt_model_O = gr.Text(label="(O)Output Model Name") + with gr.Row(): + with gr.Column(): + sl_IN_A_00 = gr.Slider(label="IN_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_00") + sl_IN_A_01 = gr.Slider(label="IN_A_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_01") + sl_IN_A_02 = gr.Slider(label="IN_A_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_02") + sl_IN_A_03 = gr.Slider(label="IN_A_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_03") + sl_IN_A_04 = gr.Slider(label="IN_A_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_04") + sl_IN_A_05 = gr.Slider(label="IN_A_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_05") + sl_IN_A_06 = gr.Slider(label="IN_A_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_06") + sl_IN_A_07 = gr.Slider(label="IN_A_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_07") + sl_IN_A_08 = gr.Slider(label="IN_A_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_08") + sl_IN_A_09 = gr.Slider(label="IN_A_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_09") + sl_IN_A_10 = gr.Slider(label="IN_A_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_10") + sl_IN_A_11 = gr.Slider(label="IN_A_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_11") + with gr.Column(): + sl_IN_B_00 = gr.Slider(label="IN_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_00") + sl_IN_B_01 = gr.Slider(label="IN_B_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_01") + sl_IN_B_02 = gr.Slider(label="IN_B_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_02") + sl_IN_B_03 = gr.Slider(label="IN_B_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_03") + sl_IN_B_04 = gr.Slider(label="IN_B_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_04") + sl_IN_B_05 = gr.Slider(label="IN_B_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_05") + sl_IN_B_06 = gr.Slider(label="IN_B_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_06") + sl_IN_B_07 = gr.Slider(label="IN_B_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_07") + sl_IN_B_08 = gr.Slider(label="IN_B_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_08") + sl_IN_B_09 = gr.Slider(label="IN_B_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_09") + sl_IN_B_10 = gr.Slider(label="IN_B_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_10") + sl_IN_B_11 = gr.Slider(label="IN_B_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_11") + with gr.Column(): + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + sl_M_A_00 = gr.Slider(label="M_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_M_A_00") + with gr.Column(): + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + gr.Slider(visible=False) + sl_M_B_00 = gr.Slider(label="M_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_M_B_00") + with gr.Column(): + sl_OUT_A_11 = gr.Slider(label="OUT_A_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_11") + sl_OUT_A_10 = gr.Slider(label="OUT_A_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_10") + sl_OUT_A_09 = gr.Slider(label="OUT_A_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_09") + sl_OUT_A_08 = gr.Slider(label="OUT_A_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_08") + sl_OUT_A_07 = gr.Slider(label="OUT_A_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_07") + sl_OUT_A_06 = gr.Slider(label="OUT_A_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_06") + sl_OUT_A_05 = gr.Slider(label="OUT_A_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_05") + sl_OUT_A_04 = gr.Slider(label="OUT_A_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_04") + sl_OUT_A_03 = gr.Slider(label="OUT_A_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_03") + sl_OUT_A_02 = gr.Slider(label="OUT_A_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_02") + sl_OUT_A_01 = gr.Slider(label="OUT_A_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_01") + sl_OUT_A_00 = gr.Slider(label="OUT_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_00") + with gr.Column(): + sl_OUT_B_11 = gr.Slider(label="OUT_B_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_11") + sl_OUT_B_10 = gr.Slider(label="OUT_B_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_10") + sl_OUT_B_09 = gr.Slider(label="OUT_B_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_09") + sl_OUT_B_08 = gr.Slider(label="OUT_B_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_08") + sl_OUT_B_07 = gr.Slider(label="OUT_B_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_07") + sl_OUT_B_06 = gr.Slider(label="OUT_B_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_06") + sl_OUT_B_05 = gr.Slider(label="OUT_B_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_05") + sl_OUT_B_04 = gr.Slider(label="OUT_B_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_04") + sl_OUT_B_03 = gr.Slider(label="OUT_B_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_03") + sl_OUT_B_02 = gr.Slider(label="OUT_B_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_02") + sl_OUT_B_01 = gr.Slider(label="OUT_B_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_01") + sl_OUT_B_00 = gr.Slider(label="OUT_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_00") + + # Footer + gr.HTML( + """ +

+ Merge Block Weighted extension by bbc_mc
+ MBW Each is experimental functions and NO PROOF of effectiveness.
+ You can try it by own, to dig more deeper into Abyss ...
+

+ """ + ) + + sl_A_IN = [ + sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, + sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11] + sl_A_MID = [sl_M_A_00] + sl_A_OUT = [ + sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, + sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11] + + sl_B_IN = [ + sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, + sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11] + sl_B_MID = [sl_M_B_00] + sl_B_OUT = [ + sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, + sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11] + + + # Events + def onclick_btn_do_merge_block_weighted( + dd_model_A, dd_model_B, txt_multi_process_cmd, + sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, + sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, + sl_M_A_00, + sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, + sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, + sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, + sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, + sl_M_B_00, + sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, + sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, + txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite + ): + base_alpha = sl_base_alpha + _weight_A = ",".join( + [str(x) for x in [ + sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, + sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, + sl_M_A_00, + sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, + sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, + ]]) + _weight_B = ",".join( + [str(x) for x in [ + sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, + sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, + sl_M_B_00, + sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, + sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, + ]]) + + # debug output + print( "#### Merge Block Weighted ####") + + if (not dd_model_A or not dd_model_B) and txt_multi_process_cmd == "": + _err_msg = f"ERROR: model not found. [{dd_model_A}][{dd_model_B}]" + print(_err_msg) + return gr.update(value=_err_msg) + + ret_html = "" + if txt_multi_process_cmd != "": + # need multi-merge + _lines = txt_multi_process_cmd.split('\n') + print(f"check multi-merge. {len(_lines)} lines found.") + for line_index, _line in enumerate(_lines): + if _line == "": + continue + print(f"\n== merge line {line_index+1}/{len(_lines)} ==") + _items = [x.strip() for x in _line.split(",") if x != ""] + if len(_items) > 0: + ret_html += _run_merge( + weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B, + allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw, params=_items) + else: + _ret = f" multi-merge text found, but invalid params. skipped :[{_line}]" + ret_html += _ret + print(_ret) + else: + # normal merge + ret_html += _run_merge( + weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B, + allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw) + + sd_models.list_models() + print( "#### All merge process done. ####") + + return gr.update(value=f"{ret_html}") + btn_do_merge_block_weighted.click( + fn=onclick_btn_do_merge_block_weighted, + inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd] + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite], + outputs=[html_output_block_weight_info] + ) + + def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0, model_Output="", verbose=False, params=[]): + + # generate output file name from param + model_A_info = sd_models.get_closet_checkpoint_match(model_0) + _model_A_name = "" if not model_A_info else model_A_info.model_name # expect "" or "aaa.ckpt [abcdefgh]" + + model_B_info = sd_models.get_closet_checkpoint_match(model_1) + _model_B_name = "" if not model_B_info else model_B_info.model_name + + model_O = f"bw-merge-{_model_A_name}-{_model_B_name}-{base_alpha}.ckpt" if model_Output == "" else model_Output + model_O = model_O if ".ckpt" in model_O else f"{model_O}.ckpt" + model_O = re.sub(r'[\\|/|:|?|.|"|<|>|\|\*]', '-', model_O) + + if params and len(params) > 0: + for _item in params: + # expect "O=merge/test02, IN_B_00 = 0.12345" as params=["O=merge/test02", "IN_B_00 = 0.12345"] + if len(_item.split("=")) == 2: + _item_l = _item.split("=")[0].strip() + _item_r = _item.split("=")[1].strip() + if _item_r != "": + if _item_l.lower() == "model_a" or _item_l.lower() == "model_b": + _model_info = sd_models.get_closet_checkpoint_match(_item_r) + if _model_info: + _model_name = _model_info.title.split(" ")[0] + if _model_name and _model_name.strip() != "": + if _item_l.lower() == "model_a": + print(f" * Model changed: {model_0} -> {_model_info.title}") + model_0 = _model_info.title + elif _item_l.lower() == "model_b": + print(f" * Model changed: {model_1} -> {_model_info.title}") + model_1 = _model_info.title + + elif _item_l.lower() == "preset_weights": + _weights = presetWeights.find_weight_by_name(_item_r) + if _weights != "" and len(_weights.split(',')) == 25: + print(f" * Weights changed by preset-name: {_item_r}") + weight_B = _weights + weight_A = ",".join([str(1-float(x)) for x in _weights.split(',')]) + else: + print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]") + + elif _item_l.lower() == "weight_values": + _weights = _item_r.strip() + if _weights != "" and len(_weights.split(' ')) == 25: # this is work-around to use space as separator. Double-meaning issue on commna which already used as value separator and weights separator. + print(f" * Weights changed: {_item_r}") + weight_B = _weights + weight_A = ",".join([str(1-float(x)) for x in _weights.split(' ')]) + else: + print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]") + + elif _item_l.lower() == "base_alpha": + if float(_item_r) >= 0: + print(f" * base_alpha changed: {base_alpha} -> {_item_r}") + base_alpha = float(_item_r) + + elif _item_l.upper() == "O": + _item_r = _item_r if ".ckpt" in _item_r else f"{_item_r}.ckpt" + print(f" * Output filename changed:[{model_O}] -> [{_item_r}]") + model_O = _item_r + + elif len(_item_l.split("_")) == 3: + _IMO = _item_l.split("_")[0] + _AB = _item_l.split("_")[1] + _NUM = _item_l.split("_")[2] + + _index = int(_NUM) + _index = _index + 0 if _IMO == "IN" else _index + _index = _index + 12 if _IMO == "M" else _index + _index = _index + 13 if _IMO == "OUT" else _index + + def _apply_val(key, weight, index, new_value): + _weight = [x.strip() for x in weight.split(",")] + _new_weight = _weight[:] + _new_weight[index] = new_value + _new_weight = ",".join(_new_weight) + print(f" * weight_{key} changed:[{weight}]") + print(f" -> [{_new_weight}]") + return _new_weight + + if _AB == "A": + weight_A = _apply_val(_AB, weight_A, _index, _item_r) + elif _AB == "B": + weight_B = _apply_val(_AB, weight_B, _index, _item_r) + else: + print(f" * Waring: uncaught param found. ignored. [{_item_l}][{_item_r}]") + + # + # Prepare params before run merge + # + output_file = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O) + # + # Check params + # + if not os.path.exists(os.path.dirname(output_file)): + _err_msg = f"WARNING: target path not found: {os.path.dirname(output_file)}. skipped." + print(_err_msg) + return _err_msg + "
" + if not allow_overwrite: + if os.path.exists(output_file): + _err_msg = f"WARNING: output_file already exists. overwrite not allowed. skipped." + print(_err_msg) + return _err_msg + "
" + + # debug output + print(f" model_0 : {model_0}") + print(f" model_1 : {model_1}") + print(f" model_Out : {model_O}") + print(f" base_alpha : {base_alpha}") + print(f" output_file: {output_file}") + print(f" weight_A : {weight_A}") + print(f" weight_B : {weight_B}") + + result, ret_message = merge( + weight_A=weight_A, weight_B=weight_B, model_0=model_0, model_1=model_1, + allow_overwrite=allow_overwrite, base_alpha=base_alpha, output_file=output_file, verbose=verbose) + if result: + ret_html = f"merged. {model_0} + {model_1} = {model_O}
" + print("merged.") + else: + ret_html = ret_message + print("merge failed.") + + + # save log to history.tsv + model_O_info = sd_models.get_closet_checkpoint_match(model_O) + model_O_hash = "" if not model_O_info else model_O_info.hash + _names = presetWeights.find_names_by_weight(weight_A) + if _names and len(_names) > 0: + weight_name = _names[0] + else: + weight_name = "" + mergeHistory.add_history(model_0, model_1, model_O, model_O_hash, base_alpha, weight_A, weight_B, weight_name) + return ret_html + + btn_clear_weighted.click( + fn=lambda: [gr.update(value=0.5) for _ in range(25*2)], + inputs=[], + outputs=[ + sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, + sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, + sl_M_A_00, + sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, + sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, + sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, + sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, + sl_M_B_00, + sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, + sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, + ] + ) + + def on_change_dd_preset_weight(dd_preset_weight): + _weights = presetWeights.find_weight_by_name(dd_preset_weight) + _ret = on_btn_apply_block_weight_from_txt(_weights) + return [gr.update(value=_weights)] + _ret + dd_preset_weight.change( + fn=on_change_dd_preset_weight, + inputs=[dd_preset_weight], + outputs=[ + txt_block_weight, + sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, + sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, + sl_M_A_00, + sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, + sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, + sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, + sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, + sl_M_B_00, + sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, + sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, + ] + ) + + def on_btn_reload_checkpoint_mbw(): + sd_models.list_models() + return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())] + btn_reload_checkpoint_mbw.click( + fn=on_btn_reload_checkpoint_mbw, + inputs=[], + outputs=[dd_model_A, dd_model_B] + ) + + def on_btn_apply_block_weight_from_txt(txt_block_weight): + if not txt_block_weight or txt_block_weight == "": + return [gr.update() for _ in range(25*2)] + _list = [x.strip() for x in txt_block_weight.split(",")] + if(len(_list) != 25): + return [gr.update() for _ in range(25*2)] + return [gr.update(value=str(1-float(x))) for x in _list] + [gr.update(value=x) for x in _list] + btn_apply_block_weithg_from_txt.click( + fn=on_btn_apply_block_weight_from_txt, + inputs=[txt_block_weight], + outputs=[ + sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, + sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, + sl_M_A_00, + sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, + sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, + sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, + sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, + sl_M_B_00, + sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, + sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, + ] + ) diff --git a/scripts/merge_block_weighted_extension.py b/scripts/merge_block_weighted_extension.py index a9dd63a..e7a6fe5 100644 --- a/scripts/merge_block_weighted_extension.py +++ b/scripts/merge_block_weighted_extension.py @@ -8,17 +8,12 @@ import os import gradio as gr -from modules import scripts, script_callbacks -from modules import sd_models, shared +from modules import script_callbacks -from scripts.merge_block_weighted import merge -from scripts.merge_history import MergeHistory -from scripts.preset_weights import PresetWeights -path_root = scripts.basedir() +from scripts.mbw import ui_mbw +from scripts.mbw_each import ui_mbw_each -mergeHistory = MergeHistory() -presetWeights = PresetWeights() # # UI callback @@ -26,204 +21,11 @@ presetWeights = PresetWeights() def on_ui_tabs(): with gr.Blocks() as main_block: - with gr.Column(): - with gr.Row(): - with gr.Column(variant="panel"): - btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary") - btn_clear_weighted = gr.Button(value="Clear values") - btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint") - html_output_block_weight_info = gr.HTML() - with gr.Column(): - dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list()) - txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25") - btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary") - with gr.Row(): - sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=1) - chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False) - chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False) - with gr.Row(): - model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles()) - model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles()) - txt_model_O = gr.Text(label="Output Model Name") - with gr.Row(): - with gr.Column(): - sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.00000000001, value=0.5) - with gr.Column(): - gr.Slider(visible=False) - gr.Slider(visible=False) - gr.Slider(visible=False) - gr.Slider(visible=False) - gr.Slider(visible=False) - gr.Slider(visible=False) - gr.Slider(visible=False) - gr.Slider(visible=False) - gr.Slider(visible=False) - gr.Slider(visible=False) - gr.Slider(visible=False) - sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="mbw_sl_M00") - with gr.Column(): - sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.00000000001, value=0.5) - sl_IN = [ - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11] - sl_MID = [sl_M_00] - sl_OUT = [ - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11] + with gr.Tab("MBW", elem_id="tab_mbw"): + ui_mbw.on_ui_tabs() - # Events - def onclick_btn_do_merge_block_weighted( - model_A, model_B, - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, - sl_M_00, - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, - txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite - ): - _weights = ",".join( - [str(x) for x in [ - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, - sl_M_00, - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11 - ]]) - # - if not model_A or not model_B: - return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]") - - ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - model_A_info = sd_models.get_closet_checkpoint_match(model_A) - if model_A_info: - _model_A_name = model_A_info.model_name - else: - _model_A_name = "" - model_B_info = sd_models.get_closet_checkpoint_match(model_B) - if model_B_info: - _model_B_info = model_B_info.model_name - else: - _model_B_info = "" - model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O - if ".ckpt" not in model_O: - model_O = model_O + ".ckpt" - - _output = os.path.join(ckpt_dir, model_O) - # debug output - print( "#### Merge Block Weighted ####") - if not chk_allow_overwrite: - if os.path.exists(_output): - _err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort." - print(_err_msg) - return gr.update(value=f"{_err_msg} [{_output}]") - print(f"model_0 : {model_A}") - print(f"model_1 : {model_B}") - print(f"base_alpha : {sl_base_alpha}") - print(f"output_file: {_output}") - print(f"weights : {_weights}") - - result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw) - - sd_models.list_models() - if result: - ret_html = "merged.
" + f"{model_A}
" + f"{model_B}
" + f"{model_O}" - else: - ret_html = ret_message - - # save log to history.tsv - model_O_info = sd_models.get_closet_checkpoint_match(model_O) - model_O_hash = "" if not model_O_info else model_O_info.hash - _names = presetWeights.find_names_by_weight(_weights) - if _names and len(_names) > 0: - weight_name = _names[0] - else: - weight_name = "" - mergeHistory.add_history(model_A, model_B, model_O, model_O_hash, sl_base_alpha, _weights, weight_name) - - return gr.update(value=f"{ret_html}") - btn_do_merge_block_weighted.click( - fn=onclick_btn_do_merge_block_weighted, - inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite], - outputs=[html_output_block_weight_info] - ) - - btn_clear_weighted.click( - fn=lambda: [gr.update(value=0.5) for _ in range(25)], - inputs=[], - outputs=[ - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, - sl_M_00, - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, - ] - ) - - def on_change_dd_preset_weight(dd_preset_weight): - _weights = presetWeights.find_weight_by_name(dd_preset_weight) - _ret = on_btn_apply_block_weithg_from_txt(_weights) - return [gr.update(value=_weights)] + _ret - dd_preset_weight.change( - fn=on_change_dd_preset_weight, - inputs=[dd_preset_weight], - outputs=[txt_block_weight, - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, - sl_M_00, - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, - ] - ) - - def on_btn_reload_checkpoint_mbw(): - sd_models.list_models() - return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())] - btn_reload_checkpoint_mbw.click( - fn=on_btn_reload_checkpoint_mbw, - inputs=[], - outputs=[model_A, model_B] - ) - - def on_btn_apply_block_weithg_from_txt(txt_block_weight): - if not txt_block_weight or txt_block_weight == "": - return [gr.update() for _ in range(25)] - _list = [x.strip() for x in txt_block_weight.split(",")] - if(len(_list) != 25): - return [gr.update() for _ in range(25)] - return [gr.update(value=x) for x in _list] - btn_apply_block_weithg_from_txt.click( - fn=on_btn_apply_block_weithg_from_txt, - inputs=[txt_block_weight], - outputs=[ - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, - sl_M_00, - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, - ] - ) + with gr.Tab("MBW Each", elem_id="tab_mbw_each"): + ui_mbw_each.on_ui_tabs() # return required as (gradio_component, title, elem_id) return (main_block, "Merge Block Weighted", "merge_block_weighted"), diff --git a/scripts/merge_history.py b/scripts/merge_history.py deleted file mode 100644 index 1aa3d11..0000000 --- a/scripts/merge_history.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# -# -import os -from csv import DictWriter, writer - -from modules import scripts - - -CSV_FILE_PATH = "csv/history.tsv" -HEADERS = ["model_A", "model_A_hash", "model_B", "model_B_hash", "model_O", "model_O_hash", "base_alpha", "weight_name", "weight_values"] -path_root = scripts.basedir() - - -class MergeHistory(): - def __init__(self): - self.filepath = os.path.join(path_root, CSV_FILE_PATH) - - def add_history(self, model_A, model_B, model_O, model_O_hash, sl_base_alpha, weight_values, weight_name=""): - _history_dict = {} - _history_dict.update({ - "model_A": f"{os.path.basename(model_A.split(' ')[0])}", - "model_A_hash": f"{model_A.split(' ')[1]}", - "model_B": f"{os.path.basename(model_B.split(' ')[0])}", - "model_B_hash": f"{model_B.split(' ')[1]}", - "model_O": model_O, - "model_O_hash": model_O_hash, - "base_alpha": sl_base_alpha, - "weight_name": weight_name, - "weight_values": weight_values, - }) - - if not os.path.exists(self.filepath): - with open(self.filepath, "w", newline="", encoding="utf-8") as f: - wr = DictWriter(f, fieldnames=HEADERS, delimiter='\t') - wr.writeheader() - # save to file - with open(self.filepath, "a", newline="", encoding='utf-8') as f: - dictwriter = DictWriter(f, fieldnames=HEADERS, delimiter='\t') - dictwriter.writerow(_history_dict) diff --git a/scripts/util/merge_history.py b/scripts/util/merge_history.py new file mode 100644 index 0000000..eabd16a --- /dev/null +++ b/scripts/util/merge_history.py @@ -0,0 +1,61 @@ +# +# +# +import os +import datetime +from csv import DictWriter, DictReader + +from modules import scripts + + +CSV_FILE_PATH = "csv/history.tsv" +HEADERS = ["model_A", "model_A_hash", "model_B", "model_B_hash", "model_O", "model_O_hash", "base_alpha", "weight_name", "weight_values", "weight_values2", "datetime"] +path_root = scripts.basedir() + + +class MergeHistory(): + def __init__(self): + self.filepath = os.path.join(path_root, CSV_FILE_PATH) + if os.path.exists(self.filepath): + self.update_header() + + def add_history(self, model_A, model_B, model_O, model_O_hash, sl_base_alpha, weight_value_A, weight_value_B, weight_name=""): + _history_dict = {} + _history_dict.update({ + "model_A": f"{os.path.basename(model_A.split(' ')[0])}", + "model_A_hash": f"{model_A.split(' ')[1]}", + "model_B": f"{os.path.basename(model_B.split(' ')[0])}", + "model_B_hash": f"{model_B.split(' ')[1]}", + "model_O": model_O, + "model_O_hash": model_O_hash, + "base_alpha": sl_base_alpha, + "weight_name": weight_name, + "weight_values": weight_value_A, + "weight_values2": weight_value_B, + "datetime": f"{datetime.datetime.now()}" + }) + + if not os.path.exists(self.filepath): + with open(self.filepath, "w", newline="", encoding="utf-8") as f: + dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t') + dw.writeheader() + # save to file + with open(self.filepath, "a", newline="", encoding='utf-8') as f: + dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t') + dw.writerow(_history_dict) + + def update_header(self): + hist_data = [] + if os.path.exists(self.filepath): + # check header in case HEADERS updated + with open(self.filepath, "r", newline="", encoding="utf-8") as f: + dr = DictReader(f, delimiter='\t') + new_header = [ x for x in HEADERS if x not in dr.fieldnames ] + if len(new_header) > 0: + # need update. + hist_data = [ x for x in dr] + if len(hist_data) > 0: + with open(self.filepath, "w", newline="", encoding="utf-8") as f: + dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t') + dw.writeheader() + dw.writerows(hist_data) diff --git a/scripts/preset_weights.py b/scripts/util/preset_weights.py similarity index 68% rename from scripts/preset_weights.py rename to scripts/util/preset_weights.py index 9f01fba..7867f32 100644 --- a/scripts/preset_weights.py +++ b/scripts/util/preset_weights.py @@ -8,14 +8,24 @@ from modules import scripts CSV_FILE_PATH = "csv/preset.tsv" +MYPRESET_PATH = "csv/preset_own.tsv" +HEADER = ["preset_name", "preset_weights"] path_root = scripts.basedir() class PresetWeights(): def __init__(self): - self.filepath = os.path.join(path_root, CSV_FILE_PATH) self.presets = {} - with open(self.filepath, "r") as f: + + if os.path.exists(os.path.join(path_root, MYPRESET_PATH)): + with open(os.path.join(path_root, MYPRESET_PATH), "r") as f: + reader = DictReader(f, delimiter="\t") + lines_dict = [row for row in reader] + for line_dict in lines_dict: + _w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")]) + self.presets.update({line_dict["preset_name"]: _w}) + + with open(os.path.join(path_root, CSV_FILE_PATH), "r") as f: reader = DictReader(f, delimiter="\t") lines_dict = [row for row in reader] for line_dict in lines_dict: @@ -26,7 +36,7 @@ class PresetWeights(): return [k for k in self.presets.keys()] def find_weight_by_name(self, preset_name=""): - if preset_name and preset_name != "" and preset_name in self.presets: + if preset_name and preset_name != "" and preset_name in self.presets.keys(): return self.presets.get(preset_name, ",".join(["0.5" for _ in range(25)])) else: return "" diff --git a/style.css b/style.css index 22404b5..5015514 100644 --- a/style.css +++ b/style.css @@ -1,4 +1,29 @@ -#mbw_sl_M00 { +#mbw_sl_M00, #mbw_sl_a_M00, #mbw_sl_b_M00 { bottom:0; position:absolute; -} \ No newline at end of file +} + +/* +#sl_IN_A_00, #sl_IN_A_01, #sl_IN_A_02, #sl_IN_A_03, #sl_IN_A_04, #sl_IN_A_05, #sl_IN_A_06, #sl_IN_A_07, #sl_IN_A_08, #sl_IN_A_09, #sl_IN_A_10, #sl_IN_A_11 { + width: 220; +} + +#sl_IN_B_00, #sl_IN_B_01, #sl_IN_B_02, #sl_IN_B_03, #sl_IN_B_04, #sl_IN_B_05, #sl_IN_B_06, #sl_IN_B_07, #sl_IN_B_08, #sl_IN_B_09, #sl_IN_B_10, #sl_IN_B_11 { + width: 220; +} +*/ + +#sl_M_A_00, #sl_M_B_00 { + bottom:0; + position:absolute; +} + +/* +#sl_OUT_A_00, #sl_OUT_A_01, #sl_OUT_A_02, #sl_OUT_A_03, #sl_OUT_A_04, #sl_OUT_A_05, #sl_OUT_A_06, #sl_OUT_A_07, #sl_OUT_A_08, #sl_OUT_A_09, #sl_OUT_A_10, #sl_OUT_A_11 { + width: 220; +} + +#sl_OUT_B_00, #sl_OUT_B_01, #sl_OUT_B_02, #sl_OUT_B_03, #sl_OUT_B_04, #sl_OUT_B_05, #sl_OUT_B_06, #sl_OUT_B_07, #sl_OUT_B_08, #sl_OUT_B_09, #sl_OUT_B_10, #sl_OUT_B_11 { + width: 220; +} +*/ \ No newline at end of file