Big update with functions, keeping current function and UIs

add: add tab to divide feature. Current remain on "MBW" and new feature comes on "MBW Each"
refact: divide scripts to each feature
feature: "MBW Each" function, allow set percentage of model A and model B for each layer.
feature: add support for csv/preset_own.tsv, not included in git (so not overwrite by update)
feature: add datetime column on logfile
feature: add Multi-Merge feature (#5)
exp/feature-each-merge
bbc_mc 2022-12-25 20:00:00 +09:00
parent 63ba0926bb
commit 75a31b481a
11 changed files with 936 additions and 250 deletions

4
.gitignore vendored
View File

@ -1 +1,5 @@
/csv/history.tsv
/csv/preset_own.tsv
#
_*

23
javascript/js_mbw_each.js Normal file
View File

@ -0,0 +1,23 @@
//
// fix position of sliders
//
//
// UI
//
onUiUpdate(function () {
// check Extension loaded
if (gradioApp().querySelector("div#tab_mbw_each") == null ) return;
// check already done
//if (gradioApp().querySelector("#div_mdl_size_a") != null) return;
// apply
let _style = "min-width: min(200px, 100%); flex-grow: 1";
gradioApp().querySelector("#sl_IN_A_00").parentElement.parentElement.setAttribute("style", _style)
gradioApp().querySelector("#sl_IN_B_00").parentElement.parentElement.setAttribute("style", _style)
gradioApp().querySelector("#sl_M_A_00").parentElement.parentElement.setAttribute("style", _style)
gradioApp().querySelector("#sl_M_B_00").parentElement.parentElement.setAttribute("style", _style)
gradioApp().querySelector("#sl_OUT_A_00").parentElement.parentElement.setAttribute("style", _style)
gradioApp().querySelector("#sl_OUT_B_00").parentElement.parentElement.setAttribute("style", _style)
});

215
scripts/mbw/ui_mbw.py Normal file
View File

@ -0,0 +1,215 @@
import gradio as gr
import os
from modules import sd_models, shared
from tqdm import tqdm
from scripts.mbw.merge_block_weighted import merge
from scripts.util.preset_weights import PresetWeights
from scripts.util.merge_history import MergeHistory
presetWeights = PresetWeights()
mergeHistory = MergeHistory()
def on_ui_tabs():
with gr.Column():
with gr.Row():
with gr.Column(variant="panel"):
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
btn_clear_weighted = gr.Button(value="Clear values")
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
html_output_block_weight_info = gr.HTML()
with gr.Column():
dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list())
txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25")
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
with gr.Row():
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=1)
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
with gr.Row():
model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles())
model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles())
txt_model_O = gr.Text(label="Output Model Name")
with gr.Row():
with gr.Column():
sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
with gr.Column():
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="mbw_sl_M00")
with gr.Column():
sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN = [
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11]
sl_MID = [sl_M_00]
sl_OUT = [
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11]
# Events
def onclick_btn_do_merge_block_weighted(
model_A, model_B,
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
):
_weights = ",".join(
[str(x) for x in [
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11
]])
#
if not model_A or not model_B:
return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]")
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
model_A_info = sd_models.get_closet_checkpoint_match(model_A)
if model_A_info:
_model_A_name = model_A_info.model_name
else:
_model_A_name = ""
model_B_info = sd_models.get_closet_checkpoint_match(model_B)
if model_B_info:
_model_B_info = model_B_info.model_name
else:
_model_B_info = ""
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
if ".ckpt" not in model_O:
model_O = model_O + ".ckpt"
_output = os.path.join(ckpt_dir, model_O)
# debug output
print( "#### Merge Block Weighted ####")
if not chk_allow_overwrite:
if os.path.exists(_output):
_err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort."
print(_err_msg)
return gr.update(value=f"{_err_msg} [{_output}]")
print(f"model_0 : {model_A}")
print(f"model_1 : {model_B}")
print(f"base_alpha : {sl_base_alpha}")
print(f"output_file: {_output}")
print(f"weights : {_weights}")
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw)
sd_models.list_models()
if result:
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{model_O}"
else:
ret_html = ret_message
# save log to history.tsv
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
model_O_hash = "" if not model_O_info else model_O_info.hash
_names = presetWeights.find_names_by_weight(_weights)
if _names and len(_names) > 0:
weight_name = _names[0]
else:
weight_name = ""
mergeHistory.add_history(model_A, model_B, model_O, model_O_hash, sl_base_alpha, _weights, "", weight_name)
return gr.update(value=f"{ret_html}")
btn_do_merge_block_weighted.click(
fn=onclick_btn_do_merge_block_weighted,
inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
outputs=[html_output_block_weight_info]
)
btn_clear_weighted.click(
fn=lambda: [gr.update(value=0.5) for _ in range(25)],
inputs=[],
outputs=[
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
]
)
def on_change_dd_preset_weight(dd_preset_weight):
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
_ret = on_btn_apply_block_weithg_from_txt(_weights)
return [gr.update(value=_weights)] + _ret
dd_preset_weight.change(
fn=on_change_dd_preset_weight,
inputs=[dd_preset_weight],
outputs=[txt_block_weight,
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
]
)
def on_btn_reload_checkpoint_mbw():
sd_models.list_models()
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
btn_reload_checkpoint_mbw.click(
fn=on_btn_reload_checkpoint_mbw,
inputs=[],
outputs=[model_A, model_B]
)
def on_btn_apply_block_weithg_from_txt(txt_block_weight):
if not txt_block_weight or txt_block_weight == "":
return [gr.update() for _ in range(25)]
_list = [x.strip() for x in txt_block_weight.split(",")]
if(len(_list) != 25):
return [gr.update() for _ in range(25)]
return [gr.update(value=x) for x in _list]
btn_apply_block_weithg_from_txt.click(
fn=on_btn_apply_block_weithg_from_txt,
inputs=[txt_block_weight],
outputs=[
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
]
)

View File

@ -0,0 +1,160 @@
# from https://note.com/kohya_ss/n/n9a485a066d5b
# kohya_ss
# original code: https://github.com/eyriewow/merge-models
# use them as base of this code
# 2022/12/15
# bbc-mc
import os
import argparse
import re
import torch
from tqdm import tqdm
from modules import sd_models, shared
NUM_INPUT_BLOCKS = 12
NUM_MID_BLOCK = 1
NUM_OUTPUT_BLOCKS = 12
NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
def dprint(str, flg):
if flg:
print(str)
def merge(weight_A:list, weight_B:list, model_0, model_1, device="cpu", base_alpha=0.5,
output_file="", allow_overwrite=False, verbose=False):
def _check_arg_weight(weight):
if weight is None:
return None
else:
_weight = [float(w) for w in weight.split(",")]
if len(_weight) != NUM_TOTAL_BLOCKS:
return None
else:
return _weight
weight_A = _check_arg_weight(weight_A)
if weight_A is None:
_err_msg = f"Weight A invalid. program abort. weights value must be {NUM_TOTAL_BLOCKS}."
print(_err_msg)
return False, _err_msg
weight_B = _check_arg_weight(weight_B)
if weight_B is None:
_err_msg = f"Weight B invalid. program abort. weights value must be {NUM_TOTAL_BLOCKS}."
print(_err_msg)
return False, _err_msg
device = device if device in ["cpu", "cuda"] else "cpu"
alpha = base_alpha
if not output_file or output_file == "":
output_file = f'bw-{model_0}-{model_1}-{str(alpha)[2:] + "0"}.ckpt'
else:
output_file = output_file if ".ckpt" in output_file else output_file + ".ckpt"
# check if output file already exists
if os.path.isfile(output_file) and not allow_overwrite:
_err_msg = f"Exiting... [{output_file}]"
print(_err_msg)
return False, _err_msg
def load_model(_model, _device):
model_info = sd_models.get_closet_checkpoint_match(_model)
if model_info:
model_file = model_info.filename
else:
return None
cache_enabled = shared.opts.sd_checkpoint_cache > 0
if cache_enabled and model_info in sd_models.checkpoints_loaded:
print(" load from cache")
return sd_models.checkpoints_loaded[model_info].copy()
else:
print(" loading ...")
return sd_models.read_state_dict(model_file, map_location=_device)
print("loading", model_0)
theta_0 = load_model(model_0, device)
print("loading", model_1)
theta_1 = load_model(model_1, device)
re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
dprint(f"-- start Stage 1/2 --", verbose)
count_target_of_basealpha = 0
for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not verbose else theta_0.keys()):
if "model" in key and key in theta_1:
dprint(f" key : {key}", verbose)
current_alpha_A = 1 - alpha
current_alpha_B = alpha
current_alpha_I = 0
# check weighted and U-Net or not
if weight_A is not None and 'model.diffusion_model.' in key:
# check block index
weight_index = -1
if 'time_embed' in key:
weight_index = 0 # before input blocks
elif '.out.' in key:
weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
else:
m = re_inp.search(key)
if m:
inp_idx = int(m.groups()[0])
weight_index = inp_idx
else:
m = re_mid.search(key)
if m:
weight_index = NUM_INPUT_BLOCKS
else:
m = re_out.search(key)
if m:
out_idx = int(m.groups()[0])
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
if weight_index >= NUM_TOTAL_BLOCKS:
print(f"error. illegal block index: {key}")
if weight_index >= 0:
current_alpha_A = weight_A[weight_index]
current_alpha_B = weight_B[weight_index]
current_alpha_I = 1 - current_alpha_A - current_alpha_B
if verbose:
print(f"weighted '{key}': A{current_alpha_A} B{current_alpha_B} I{current_alpha_I}")
# create I tensor
tensor_I_0 = torch.zeros_like(theta_0[key], dtype=theta_0[key].dtype)
_var1 = current_alpha_I * tensor_I_0
_var2 = current_alpha_A * theta_0[key]
_var3 = current_alpha_B * theta_1[key]
theta_0[key] = _var1 + _var2 + _var3
# theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
else:
dprint(f" key - {key}", verbose)
dprint(f"-- start Stage 2/2 --", verbose)
for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
if "model" in key and key not in theta_0:
dprint(f" key : {key}", verbose)
theta_0.update({key:theta_1[key]})
else:
dprint(f" key - {key}", verbose)
print("Saving...")
torch.save({"state_dict": theta_0}, output_file)
print("Done!")
return True, f"{output_file}<br>base_alpha applied [{count_target_of_basealpha}] times."

View File

@ -0,0 +1,426 @@
import gradio as gr
import os
import re
from modules import sd_models, shared
from tqdm import tqdm
from scripts.mbw_each.merge_block_weighted_mod import merge
from scripts.util.preset_weights import PresetWeights
from scripts.util.merge_history import MergeHistory
presetWeights = PresetWeights()
mergeHistory = MergeHistory()
def on_ui_tabs():
with gr.Column():
with gr.Row():
with gr.Column(variant="panel"):
with gr.Row():
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
btn_clear_weighted = gr.Button(value="Clear values")
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
with gr.Row():
txt_multi_process_cmd = gr.TextArea(label="Multi Proc Cmd", placeholder="Keep empty if dont use.")
html_output_block_weight_info = gr.HTML()
with gr.Column():
dd_preset_weight = gr.Dropdown(label="Preset_Weights", choices=presetWeights.get_preset_name_list())
txt_block_weight = gr.Text(label="Weight_values", placeholder="Put weight sets. float number x 25")
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
with gr.Row():
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=0)
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
with gr.Row():
dd_model_A = gr.Dropdown(label="Model_A", choices=sd_models.checkpoint_tiles())
dd_model_B = gr.Dropdown(label="Model_B", choices=sd_models.checkpoint_tiles())
txt_model_O = gr.Text(label="(O)Output Model Name")
with gr.Row():
with gr.Column():
sl_IN_A_00 = gr.Slider(label="IN_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_00")
sl_IN_A_01 = gr.Slider(label="IN_A_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_01")
sl_IN_A_02 = gr.Slider(label="IN_A_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_02")
sl_IN_A_03 = gr.Slider(label="IN_A_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_03")
sl_IN_A_04 = gr.Slider(label="IN_A_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_04")
sl_IN_A_05 = gr.Slider(label="IN_A_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_05")
sl_IN_A_06 = gr.Slider(label="IN_A_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_06")
sl_IN_A_07 = gr.Slider(label="IN_A_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_07")
sl_IN_A_08 = gr.Slider(label="IN_A_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_08")
sl_IN_A_09 = gr.Slider(label="IN_A_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_09")
sl_IN_A_10 = gr.Slider(label="IN_A_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_10")
sl_IN_A_11 = gr.Slider(label="IN_A_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_11")
with gr.Column():
sl_IN_B_00 = gr.Slider(label="IN_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_00")
sl_IN_B_01 = gr.Slider(label="IN_B_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_01")
sl_IN_B_02 = gr.Slider(label="IN_B_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_02")
sl_IN_B_03 = gr.Slider(label="IN_B_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_03")
sl_IN_B_04 = gr.Slider(label="IN_B_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_04")
sl_IN_B_05 = gr.Slider(label="IN_B_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_05")
sl_IN_B_06 = gr.Slider(label="IN_B_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_06")
sl_IN_B_07 = gr.Slider(label="IN_B_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_07")
sl_IN_B_08 = gr.Slider(label="IN_B_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_08")
sl_IN_B_09 = gr.Slider(label="IN_B_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_09")
sl_IN_B_10 = gr.Slider(label="IN_B_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_10")
sl_IN_B_11 = gr.Slider(label="IN_B_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_11")
with gr.Column():
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
sl_M_A_00 = gr.Slider(label="M_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_M_A_00")
with gr.Column():
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
sl_M_B_00 = gr.Slider(label="M_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_M_B_00")
with gr.Column():
sl_OUT_A_11 = gr.Slider(label="OUT_A_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_11")
sl_OUT_A_10 = gr.Slider(label="OUT_A_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_10")
sl_OUT_A_09 = gr.Slider(label="OUT_A_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_09")
sl_OUT_A_08 = gr.Slider(label="OUT_A_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_08")
sl_OUT_A_07 = gr.Slider(label="OUT_A_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_07")
sl_OUT_A_06 = gr.Slider(label="OUT_A_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_06")
sl_OUT_A_05 = gr.Slider(label="OUT_A_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_05")
sl_OUT_A_04 = gr.Slider(label="OUT_A_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_04")
sl_OUT_A_03 = gr.Slider(label="OUT_A_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_03")
sl_OUT_A_02 = gr.Slider(label="OUT_A_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_02")
sl_OUT_A_01 = gr.Slider(label="OUT_A_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_01")
sl_OUT_A_00 = gr.Slider(label="OUT_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_00")
with gr.Column():
sl_OUT_B_11 = gr.Slider(label="OUT_B_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_11")
sl_OUT_B_10 = gr.Slider(label="OUT_B_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_10")
sl_OUT_B_09 = gr.Slider(label="OUT_B_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_09")
sl_OUT_B_08 = gr.Slider(label="OUT_B_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_08")
sl_OUT_B_07 = gr.Slider(label="OUT_B_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_07")
sl_OUT_B_06 = gr.Slider(label="OUT_B_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_06")
sl_OUT_B_05 = gr.Slider(label="OUT_B_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_05")
sl_OUT_B_04 = gr.Slider(label="OUT_B_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_04")
sl_OUT_B_03 = gr.Slider(label="OUT_B_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_03")
sl_OUT_B_02 = gr.Slider(label="OUT_B_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_02")
sl_OUT_B_01 = gr.Slider(label="OUT_B_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_01")
sl_OUT_B_00 = gr.Slider(label="OUT_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_00")
# Footer
gr.HTML(
"""
<p style="font-size: 12px" align="right">
<b>Merge Block Weighted</b> extension by <a href="https://github.com/bbc-mc" target="_blank">bbc_mc</a><br />
<b>MBW Each</b> is experimental functions and <b>NO PROOF</b> of effectiveness.<br />
You can try it by own, to dig more deeper into Abyss ...<br />
</p>
"""
)
sl_A_IN = [
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11]
sl_A_MID = [sl_M_A_00]
sl_A_OUT = [
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11]
sl_B_IN = [
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11]
sl_B_MID = [sl_M_B_00]
sl_B_OUT = [
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11]
# Events
def onclick_btn_do_merge_block_weighted(
dd_model_A, dd_model_B, txt_multi_process_cmd,
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
sl_M_A_00,
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
sl_M_B_00,
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
):
base_alpha = sl_base_alpha
_weight_A = ",".join(
[str(x) for x in [
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
sl_M_A_00,
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
]])
_weight_B = ",".join(
[str(x) for x in [
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
sl_M_B_00,
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
]])
# debug output
print( "#### Merge Block Weighted ####")
if (not dd_model_A or not dd_model_B) and txt_multi_process_cmd == "":
_err_msg = f"ERROR: model not found. [{dd_model_A}][{dd_model_B}]"
print(_err_msg)
return gr.update(value=_err_msg)
ret_html = ""
if txt_multi_process_cmd != "":
# need multi-merge
_lines = txt_multi_process_cmd.split('\n')
print(f"check multi-merge. {len(_lines)} lines found.")
for line_index, _line in enumerate(_lines):
if _line == "":
continue
print(f"\n== merge line {line_index+1}/{len(_lines)} ==")
_items = [x.strip() for x in _line.split(",") if x != ""]
if len(_items) > 0:
ret_html += _run_merge(
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw, params=_items)
else:
_ret = f" multi-merge text found, but invalid params. skipped :[{_line}]"
ret_html += _ret
print(_ret)
else:
# normal merge
ret_html += _run_merge(
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw)
sd_models.list_models()
print( "#### All merge process done. ####")
return gr.update(value=f"{ret_html}")
btn_do_merge_block_weighted.click(
fn=onclick_btn_do_merge_block_weighted,
inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd] + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
outputs=[html_output_block_weight_info]
)
def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0, model_Output="", verbose=False, params=[]):
# generate output file name from param
model_A_info = sd_models.get_closet_checkpoint_match(model_0)
_model_A_name = "" if not model_A_info else model_A_info.model_name # expect "" or "aaa.ckpt [abcdefgh]"
model_B_info = sd_models.get_closet_checkpoint_match(model_1)
_model_B_name = "" if not model_B_info else model_B_info.model_name
model_O = f"bw-merge-{_model_A_name}-{_model_B_name}-{base_alpha}.ckpt" if model_Output == "" else model_Output
model_O = model_O if ".ckpt" in model_O else f"{model_O}.ckpt"
model_O = re.sub(r'[\\|/|:|?|.|"|<|>|\|\*]', '-', model_O)
if params and len(params) > 0:
for _item in params:
# expect "O=merge/test02, IN_B_00 = 0.12345" as params=["O=merge/test02", "IN_B_00 = 0.12345"]
if len(_item.split("=")) == 2:
_item_l = _item.split("=")[0].strip()
_item_r = _item.split("=")[1].strip()
if _item_r != "":
if _item_l.lower() == "model_a" or _item_l.lower() == "model_b":
_model_info = sd_models.get_closet_checkpoint_match(_item_r)
if _model_info:
_model_name = _model_info.title.split(" ")[0]
if _model_name and _model_name.strip() != "":
if _item_l.lower() == "model_a":
print(f" * Model changed: {model_0} -> {_model_info.title}")
model_0 = _model_info.title
elif _item_l.lower() == "model_b":
print(f" * Model changed: {model_1} -> {_model_info.title}")
model_1 = _model_info.title
elif _item_l.lower() == "preset_weights":
_weights = presetWeights.find_weight_by_name(_item_r)
if _weights != "" and len(_weights.split(',')) == 25:
print(f" * Weights changed by preset-name: {_item_r}")
weight_B = _weights
weight_A = ",".join([str(1-float(x)) for x in _weights.split(',')])
else:
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
elif _item_l.lower() == "weight_values":
_weights = _item_r.strip()
if _weights != "" and len(_weights.split(' ')) == 25: # this is work-around to use space as separator. Double-meaning issue on commna which already used as value separator and weights separator.
print(f" * Weights changed: {_item_r}")
weight_B = _weights
weight_A = ",".join([str(1-float(x)) for x in _weights.split(' ')])
else:
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
elif _item_l.lower() == "base_alpha":
if float(_item_r) >= 0:
print(f" * base_alpha changed: {base_alpha} -> {_item_r}")
base_alpha = float(_item_r)
elif _item_l.upper() == "O":
_item_r = _item_r if ".ckpt" in _item_r else f"{_item_r}.ckpt"
print(f" * Output filename changed:[{model_O}] -> [{_item_r}]")
model_O = _item_r
elif len(_item_l.split("_")) == 3:
_IMO = _item_l.split("_")[0]
_AB = _item_l.split("_")[1]
_NUM = _item_l.split("_")[2]
_index = int(_NUM)
_index = _index + 0 if _IMO == "IN" else _index
_index = _index + 12 if _IMO == "M" else _index
_index = _index + 13 if _IMO == "OUT" else _index
def _apply_val(key, weight, index, new_value):
_weight = [x.strip() for x in weight.split(",")]
_new_weight = _weight[:]
_new_weight[index] = new_value
_new_weight = ",".join(_new_weight)
print(f" * weight_{key} changed:[{weight}]")
print(f" -> [{_new_weight}]")
return _new_weight
if _AB == "A":
weight_A = _apply_val(_AB, weight_A, _index, _item_r)
elif _AB == "B":
weight_B = _apply_val(_AB, weight_B, _index, _item_r)
else:
print(f" * Waring: uncaught param found. ignored. [{_item_l}][{_item_r}]")
#
# Prepare params before run merge
#
output_file = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O)
#
# Check params
#
if not os.path.exists(os.path.dirname(output_file)):
_err_msg = f"WARNING: target path not found: {os.path.dirname(output_file)}. skipped."
print(_err_msg)
return _err_msg + "<br />"
if not allow_overwrite:
if os.path.exists(output_file):
_err_msg = f"WARNING: output_file already exists. overwrite not allowed. skipped."
print(_err_msg)
return _err_msg + "<br />"
# debug output
print(f" model_0 : {model_0}")
print(f" model_1 : {model_1}")
print(f" model_Out : {model_O}")
print(f" base_alpha : {base_alpha}")
print(f" output_file: {output_file}")
print(f" weight_A : {weight_A}")
print(f" weight_B : {weight_B}")
result, ret_message = merge(
weight_A=weight_A, weight_B=weight_B, model_0=model_0, model_1=model_1,
allow_overwrite=allow_overwrite, base_alpha=base_alpha, output_file=output_file, verbose=verbose)
if result:
ret_html = f"merged. {model_0} + {model_1} = {model_O} <br>"
print("merged.")
else:
ret_html = ret_message
print("merge failed.")
# save log to history.tsv
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
model_O_hash = "" if not model_O_info else model_O_info.hash
_names = presetWeights.find_names_by_weight(weight_A)
if _names and len(_names) > 0:
weight_name = _names[0]
else:
weight_name = ""
mergeHistory.add_history(model_0, model_1, model_O, model_O_hash, base_alpha, weight_A, weight_B, weight_name)
return ret_html
btn_clear_weighted.click(
fn=lambda: [gr.update(value=0.5) for _ in range(25*2)],
inputs=[],
outputs=[
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
sl_M_A_00,
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
sl_M_B_00,
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
]
)
def on_change_dd_preset_weight(dd_preset_weight):
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
_ret = on_btn_apply_block_weight_from_txt(_weights)
return [gr.update(value=_weights)] + _ret
dd_preset_weight.change(
fn=on_change_dd_preset_weight,
inputs=[dd_preset_weight],
outputs=[
txt_block_weight,
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
sl_M_A_00,
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
sl_M_B_00,
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
]
)
def on_btn_reload_checkpoint_mbw():
sd_models.list_models()
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
btn_reload_checkpoint_mbw.click(
fn=on_btn_reload_checkpoint_mbw,
inputs=[],
outputs=[dd_model_A, dd_model_B]
)
def on_btn_apply_block_weight_from_txt(txt_block_weight):
if not txt_block_weight or txt_block_weight == "":
return [gr.update() for _ in range(25*2)]
_list = [x.strip() for x in txt_block_weight.split(",")]
if(len(_list) != 25):
return [gr.update() for _ in range(25*2)]
return [gr.update(value=str(1-float(x))) for x in _list] + [gr.update(value=x) for x in _list]
btn_apply_block_weithg_from_txt.click(
fn=on_btn_apply_block_weight_from_txt,
inputs=[txt_block_weight],
outputs=[
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
sl_M_A_00,
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
sl_M_B_00,
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
]
)

View File

@ -8,17 +8,12 @@
import os
import gradio as gr
from modules import scripts, script_callbacks
from modules import sd_models, shared
from modules import script_callbacks
from scripts.merge_block_weighted import merge
from scripts.merge_history import MergeHistory
from scripts.preset_weights import PresetWeights
path_root = scripts.basedir()
from scripts.mbw import ui_mbw
from scripts.mbw_each import ui_mbw_each
mergeHistory = MergeHistory()
presetWeights = PresetWeights()
#
# UI callback
@ -26,204 +21,11 @@ presetWeights = PresetWeights()
def on_ui_tabs():
with gr.Blocks() as main_block:
with gr.Column():
with gr.Row():
with gr.Column(variant="panel"):
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
btn_clear_weighted = gr.Button(value="Clear values")
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
html_output_block_weight_info = gr.HTML()
with gr.Column():
dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list())
txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25")
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
with gr.Row():
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=1)
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
with gr.Row():
model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles())
model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles())
txt_model_O = gr.Text(label="Output Model Name")
with gr.Row():
with gr.Column():
sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
with gr.Column():
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="mbw_sl_M00")
with gr.Column():
sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN = [
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11]
sl_MID = [sl_M_00]
sl_OUT = [
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11]
with gr.Tab("MBW", elem_id="tab_mbw"):
ui_mbw.on_ui_tabs()
# Events
def onclick_btn_do_merge_block_weighted(
model_A, model_B,
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
):
_weights = ",".join(
[str(x) for x in [
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11
]])
#
if not model_A or not model_B:
return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]")
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
model_A_info = sd_models.get_closet_checkpoint_match(model_A)
if model_A_info:
_model_A_name = model_A_info.model_name
else:
_model_A_name = ""
model_B_info = sd_models.get_closet_checkpoint_match(model_B)
if model_B_info:
_model_B_info = model_B_info.model_name
else:
_model_B_info = ""
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
if ".ckpt" not in model_O:
model_O = model_O + ".ckpt"
_output = os.path.join(ckpt_dir, model_O)
# debug output
print( "#### Merge Block Weighted ####")
if not chk_allow_overwrite:
if os.path.exists(_output):
_err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort."
print(_err_msg)
return gr.update(value=f"{_err_msg} [{_output}]")
print(f"model_0 : {model_A}")
print(f"model_1 : {model_B}")
print(f"base_alpha : {sl_base_alpha}")
print(f"output_file: {_output}")
print(f"weights : {_weights}")
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw)
sd_models.list_models()
if result:
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{model_O}"
else:
ret_html = ret_message
# save log to history.tsv
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
model_O_hash = "" if not model_O_info else model_O_info.hash
_names = presetWeights.find_names_by_weight(_weights)
if _names and len(_names) > 0:
weight_name = _names[0]
else:
weight_name = ""
mergeHistory.add_history(model_A, model_B, model_O, model_O_hash, sl_base_alpha, _weights, weight_name)
return gr.update(value=f"{ret_html}")
btn_do_merge_block_weighted.click(
fn=onclick_btn_do_merge_block_weighted,
inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
outputs=[html_output_block_weight_info]
)
btn_clear_weighted.click(
fn=lambda: [gr.update(value=0.5) for _ in range(25)],
inputs=[],
outputs=[
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
]
)
def on_change_dd_preset_weight(dd_preset_weight):
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
_ret = on_btn_apply_block_weithg_from_txt(_weights)
return [gr.update(value=_weights)] + _ret
dd_preset_weight.change(
fn=on_change_dd_preset_weight,
inputs=[dd_preset_weight],
outputs=[txt_block_weight,
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
]
)
def on_btn_reload_checkpoint_mbw():
sd_models.list_models()
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
btn_reload_checkpoint_mbw.click(
fn=on_btn_reload_checkpoint_mbw,
inputs=[],
outputs=[model_A, model_B]
)
def on_btn_apply_block_weithg_from_txt(txt_block_weight):
if not txt_block_weight or txt_block_weight == "":
return [gr.update() for _ in range(25)]
_list = [x.strip() for x in txt_block_weight.split(",")]
if(len(_list) != 25):
return [gr.update() for _ in range(25)]
return [gr.update(value=x) for x in _list]
btn_apply_block_weithg_from_txt.click(
fn=on_btn_apply_block_weithg_from_txt,
inputs=[txt_block_weight],
outputs=[
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
]
)
with gr.Tab("MBW Each", elem_id="tab_mbw_each"):
ui_mbw_each.on_ui_tabs()
# return required as (gradio_component, title, elem_id)
return (main_block, "Merge Block Weighted", "merge_block_weighted"),

View File

@ -1,40 +0,0 @@
#
#
#
import os
from csv import DictWriter, writer
from modules import scripts
CSV_FILE_PATH = "csv/history.tsv"
HEADERS = ["model_A", "model_A_hash", "model_B", "model_B_hash", "model_O", "model_O_hash", "base_alpha", "weight_name", "weight_values"]
path_root = scripts.basedir()
class MergeHistory():
def __init__(self):
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
def add_history(self, model_A, model_B, model_O, model_O_hash, sl_base_alpha, weight_values, weight_name=""):
_history_dict = {}
_history_dict.update({
"model_A": f"{os.path.basename(model_A.split(' ')[0])}",
"model_A_hash": f"{model_A.split(' ')[1]}",
"model_B": f"{os.path.basename(model_B.split(' ')[0])}",
"model_B_hash": f"{model_B.split(' ')[1]}",
"model_O": model_O,
"model_O_hash": model_O_hash,
"base_alpha": sl_base_alpha,
"weight_name": weight_name,
"weight_values": weight_values,
})
if not os.path.exists(self.filepath):
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
wr = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
wr.writeheader()
# save to file
with open(self.filepath, "a", newline="", encoding='utf-8') as f:
dictwriter = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
dictwriter.writerow(_history_dict)

View File

@ -0,0 +1,61 @@
#
#
#
import os
import datetime
from csv import DictWriter, DictReader
from modules import scripts
CSV_FILE_PATH = "csv/history.tsv"
HEADERS = ["model_A", "model_A_hash", "model_B", "model_B_hash", "model_O", "model_O_hash", "base_alpha", "weight_name", "weight_values", "weight_values2", "datetime"]
path_root = scripts.basedir()
class MergeHistory():
def __init__(self):
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
if os.path.exists(self.filepath):
self.update_header()
def add_history(self, model_A, model_B, model_O, model_O_hash, sl_base_alpha, weight_value_A, weight_value_B, weight_name=""):
_history_dict = {}
_history_dict.update({
"model_A": f"{os.path.basename(model_A.split(' ')[0])}",
"model_A_hash": f"{model_A.split(' ')[1]}",
"model_B": f"{os.path.basename(model_B.split(' ')[0])}",
"model_B_hash": f"{model_B.split(' ')[1]}",
"model_O": model_O,
"model_O_hash": model_O_hash,
"base_alpha": sl_base_alpha,
"weight_name": weight_name,
"weight_values": weight_value_A,
"weight_values2": weight_value_B,
"datetime": f"{datetime.datetime.now()}"
})
if not os.path.exists(self.filepath):
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
dw.writeheader()
# save to file
with open(self.filepath, "a", newline="", encoding='utf-8') as f:
dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
dw.writerow(_history_dict)
def update_header(self):
hist_data = []
if os.path.exists(self.filepath):
# check header in case HEADERS updated
with open(self.filepath, "r", newline="", encoding="utf-8") as f:
dr = DictReader(f, delimiter='\t')
new_header = [ x for x in HEADERS if x not in dr.fieldnames ]
if len(new_header) > 0:
# need update.
hist_data = [ x for x in dr]
if len(hist_data) > 0:
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
dw.writeheader()
dw.writerows(hist_data)

View File

@ -8,14 +8,24 @@ from modules import scripts
CSV_FILE_PATH = "csv/preset.tsv"
MYPRESET_PATH = "csv/preset_own.tsv"
HEADER = ["preset_name", "preset_weights"]
path_root = scripts.basedir()
class PresetWeights():
def __init__(self):
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
self.presets = {}
with open(self.filepath, "r") as f:
if os.path.exists(os.path.join(path_root, MYPRESET_PATH)):
with open(os.path.join(path_root, MYPRESET_PATH), "r") as f:
reader = DictReader(f, delimiter="\t")
lines_dict = [row for row in reader]
for line_dict in lines_dict:
_w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")])
self.presets.update({line_dict["preset_name"]: _w})
with open(os.path.join(path_root, CSV_FILE_PATH), "r") as f:
reader = DictReader(f, delimiter="\t")
lines_dict = [row for row in reader]
for line_dict in lines_dict:
@ -26,7 +36,7 @@ class PresetWeights():
return [k for k in self.presets.keys()]
def find_weight_by_name(self, preset_name=""):
if preset_name and preset_name != "" and preset_name in self.presets:
if preset_name and preset_name != "" and preset_name in self.presets.keys():
return self.presets.get(preset_name, ",".join(["0.5" for _ in range(25)]))
else:
return ""

View File

@ -1,4 +1,29 @@
#mbw_sl_M00 {
#mbw_sl_M00, #mbw_sl_a_M00, #mbw_sl_b_M00 {
bottom:0;
position:absolute;
}
}
/*
#sl_IN_A_00, #sl_IN_A_01, #sl_IN_A_02, #sl_IN_A_03, #sl_IN_A_04, #sl_IN_A_05, #sl_IN_A_06, #sl_IN_A_07, #sl_IN_A_08, #sl_IN_A_09, #sl_IN_A_10, #sl_IN_A_11 {
width: 220;
}
#sl_IN_B_00, #sl_IN_B_01, #sl_IN_B_02, #sl_IN_B_03, #sl_IN_B_04, #sl_IN_B_05, #sl_IN_B_06, #sl_IN_B_07, #sl_IN_B_08, #sl_IN_B_09, #sl_IN_B_10, #sl_IN_B_11 {
width: 220;
}
*/
#sl_M_A_00, #sl_M_B_00 {
bottom:0;
position:absolute;
}
/*
#sl_OUT_A_00, #sl_OUT_A_01, #sl_OUT_A_02, #sl_OUT_A_03, #sl_OUT_A_04, #sl_OUT_A_05, #sl_OUT_A_06, #sl_OUT_A_07, #sl_OUT_A_08, #sl_OUT_A_09, #sl_OUT_A_10, #sl_OUT_A_11 {
width: 220;
}
#sl_OUT_B_00, #sl_OUT_B_01, #sl_OUT_B_02, #sl_OUT_B_03, #sl_OUT_B_04, #sl_OUT_B_05, #sl_OUT_B_06, #sl_OUT_B_07, #sl_OUT_B_08, #sl_OUT_B_09, #sl_OUT_B_10, #sl_OUT_B_11 {
width: 220;
}
*/