Big update with functions, keeping current function and UIs
add: add tab to divide feature. Current remain on "MBW" and new feature comes on "MBW Each" refact: divide scripts to each feature feature: "MBW Each" function, allow set percentage of model A and model B for each layer. feature: add support for csv/preset_own.tsv, not included in git (so not overwrite by update) feature: add datetime column on logfile feature: add Multi-Merge feature (#5)exp/feature-each-merge
parent
63ba0926bb
commit
75a31b481a
|
|
@ -1 +1,5 @@
|
||||||
/csv/history.tsv
|
/csv/history.tsv
|
||||||
|
/csv/preset_own.tsv
|
||||||
|
|
||||||
|
#
|
||||||
|
_*
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
//
|
||||||
|
// fix position of sliders
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// UI
|
||||||
|
//
|
||||||
|
onUiUpdate(function () {
|
||||||
|
// check Extension loaded
|
||||||
|
if (gradioApp().querySelector("div#tab_mbw_each") == null ) return;
|
||||||
|
|
||||||
|
// check already done
|
||||||
|
//if (gradioApp().querySelector("#div_mdl_size_a") != null) return;
|
||||||
|
|
||||||
|
// apply
|
||||||
|
let _style = "min-width: min(200px, 100%); flex-grow: 1";
|
||||||
|
gradioApp().querySelector("#sl_IN_A_00").parentElement.parentElement.setAttribute("style", _style)
|
||||||
|
gradioApp().querySelector("#sl_IN_B_00").parentElement.parentElement.setAttribute("style", _style)
|
||||||
|
gradioApp().querySelector("#sl_M_A_00").parentElement.parentElement.setAttribute("style", _style)
|
||||||
|
gradioApp().querySelector("#sl_M_B_00").parentElement.parentElement.setAttribute("style", _style)
|
||||||
|
gradioApp().querySelector("#sl_OUT_A_00").parentElement.parentElement.setAttribute("style", _style)
|
||||||
|
gradioApp().querySelector("#sl_OUT_B_00").parentElement.parentElement.setAttribute("style", _style)
|
||||||
|
});
|
||||||
|
|
@ -0,0 +1,215 @@
|
||||||
|
import gradio as gr
|
||||||
|
import os
|
||||||
|
|
||||||
|
from modules import sd_models, shared
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from scripts.mbw.merge_block_weighted import merge
|
||||||
|
from scripts.util.preset_weights import PresetWeights
|
||||||
|
from scripts.util.merge_history import MergeHistory
|
||||||
|
|
||||||
|
presetWeights = PresetWeights()
|
||||||
|
mergeHistory = MergeHistory()
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_tabs():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(variant="panel"):
|
||||||
|
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
|
||||||
|
btn_clear_weighted = gr.Button(value="Clear values")
|
||||||
|
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
|
||||||
|
html_output_block_weight_info = gr.HTML()
|
||||||
|
with gr.Column():
|
||||||
|
dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list())
|
||||||
|
txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25")
|
||||||
|
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
|
||||||
|
with gr.Row():
|
||||||
|
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=1)
|
||||||
|
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
|
||||||
|
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
|
||||||
|
with gr.Row():
|
||||||
|
model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles())
|
||||||
|
model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles())
|
||||||
|
txt_model_O = gr.Text(label="Output Model Name")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
with gr.Column():
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="mbw_sl_M00")
|
||||||
|
with gr.Column():
|
||||||
|
sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||||
|
|
||||||
|
sl_IN = [
|
||||||
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||||
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11]
|
||||||
|
sl_MID = [sl_M_00]
|
||||||
|
sl_OUT = [
|
||||||
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||||
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11]
|
||||||
|
|
||||||
|
# Events
|
||||||
|
def onclick_btn_do_merge_block_weighted(
|
||||||
|
model_A, model_B,
|
||||||
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||||
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||||
|
sl_M_00,
|
||||||
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||||
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||||
|
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
|
||||||
|
):
|
||||||
|
_weights = ",".join(
|
||||||
|
[str(x) for x in [
|
||||||
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||||
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||||
|
sl_M_00,
|
||||||
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||||
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11
|
||||||
|
]])
|
||||||
|
#
|
||||||
|
if not model_A or not model_B:
|
||||||
|
return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]")
|
||||||
|
|
||||||
|
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||||
|
model_A_info = sd_models.get_closet_checkpoint_match(model_A)
|
||||||
|
if model_A_info:
|
||||||
|
_model_A_name = model_A_info.model_name
|
||||||
|
else:
|
||||||
|
_model_A_name = ""
|
||||||
|
model_B_info = sd_models.get_closet_checkpoint_match(model_B)
|
||||||
|
if model_B_info:
|
||||||
|
_model_B_info = model_B_info.model_name
|
||||||
|
else:
|
||||||
|
_model_B_info = ""
|
||||||
|
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
|
||||||
|
if ".ckpt" not in model_O:
|
||||||
|
model_O = model_O + ".ckpt"
|
||||||
|
|
||||||
|
_output = os.path.join(ckpt_dir, model_O)
|
||||||
|
# debug output
|
||||||
|
print( "#### Merge Block Weighted ####")
|
||||||
|
if not chk_allow_overwrite:
|
||||||
|
if os.path.exists(_output):
|
||||||
|
_err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort."
|
||||||
|
print(_err_msg)
|
||||||
|
return gr.update(value=f"{_err_msg} [{_output}]")
|
||||||
|
print(f"model_0 : {model_A}")
|
||||||
|
print(f"model_1 : {model_B}")
|
||||||
|
print(f"base_alpha : {sl_base_alpha}")
|
||||||
|
print(f"output_file: {_output}")
|
||||||
|
print(f"weights : {_weights}")
|
||||||
|
|
||||||
|
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw)
|
||||||
|
|
||||||
|
sd_models.list_models()
|
||||||
|
if result:
|
||||||
|
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{model_O}"
|
||||||
|
else:
|
||||||
|
ret_html = ret_message
|
||||||
|
|
||||||
|
# save log to history.tsv
|
||||||
|
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
|
||||||
|
model_O_hash = "" if not model_O_info else model_O_info.hash
|
||||||
|
_names = presetWeights.find_names_by_weight(_weights)
|
||||||
|
if _names and len(_names) > 0:
|
||||||
|
weight_name = _names[0]
|
||||||
|
else:
|
||||||
|
weight_name = ""
|
||||||
|
mergeHistory.add_history(model_A, model_B, model_O, model_O_hash, sl_base_alpha, _weights, "", weight_name)
|
||||||
|
|
||||||
|
return gr.update(value=f"{ret_html}")
|
||||||
|
btn_do_merge_block_weighted.click(
|
||||||
|
fn=onclick_btn_do_merge_block_weighted,
|
||||||
|
inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
|
||||||
|
outputs=[html_output_block_weight_info]
|
||||||
|
)
|
||||||
|
|
||||||
|
btn_clear_weighted.click(
|
||||||
|
fn=lambda: [gr.update(value=0.5) for _ in range(25)],
|
||||||
|
inputs=[],
|
||||||
|
outputs=[
|
||||||
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||||
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||||
|
sl_M_00,
|
||||||
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||||
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_change_dd_preset_weight(dd_preset_weight):
|
||||||
|
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
|
||||||
|
_ret = on_btn_apply_block_weithg_from_txt(_weights)
|
||||||
|
return [gr.update(value=_weights)] + _ret
|
||||||
|
dd_preset_weight.change(
|
||||||
|
fn=on_change_dd_preset_weight,
|
||||||
|
inputs=[dd_preset_weight],
|
||||||
|
outputs=[txt_block_weight,
|
||||||
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||||
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||||
|
sl_M_00,
|
||||||
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||||
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_btn_reload_checkpoint_mbw():
|
||||||
|
sd_models.list_models()
|
||||||
|
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
|
||||||
|
btn_reload_checkpoint_mbw.click(
|
||||||
|
fn=on_btn_reload_checkpoint_mbw,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[model_A, model_B]
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_btn_apply_block_weithg_from_txt(txt_block_weight):
|
||||||
|
if not txt_block_weight or txt_block_weight == "":
|
||||||
|
return [gr.update() for _ in range(25)]
|
||||||
|
_list = [x.strip() for x in txt_block_weight.split(",")]
|
||||||
|
if(len(_list) != 25):
|
||||||
|
return [gr.update() for _ in range(25)]
|
||||||
|
return [gr.update(value=x) for x in _list]
|
||||||
|
btn_apply_block_weithg_from_txt.click(
|
||||||
|
fn=on_btn_apply_block_weithg_from_txt,
|
||||||
|
inputs=[txt_block_weight],
|
||||||
|
outputs=[
|
||||||
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||||
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||||
|
sl_M_00,
|
||||||
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||||
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,160 @@
|
||||||
|
# from https://note.com/kohya_ss/n/n9a485a066d5b
|
||||||
|
# kohya_ss
|
||||||
|
# original code: https://github.com/eyriewow/merge-models
|
||||||
|
|
||||||
|
# use them as base of this code
|
||||||
|
# 2022/12/15
|
||||||
|
# bbc-mc
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from modules import sd_models, shared
|
||||||
|
|
||||||
|
|
||||||
|
NUM_INPUT_BLOCKS = 12
|
||||||
|
NUM_MID_BLOCK = 1
|
||||||
|
NUM_OUTPUT_BLOCKS = 12
|
||||||
|
NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
|
||||||
|
|
||||||
|
|
||||||
|
def dprint(str, flg):
|
||||||
|
if flg:
|
||||||
|
print(str)
|
||||||
|
|
||||||
|
|
||||||
|
def merge(weight_A:list, weight_B:list, model_0, model_1, device="cpu", base_alpha=0.5,
|
||||||
|
output_file="", allow_overwrite=False, verbose=False):
|
||||||
|
|
||||||
|
def _check_arg_weight(weight):
|
||||||
|
if weight is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
_weight = [float(w) for w in weight.split(",")]
|
||||||
|
if len(_weight) != NUM_TOTAL_BLOCKS:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return _weight
|
||||||
|
|
||||||
|
weight_A = _check_arg_weight(weight_A)
|
||||||
|
if weight_A is None:
|
||||||
|
_err_msg = f"Weight A invalid. program abort. weights value must be {NUM_TOTAL_BLOCKS}."
|
||||||
|
print(_err_msg)
|
||||||
|
return False, _err_msg
|
||||||
|
weight_B = _check_arg_weight(weight_B)
|
||||||
|
if weight_B is None:
|
||||||
|
_err_msg = f"Weight B invalid. program abort. weights value must be {NUM_TOTAL_BLOCKS}."
|
||||||
|
print(_err_msg)
|
||||||
|
return False, _err_msg
|
||||||
|
|
||||||
|
device = device if device in ["cpu", "cuda"] else "cpu"
|
||||||
|
|
||||||
|
alpha = base_alpha
|
||||||
|
if not output_file or output_file == "":
|
||||||
|
output_file = f'bw-{model_0}-{model_1}-{str(alpha)[2:] + "0"}.ckpt'
|
||||||
|
else:
|
||||||
|
output_file = output_file if ".ckpt" in output_file else output_file + ".ckpt"
|
||||||
|
|
||||||
|
# check if output file already exists
|
||||||
|
if os.path.isfile(output_file) and not allow_overwrite:
|
||||||
|
_err_msg = f"Exiting... [{output_file}]"
|
||||||
|
print(_err_msg)
|
||||||
|
return False, _err_msg
|
||||||
|
|
||||||
|
def load_model(_model, _device):
|
||||||
|
model_info = sd_models.get_closet_checkpoint_match(_model)
|
||||||
|
if model_info:
|
||||||
|
model_file = model_info.filename
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||||
|
if cache_enabled and model_info in sd_models.checkpoints_loaded:
|
||||||
|
print(" load from cache")
|
||||||
|
return sd_models.checkpoints_loaded[model_info].copy()
|
||||||
|
else:
|
||||||
|
print(" loading ...")
|
||||||
|
return sd_models.read_state_dict(model_file, map_location=_device)
|
||||||
|
|
||||||
|
print("loading", model_0)
|
||||||
|
theta_0 = load_model(model_0, device)
|
||||||
|
|
||||||
|
print("loading", model_1)
|
||||||
|
theta_1 = load_model(model_1, device)
|
||||||
|
|
||||||
|
re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
|
||||||
|
re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
|
||||||
|
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
|
||||||
|
|
||||||
|
dprint(f"-- start Stage 1/2 --", verbose)
|
||||||
|
count_target_of_basealpha = 0
|
||||||
|
for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not verbose else theta_0.keys()):
|
||||||
|
if "model" in key and key in theta_1:
|
||||||
|
dprint(f" key : {key}", verbose)
|
||||||
|
|
||||||
|
current_alpha_A = 1 - alpha
|
||||||
|
current_alpha_B = alpha
|
||||||
|
current_alpha_I = 0
|
||||||
|
|
||||||
|
# check weighted and U-Net or not
|
||||||
|
if weight_A is not None and 'model.diffusion_model.' in key:
|
||||||
|
# check block index
|
||||||
|
weight_index = -1
|
||||||
|
|
||||||
|
if 'time_embed' in key:
|
||||||
|
weight_index = 0 # before input blocks
|
||||||
|
elif '.out.' in key:
|
||||||
|
weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
|
||||||
|
else:
|
||||||
|
m = re_inp.search(key)
|
||||||
|
if m:
|
||||||
|
inp_idx = int(m.groups()[0])
|
||||||
|
weight_index = inp_idx
|
||||||
|
else:
|
||||||
|
m = re_mid.search(key)
|
||||||
|
if m:
|
||||||
|
weight_index = NUM_INPUT_BLOCKS
|
||||||
|
else:
|
||||||
|
m = re_out.search(key)
|
||||||
|
if m:
|
||||||
|
out_idx = int(m.groups()[0])
|
||||||
|
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
|
||||||
|
|
||||||
|
if weight_index >= NUM_TOTAL_BLOCKS:
|
||||||
|
print(f"error. illegal block index: {key}")
|
||||||
|
if weight_index >= 0:
|
||||||
|
current_alpha_A = weight_A[weight_index]
|
||||||
|
current_alpha_B = weight_B[weight_index]
|
||||||
|
current_alpha_I = 1 - current_alpha_A - current_alpha_B
|
||||||
|
if verbose:
|
||||||
|
print(f"weighted '{key}': A{current_alpha_A} B{current_alpha_B} I{current_alpha_I}")
|
||||||
|
|
||||||
|
# create I tensor
|
||||||
|
tensor_I_0 = torch.zeros_like(theta_0[key], dtype=theta_0[key].dtype)
|
||||||
|
_var1 = current_alpha_I * tensor_I_0
|
||||||
|
_var2 = current_alpha_A * theta_0[key]
|
||||||
|
_var3 = current_alpha_B * theta_1[key]
|
||||||
|
theta_0[key] = _var1 + _var2 + _var3
|
||||||
|
|
||||||
|
# theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
|
||||||
|
|
||||||
|
else:
|
||||||
|
dprint(f" key - {key}", verbose)
|
||||||
|
|
||||||
|
dprint(f"-- start Stage 2/2 --", verbose)
|
||||||
|
for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
|
||||||
|
if "model" in key and key not in theta_0:
|
||||||
|
dprint(f" key : {key}", verbose)
|
||||||
|
theta_0.update({key:theta_1[key]})
|
||||||
|
else:
|
||||||
|
dprint(f" key - {key}", verbose)
|
||||||
|
|
||||||
|
print("Saving...")
|
||||||
|
|
||||||
|
torch.save({"state_dict": theta_0}, output_file)
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
return True, f"{output_file}<br>base_alpha applied [{count_target_of_basealpha}] times."
|
||||||
|
|
@ -0,0 +1,426 @@
|
||||||
|
import gradio as gr
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
from modules import sd_models, shared
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from scripts.mbw_each.merge_block_weighted_mod import merge
|
||||||
|
from scripts.util.preset_weights import PresetWeights
|
||||||
|
from scripts.util.merge_history import MergeHistory
|
||||||
|
|
||||||
|
presetWeights = PresetWeights()
|
||||||
|
mergeHistory = MergeHistory()
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_tabs():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(variant="panel"):
|
||||||
|
with gr.Row():
|
||||||
|
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
|
||||||
|
btn_clear_weighted = gr.Button(value="Clear values")
|
||||||
|
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
|
||||||
|
with gr.Row():
|
||||||
|
txt_multi_process_cmd = gr.TextArea(label="Multi Proc Cmd", placeholder="Keep empty if dont use.")
|
||||||
|
html_output_block_weight_info = gr.HTML()
|
||||||
|
with gr.Column():
|
||||||
|
dd_preset_weight = gr.Dropdown(label="Preset_Weights", choices=presetWeights.get_preset_name_list())
|
||||||
|
txt_block_weight = gr.Text(label="Weight_values", placeholder="Put weight sets. float number x 25")
|
||||||
|
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
|
||||||
|
with gr.Row():
|
||||||
|
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=0)
|
||||||
|
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
|
||||||
|
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
|
||||||
|
with gr.Row():
|
||||||
|
dd_model_A = gr.Dropdown(label="Model_A", choices=sd_models.checkpoint_tiles())
|
||||||
|
dd_model_B = gr.Dropdown(label="Model_B", choices=sd_models.checkpoint_tiles())
|
||||||
|
txt_model_O = gr.Text(label="(O)Output Model Name")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
sl_IN_A_00 = gr.Slider(label="IN_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_00")
|
||||||
|
sl_IN_A_01 = gr.Slider(label="IN_A_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_01")
|
||||||
|
sl_IN_A_02 = gr.Slider(label="IN_A_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_02")
|
||||||
|
sl_IN_A_03 = gr.Slider(label="IN_A_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_03")
|
||||||
|
sl_IN_A_04 = gr.Slider(label="IN_A_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_04")
|
||||||
|
sl_IN_A_05 = gr.Slider(label="IN_A_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_05")
|
||||||
|
sl_IN_A_06 = gr.Slider(label="IN_A_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_06")
|
||||||
|
sl_IN_A_07 = gr.Slider(label="IN_A_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_07")
|
||||||
|
sl_IN_A_08 = gr.Slider(label="IN_A_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_08")
|
||||||
|
sl_IN_A_09 = gr.Slider(label="IN_A_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_09")
|
||||||
|
sl_IN_A_10 = gr.Slider(label="IN_A_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_10")
|
||||||
|
sl_IN_A_11 = gr.Slider(label="IN_A_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_11")
|
||||||
|
with gr.Column():
|
||||||
|
sl_IN_B_00 = gr.Slider(label="IN_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_00")
|
||||||
|
sl_IN_B_01 = gr.Slider(label="IN_B_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_01")
|
||||||
|
sl_IN_B_02 = gr.Slider(label="IN_B_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_02")
|
||||||
|
sl_IN_B_03 = gr.Slider(label="IN_B_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_03")
|
||||||
|
sl_IN_B_04 = gr.Slider(label="IN_B_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_04")
|
||||||
|
sl_IN_B_05 = gr.Slider(label="IN_B_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_05")
|
||||||
|
sl_IN_B_06 = gr.Slider(label="IN_B_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_06")
|
||||||
|
sl_IN_B_07 = gr.Slider(label="IN_B_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_07")
|
||||||
|
sl_IN_B_08 = gr.Slider(label="IN_B_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_08")
|
||||||
|
sl_IN_B_09 = gr.Slider(label="IN_B_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_09")
|
||||||
|
sl_IN_B_10 = gr.Slider(label="IN_B_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_10")
|
||||||
|
sl_IN_B_11 = gr.Slider(label="IN_B_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_11")
|
||||||
|
with gr.Column():
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
sl_M_A_00 = gr.Slider(label="M_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_M_A_00")
|
||||||
|
with gr.Column():
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
gr.Slider(visible=False)
|
||||||
|
sl_M_B_00 = gr.Slider(label="M_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_M_B_00")
|
||||||
|
with gr.Column():
|
||||||
|
sl_OUT_A_11 = gr.Slider(label="OUT_A_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_11")
|
||||||
|
sl_OUT_A_10 = gr.Slider(label="OUT_A_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_10")
|
||||||
|
sl_OUT_A_09 = gr.Slider(label="OUT_A_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_09")
|
||||||
|
sl_OUT_A_08 = gr.Slider(label="OUT_A_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_08")
|
||||||
|
sl_OUT_A_07 = gr.Slider(label="OUT_A_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_07")
|
||||||
|
sl_OUT_A_06 = gr.Slider(label="OUT_A_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_06")
|
||||||
|
sl_OUT_A_05 = gr.Slider(label="OUT_A_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_05")
|
||||||
|
sl_OUT_A_04 = gr.Slider(label="OUT_A_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_04")
|
||||||
|
sl_OUT_A_03 = gr.Slider(label="OUT_A_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_03")
|
||||||
|
sl_OUT_A_02 = gr.Slider(label="OUT_A_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_02")
|
||||||
|
sl_OUT_A_01 = gr.Slider(label="OUT_A_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_01")
|
||||||
|
sl_OUT_A_00 = gr.Slider(label="OUT_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_00")
|
||||||
|
with gr.Column():
|
||||||
|
sl_OUT_B_11 = gr.Slider(label="OUT_B_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_11")
|
||||||
|
sl_OUT_B_10 = gr.Slider(label="OUT_B_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_10")
|
||||||
|
sl_OUT_B_09 = gr.Slider(label="OUT_B_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_09")
|
||||||
|
sl_OUT_B_08 = gr.Slider(label="OUT_B_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_08")
|
||||||
|
sl_OUT_B_07 = gr.Slider(label="OUT_B_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_07")
|
||||||
|
sl_OUT_B_06 = gr.Slider(label="OUT_B_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_06")
|
||||||
|
sl_OUT_B_05 = gr.Slider(label="OUT_B_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_05")
|
||||||
|
sl_OUT_B_04 = gr.Slider(label="OUT_B_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_04")
|
||||||
|
sl_OUT_B_03 = gr.Slider(label="OUT_B_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_03")
|
||||||
|
sl_OUT_B_02 = gr.Slider(label="OUT_B_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_02")
|
||||||
|
sl_OUT_B_01 = gr.Slider(label="OUT_B_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_01")
|
||||||
|
sl_OUT_B_00 = gr.Slider(label="OUT_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_00")
|
||||||
|
|
||||||
|
# Footer
|
||||||
|
gr.HTML(
|
||||||
|
"""
|
||||||
|
<p style="font-size: 12px" align="right">
|
||||||
|
<b>Merge Block Weighted</b> extension by <a href="https://github.com/bbc-mc" target="_blank">bbc_mc</a><br />
|
||||||
|
<b>MBW Each</b> is experimental functions and <b>NO PROOF</b> of effectiveness.<br />
|
||||||
|
You can try it by own, to dig more deeper into Abyss ...<br />
|
||||||
|
</p>
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
sl_A_IN = [
|
||||||
|
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||||
|
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11]
|
||||||
|
sl_A_MID = [sl_M_A_00]
|
||||||
|
sl_A_OUT = [
|
||||||
|
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||||
|
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11]
|
||||||
|
|
||||||
|
sl_B_IN = [
|
||||||
|
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||||
|
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11]
|
||||||
|
sl_B_MID = [sl_M_B_00]
|
||||||
|
sl_B_OUT = [
|
||||||
|
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||||
|
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11]
|
||||||
|
|
||||||
|
|
||||||
|
# Events
|
||||||
|
def onclick_btn_do_merge_block_weighted(
|
||||||
|
dd_model_A, dd_model_B, txt_multi_process_cmd,
|
||||||
|
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||||
|
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
|
||||||
|
sl_M_A_00,
|
||||||
|
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||||
|
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
|
||||||
|
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||||
|
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
|
||||||
|
sl_M_B_00,
|
||||||
|
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||||
|
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
|
||||||
|
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
|
||||||
|
):
|
||||||
|
base_alpha = sl_base_alpha
|
||||||
|
_weight_A = ",".join(
|
||||||
|
[str(x) for x in [
|
||||||
|
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||||
|
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
|
||||||
|
sl_M_A_00,
|
||||||
|
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||||
|
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
|
||||||
|
]])
|
||||||
|
_weight_B = ",".join(
|
||||||
|
[str(x) for x in [
|
||||||
|
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||||
|
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
|
||||||
|
sl_M_B_00,
|
||||||
|
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||||
|
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
|
||||||
|
]])
|
||||||
|
|
||||||
|
# debug output
|
||||||
|
print( "#### Merge Block Weighted ####")
|
||||||
|
|
||||||
|
if (not dd_model_A or not dd_model_B) and txt_multi_process_cmd == "":
|
||||||
|
_err_msg = f"ERROR: model not found. [{dd_model_A}][{dd_model_B}]"
|
||||||
|
print(_err_msg)
|
||||||
|
return gr.update(value=_err_msg)
|
||||||
|
|
||||||
|
ret_html = ""
|
||||||
|
if txt_multi_process_cmd != "":
|
||||||
|
# need multi-merge
|
||||||
|
_lines = txt_multi_process_cmd.split('\n')
|
||||||
|
print(f"check multi-merge. {len(_lines)} lines found.")
|
||||||
|
for line_index, _line in enumerate(_lines):
|
||||||
|
if _line == "":
|
||||||
|
continue
|
||||||
|
print(f"\n== merge line {line_index+1}/{len(_lines)} ==")
|
||||||
|
_items = [x.strip() for x in _line.split(",") if x != ""]
|
||||||
|
if len(_items) > 0:
|
||||||
|
ret_html += _run_merge(
|
||||||
|
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
|
||||||
|
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw, params=_items)
|
||||||
|
else:
|
||||||
|
_ret = f" multi-merge text found, but invalid params. skipped :[{_line}]"
|
||||||
|
ret_html += _ret
|
||||||
|
print(_ret)
|
||||||
|
else:
|
||||||
|
# normal merge
|
||||||
|
ret_html += _run_merge(
|
||||||
|
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
|
||||||
|
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw)
|
||||||
|
|
||||||
|
sd_models.list_models()
|
||||||
|
print( "#### All merge process done. ####")
|
||||||
|
|
||||||
|
return gr.update(value=f"{ret_html}")
|
||||||
|
btn_do_merge_block_weighted.click(
|
||||||
|
fn=onclick_btn_do_merge_block_weighted,
|
||||||
|
inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd] + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
|
||||||
|
outputs=[html_output_block_weight_info]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0, model_Output="", verbose=False, params=[]):
|
||||||
|
|
||||||
|
# generate output file name from param
|
||||||
|
model_A_info = sd_models.get_closet_checkpoint_match(model_0)
|
||||||
|
_model_A_name = "" if not model_A_info else model_A_info.model_name # expect "" or "aaa.ckpt [abcdefgh]"
|
||||||
|
|
||||||
|
model_B_info = sd_models.get_closet_checkpoint_match(model_1)
|
||||||
|
_model_B_name = "" if not model_B_info else model_B_info.model_name
|
||||||
|
|
||||||
|
model_O = f"bw-merge-{_model_A_name}-{_model_B_name}-{base_alpha}.ckpt" if model_Output == "" else model_Output
|
||||||
|
model_O = model_O if ".ckpt" in model_O else f"{model_O}.ckpt"
|
||||||
|
model_O = re.sub(r'[\\|/|:|?|.|"|<|>|\|\*]', '-', model_O)
|
||||||
|
|
||||||
|
if params and len(params) > 0:
|
||||||
|
for _item in params:
|
||||||
|
# expect "O=merge/test02, IN_B_00 = 0.12345" as params=["O=merge/test02", "IN_B_00 = 0.12345"]
|
||||||
|
if len(_item.split("=")) == 2:
|
||||||
|
_item_l = _item.split("=")[0].strip()
|
||||||
|
_item_r = _item.split("=")[1].strip()
|
||||||
|
if _item_r != "":
|
||||||
|
if _item_l.lower() == "model_a" or _item_l.lower() == "model_b":
|
||||||
|
_model_info = sd_models.get_closet_checkpoint_match(_item_r)
|
||||||
|
if _model_info:
|
||||||
|
_model_name = _model_info.title.split(" ")[0]
|
||||||
|
if _model_name and _model_name.strip() != "":
|
||||||
|
if _item_l.lower() == "model_a":
|
||||||
|
print(f" * Model changed: {model_0} -> {_model_info.title}")
|
||||||
|
model_0 = _model_info.title
|
||||||
|
elif _item_l.lower() == "model_b":
|
||||||
|
print(f" * Model changed: {model_1} -> {_model_info.title}")
|
||||||
|
model_1 = _model_info.title
|
||||||
|
|
||||||
|
elif _item_l.lower() == "preset_weights":
|
||||||
|
_weights = presetWeights.find_weight_by_name(_item_r)
|
||||||
|
if _weights != "" and len(_weights.split(',')) == 25:
|
||||||
|
print(f" * Weights changed by preset-name: {_item_r}")
|
||||||
|
weight_B = _weights
|
||||||
|
weight_A = ",".join([str(1-float(x)) for x in _weights.split(',')])
|
||||||
|
else:
|
||||||
|
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
|
||||||
|
|
||||||
|
elif _item_l.lower() == "weight_values":
|
||||||
|
_weights = _item_r.strip()
|
||||||
|
if _weights != "" and len(_weights.split(' ')) == 25: # this is work-around to use space as separator. Double-meaning issue on commna which already used as value separator and weights separator.
|
||||||
|
print(f" * Weights changed: {_item_r}")
|
||||||
|
weight_B = _weights
|
||||||
|
weight_A = ",".join([str(1-float(x)) for x in _weights.split(' ')])
|
||||||
|
else:
|
||||||
|
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
|
||||||
|
|
||||||
|
elif _item_l.lower() == "base_alpha":
|
||||||
|
if float(_item_r) >= 0:
|
||||||
|
print(f" * base_alpha changed: {base_alpha} -> {_item_r}")
|
||||||
|
base_alpha = float(_item_r)
|
||||||
|
|
||||||
|
elif _item_l.upper() == "O":
|
||||||
|
_item_r = _item_r if ".ckpt" in _item_r else f"{_item_r}.ckpt"
|
||||||
|
print(f" * Output filename changed:[{model_O}] -> [{_item_r}]")
|
||||||
|
model_O = _item_r
|
||||||
|
|
||||||
|
elif len(_item_l.split("_")) == 3:
|
||||||
|
_IMO = _item_l.split("_")[0]
|
||||||
|
_AB = _item_l.split("_")[1]
|
||||||
|
_NUM = _item_l.split("_")[2]
|
||||||
|
|
||||||
|
_index = int(_NUM)
|
||||||
|
_index = _index + 0 if _IMO == "IN" else _index
|
||||||
|
_index = _index + 12 if _IMO == "M" else _index
|
||||||
|
_index = _index + 13 if _IMO == "OUT" else _index
|
||||||
|
|
||||||
|
def _apply_val(key, weight, index, new_value):
|
||||||
|
_weight = [x.strip() for x in weight.split(",")]
|
||||||
|
_new_weight = _weight[:]
|
||||||
|
_new_weight[index] = new_value
|
||||||
|
_new_weight = ",".join(_new_weight)
|
||||||
|
print(f" * weight_{key} changed:[{weight}]")
|
||||||
|
print(f" -> [{_new_weight}]")
|
||||||
|
return _new_weight
|
||||||
|
|
||||||
|
if _AB == "A":
|
||||||
|
weight_A = _apply_val(_AB, weight_A, _index, _item_r)
|
||||||
|
elif _AB == "B":
|
||||||
|
weight_B = _apply_val(_AB, weight_B, _index, _item_r)
|
||||||
|
else:
|
||||||
|
print(f" * Waring: uncaught param found. ignored. [{_item_l}][{_item_r}]")
|
||||||
|
|
||||||
|
#
|
||||||
|
# Prepare params before run merge
|
||||||
|
#
|
||||||
|
output_file = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O)
|
||||||
|
#
|
||||||
|
# Check params
|
||||||
|
#
|
||||||
|
if not os.path.exists(os.path.dirname(output_file)):
|
||||||
|
_err_msg = f"WARNING: target path not found: {os.path.dirname(output_file)}. skipped."
|
||||||
|
print(_err_msg)
|
||||||
|
return _err_msg + "<br />"
|
||||||
|
if not allow_overwrite:
|
||||||
|
if os.path.exists(output_file):
|
||||||
|
_err_msg = f"WARNING: output_file already exists. overwrite not allowed. skipped."
|
||||||
|
print(_err_msg)
|
||||||
|
return _err_msg + "<br />"
|
||||||
|
|
||||||
|
# debug output
|
||||||
|
print(f" model_0 : {model_0}")
|
||||||
|
print(f" model_1 : {model_1}")
|
||||||
|
print(f" model_Out : {model_O}")
|
||||||
|
print(f" base_alpha : {base_alpha}")
|
||||||
|
print(f" output_file: {output_file}")
|
||||||
|
print(f" weight_A : {weight_A}")
|
||||||
|
print(f" weight_B : {weight_B}")
|
||||||
|
|
||||||
|
result, ret_message = merge(
|
||||||
|
weight_A=weight_A, weight_B=weight_B, model_0=model_0, model_1=model_1,
|
||||||
|
allow_overwrite=allow_overwrite, base_alpha=base_alpha, output_file=output_file, verbose=verbose)
|
||||||
|
if result:
|
||||||
|
ret_html = f"merged. {model_0} + {model_1} = {model_O} <br>"
|
||||||
|
print("merged.")
|
||||||
|
else:
|
||||||
|
ret_html = ret_message
|
||||||
|
print("merge failed.")
|
||||||
|
|
||||||
|
|
||||||
|
# save log to history.tsv
|
||||||
|
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
|
||||||
|
model_O_hash = "" if not model_O_info else model_O_info.hash
|
||||||
|
_names = presetWeights.find_names_by_weight(weight_A)
|
||||||
|
if _names and len(_names) > 0:
|
||||||
|
weight_name = _names[0]
|
||||||
|
else:
|
||||||
|
weight_name = ""
|
||||||
|
mergeHistory.add_history(model_0, model_1, model_O, model_O_hash, base_alpha, weight_A, weight_B, weight_name)
|
||||||
|
return ret_html
|
||||||
|
|
||||||
|
btn_clear_weighted.click(
|
||||||
|
fn=lambda: [gr.update(value=0.5) for _ in range(25*2)],
|
||||||
|
inputs=[],
|
||||||
|
outputs=[
|
||||||
|
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||||
|
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
|
||||||
|
sl_M_A_00,
|
||||||
|
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||||
|
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
|
||||||
|
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||||
|
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
|
||||||
|
sl_M_B_00,
|
||||||
|
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||||
|
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_change_dd_preset_weight(dd_preset_weight):
|
||||||
|
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
|
||||||
|
_ret = on_btn_apply_block_weight_from_txt(_weights)
|
||||||
|
return [gr.update(value=_weights)] + _ret
|
||||||
|
dd_preset_weight.change(
|
||||||
|
fn=on_change_dd_preset_weight,
|
||||||
|
inputs=[dd_preset_weight],
|
||||||
|
outputs=[
|
||||||
|
txt_block_weight,
|
||||||
|
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||||
|
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
|
||||||
|
sl_M_A_00,
|
||||||
|
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||||
|
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
|
||||||
|
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||||
|
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
|
||||||
|
sl_M_B_00,
|
||||||
|
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||||
|
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_btn_reload_checkpoint_mbw():
|
||||||
|
sd_models.list_models()
|
||||||
|
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
|
||||||
|
btn_reload_checkpoint_mbw.click(
|
||||||
|
fn=on_btn_reload_checkpoint_mbw,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[dd_model_A, dd_model_B]
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_btn_apply_block_weight_from_txt(txt_block_weight):
|
||||||
|
if not txt_block_weight or txt_block_weight == "":
|
||||||
|
return [gr.update() for _ in range(25*2)]
|
||||||
|
_list = [x.strip() for x in txt_block_weight.split(",")]
|
||||||
|
if(len(_list) != 25):
|
||||||
|
return [gr.update() for _ in range(25*2)]
|
||||||
|
return [gr.update(value=str(1-float(x))) for x in _list] + [gr.update(value=x) for x in _list]
|
||||||
|
btn_apply_block_weithg_from_txt.click(
|
||||||
|
fn=on_btn_apply_block_weight_from_txt,
|
||||||
|
inputs=[txt_block_weight],
|
||||||
|
outputs=[
|
||||||
|
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||||
|
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
|
||||||
|
sl_M_A_00,
|
||||||
|
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||||
|
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
|
||||||
|
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||||
|
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
|
||||||
|
sl_M_B_00,
|
||||||
|
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||||
|
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
@ -8,17 +8,12 @@
|
||||||
import os
|
import os
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import scripts, script_callbacks
|
from modules import script_callbacks
|
||||||
from modules import sd_models, shared
|
|
||||||
|
|
||||||
from scripts.merge_block_weighted import merge
|
|
||||||
from scripts.merge_history import MergeHistory
|
|
||||||
from scripts.preset_weights import PresetWeights
|
|
||||||
|
|
||||||
path_root = scripts.basedir()
|
from scripts.mbw import ui_mbw
|
||||||
|
from scripts.mbw_each import ui_mbw_each
|
||||||
|
|
||||||
mergeHistory = MergeHistory()
|
|
||||||
presetWeights = PresetWeights()
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# UI callback
|
# UI callback
|
||||||
|
|
@ -26,204 +21,11 @@ presetWeights = PresetWeights()
|
||||||
def on_ui_tabs():
|
def on_ui_tabs():
|
||||||
|
|
||||||
with gr.Blocks() as main_block:
|
with gr.Blocks() as main_block:
|
||||||
with gr.Column():
|
with gr.Tab("MBW", elem_id="tab_mbw"):
|
||||||
with gr.Row():
|
ui_mbw.on_ui_tabs()
|
||||||
with gr.Column(variant="panel"):
|
|
||||||
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
|
|
||||||
btn_clear_weighted = gr.Button(value="Clear values")
|
|
||||||
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
|
|
||||||
html_output_block_weight_info = gr.HTML()
|
|
||||||
with gr.Column():
|
|
||||||
dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list())
|
|
||||||
txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25")
|
|
||||||
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
|
|
||||||
with gr.Row():
|
|
||||||
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=1)
|
|
||||||
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
|
|
||||||
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
|
|
||||||
with gr.Row():
|
|
||||||
model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles())
|
|
||||||
model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles())
|
|
||||||
txt_model_O = gr.Text(label="Output Model Name")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
with gr.Column():
|
|
||||||
gr.Slider(visible=False)
|
|
||||||
gr.Slider(visible=False)
|
|
||||||
gr.Slider(visible=False)
|
|
||||||
gr.Slider(visible=False)
|
|
||||||
gr.Slider(visible=False)
|
|
||||||
gr.Slider(visible=False)
|
|
||||||
gr.Slider(visible=False)
|
|
||||||
gr.Slider(visible=False)
|
|
||||||
gr.Slider(visible=False)
|
|
||||||
gr.Slider(visible=False)
|
|
||||||
gr.Slider(visible=False)
|
|
||||||
sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="mbw_sl_M00")
|
|
||||||
with gr.Column():
|
|
||||||
sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
|
||||||
sl_IN = [
|
|
||||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
||||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11]
|
|
||||||
sl_MID = [sl_M_00]
|
|
||||||
sl_OUT = [
|
|
||||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
||||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11]
|
|
||||||
|
|
||||||
# Events
|
with gr.Tab("MBW Each", elem_id="tab_mbw_each"):
|
||||||
def onclick_btn_do_merge_block_weighted(
|
ui_mbw_each.on_ui_tabs()
|
||||||
model_A, model_B,
|
|
||||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
||||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
|
||||||
sl_M_00,
|
|
||||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
||||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
|
||||||
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
|
|
||||||
):
|
|
||||||
_weights = ",".join(
|
|
||||||
[str(x) for x in [
|
|
||||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
||||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
|
||||||
sl_M_00,
|
|
||||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
||||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11
|
|
||||||
]])
|
|
||||||
#
|
|
||||||
if not model_A or not model_B:
|
|
||||||
return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]")
|
|
||||||
|
|
||||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
|
||||||
model_A_info = sd_models.get_closet_checkpoint_match(model_A)
|
|
||||||
if model_A_info:
|
|
||||||
_model_A_name = model_A_info.model_name
|
|
||||||
else:
|
|
||||||
_model_A_name = ""
|
|
||||||
model_B_info = sd_models.get_closet_checkpoint_match(model_B)
|
|
||||||
if model_B_info:
|
|
||||||
_model_B_info = model_B_info.model_name
|
|
||||||
else:
|
|
||||||
_model_B_info = ""
|
|
||||||
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
|
|
||||||
if ".ckpt" not in model_O:
|
|
||||||
model_O = model_O + ".ckpt"
|
|
||||||
|
|
||||||
_output = os.path.join(ckpt_dir, model_O)
|
|
||||||
# debug output
|
|
||||||
print( "#### Merge Block Weighted ####")
|
|
||||||
if not chk_allow_overwrite:
|
|
||||||
if os.path.exists(_output):
|
|
||||||
_err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort."
|
|
||||||
print(_err_msg)
|
|
||||||
return gr.update(value=f"{_err_msg} [{_output}]")
|
|
||||||
print(f"model_0 : {model_A}")
|
|
||||||
print(f"model_1 : {model_B}")
|
|
||||||
print(f"base_alpha : {sl_base_alpha}")
|
|
||||||
print(f"output_file: {_output}")
|
|
||||||
print(f"weights : {_weights}")
|
|
||||||
|
|
||||||
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw)
|
|
||||||
|
|
||||||
sd_models.list_models()
|
|
||||||
if result:
|
|
||||||
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{model_O}"
|
|
||||||
else:
|
|
||||||
ret_html = ret_message
|
|
||||||
|
|
||||||
# save log to history.tsv
|
|
||||||
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
|
|
||||||
model_O_hash = "" if not model_O_info else model_O_info.hash
|
|
||||||
_names = presetWeights.find_names_by_weight(_weights)
|
|
||||||
if _names and len(_names) > 0:
|
|
||||||
weight_name = _names[0]
|
|
||||||
else:
|
|
||||||
weight_name = ""
|
|
||||||
mergeHistory.add_history(model_A, model_B, model_O, model_O_hash, sl_base_alpha, _weights, weight_name)
|
|
||||||
|
|
||||||
return gr.update(value=f"{ret_html}")
|
|
||||||
btn_do_merge_block_weighted.click(
|
|
||||||
fn=onclick_btn_do_merge_block_weighted,
|
|
||||||
inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
|
|
||||||
outputs=[html_output_block_weight_info]
|
|
||||||
)
|
|
||||||
|
|
||||||
btn_clear_weighted.click(
|
|
||||||
fn=lambda: [gr.update(value=0.5) for _ in range(25)],
|
|
||||||
inputs=[],
|
|
||||||
outputs=[
|
|
||||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
||||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
|
||||||
sl_M_00,
|
|
||||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
||||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_change_dd_preset_weight(dd_preset_weight):
|
|
||||||
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
|
|
||||||
_ret = on_btn_apply_block_weithg_from_txt(_weights)
|
|
||||||
return [gr.update(value=_weights)] + _ret
|
|
||||||
dd_preset_weight.change(
|
|
||||||
fn=on_change_dd_preset_weight,
|
|
||||||
inputs=[dd_preset_weight],
|
|
||||||
outputs=[txt_block_weight,
|
|
||||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
||||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
|
||||||
sl_M_00,
|
|
||||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
||||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_btn_reload_checkpoint_mbw():
|
|
||||||
sd_models.list_models()
|
|
||||||
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
|
|
||||||
btn_reload_checkpoint_mbw.click(
|
|
||||||
fn=on_btn_reload_checkpoint_mbw,
|
|
||||||
inputs=[],
|
|
||||||
outputs=[model_A, model_B]
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_btn_apply_block_weithg_from_txt(txt_block_weight):
|
|
||||||
if not txt_block_weight or txt_block_weight == "":
|
|
||||||
return [gr.update() for _ in range(25)]
|
|
||||||
_list = [x.strip() for x in txt_block_weight.split(",")]
|
|
||||||
if(len(_list) != 25):
|
|
||||||
return [gr.update() for _ in range(25)]
|
|
||||||
return [gr.update(value=x) for x in _list]
|
|
||||||
btn_apply_block_weithg_from_txt.click(
|
|
||||||
fn=on_btn_apply_block_weithg_from_txt,
|
|
||||||
inputs=[txt_block_weight],
|
|
||||||
outputs=[
|
|
||||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
|
||||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
|
||||||
sl_M_00,
|
|
||||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
|
||||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# return required as (gradio_component, title, elem_id)
|
# return required as (gradio_component, title, elem_id)
|
||||||
return (main_block, "Merge Block Weighted", "merge_block_weighted"),
|
return (main_block, "Merge Block Weighted", "merge_block_weighted"),
|
||||||
|
|
|
||||||
|
|
@ -1,40 +0,0 @@
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
import os
|
|
||||||
from csv import DictWriter, writer
|
|
||||||
|
|
||||||
from modules import scripts
|
|
||||||
|
|
||||||
|
|
||||||
CSV_FILE_PATH = "csv/history.tsv"
|
|
||||||
HEADERS = ["model_A", "model_A_hash", "model_B", "model_B_hash", "model_O", "model_O_hash", "base_alpha", "weight_name", "weight_values"]
|
|
||||||
path_root = scripts.basedir()
|
|
||||||
|
|
||||||
|
|
||||||
class MergeHistory():
|
|
||||||
def __init__(self):
|
|
||||||
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
|
|
||||||
|
|
||||||
def add_history(self, model_A, model_B, model_O, model_O_hash, sl_base_alpha, weight_values, weight_name=""):
|
|
||||||
_history_dict = {}
|
|
||||||
_history_dict.update({
|
|
||||||
"model_A": f"{os.path.basename(model_A.split(' ')[0])}",
|
|
||||||
"model_A_hash": f"{model_A.split(' ')[1]}",
|
|
||||||
"model_B": f"{os.path.basename(model_B.split(' ')[0])}",
|
|
||||||
"model_B_hash": f"{model_B.split(' ')[1]}",
|
|
||||||
"model_O": model_O,
|
|
||||||
"model_O_hash": model_O_hash,
|
|
||||||
"base_alpha": sl_base_alpha,
|
|
||||||
"weight_name": weight_name,
|
|
||||||
"weight_values": weight_values,
|
|
||||||
})
|
|
||||||
|
|
||||||
if not os.path.exists(self.filepath):
|
|
||||||
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
|
|
||||||
wr = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
|
|
||||||
wr.writeheader()
|
|
||||||
# save to file
|
|
||||||
with open(self.filepath, "a", newline="", encoding='utf-8') as f:
|
|
||||||
dictwriter = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
|
|
||||||
dictwriter.writerow(_history_dict)
|
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
from csv import DictWriter, DictReader
|
||||||
|
|
||||||
|
from modules import scripts
|
||||||
|
|
||||||
|
|
||||||
|
CSV_FILE_PATH = "csv/history.tsv"
|
||||||
|
HEADERS = ["model_A", "model_A_hash", "model_B", "model_B_hash", "model_O", "model_O_hash", "base_alpha", "weight_name", "weight_values", "weight_values2", "datetime"]
|
||||||
|
path_root = scripts.basedir()
|
||||||
|
|
||||||
|
|
||||||
|
class MergeHistory():
|
||||||
|
def __init__(self):
|
||||||
|
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
|
||||||
|
if os.path.exists(self.filepath):
|
||||||
|
self.update_header()
|
||||||
|
|
||||||
|
def add_history(self, model_A, model_B, model_O, model_O_hash, sl_base_alpha, weight_value_A, weight_value_B, weight_name=""):
|
||||||
|
_history_dict = {}
|
||||||
|
_history_dict.update({
|
||||||
|
"model_A": f"{os.path.basename(model_A.split(' ')[0])}",
|
||||||
|
"model_A_hash": f"{model_A.split(' ')[1]}",
|
||||||
|
"model_B": f"{os.path.basename(model_B.split(' ')[0])}",
|
||||||
|
"model_B_hash": f"{model_B.split(' ')[1]}",
|
||||||
|
"model_O": model_O,
|
||||||
|
"model_O_hash": model_O_hash,
|
||||||
|
"base_alpha": sl_base_alpha,
|
||||||
|
"weight_name": weight_name,
|
||||||
|
"weight_values": weight_value_A,
|
||||||
|
"weight_values2": weight_value_B,
|
||||||
|
"datetime": f"{datetime.datetime.now()}"
|
||||||
|
})
|
||||||
|
|
||||||
|
if not os.path.exists(self.filepath):
|
||||||
|
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
|
||||||
|
dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
|
||||||
|
dw.writeheader()
|
||||||
|
# save to file
|
||||||
|
with open(self.filepath, "a", newline="", encoding='utf-8') as f:
|
||||||
|
dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
|
||||||
|
dw.writerow(_history_dict)
|
||||||
|
|
||||||
|
def update_header(self):
|
||||||
|
hist_data = []
|
||||||
|
if os.path.exists(self.filepath):
|
||||||
|
# check header in case HEADERS updated
|
||||||
|
with open(self.filepath, "r", newline="", encoding="utf-8") as f:
|
||||||
|
dr = DictReader(f, delimiter='\t')
|
||||||
|
new_header = [ x for x in HEADERS if x not in dr.fieldnames ]
|
||||||
|
if len(new_header) > 0:
|
||||||
|
# need update.
|
||||||
|
hist_data = [ x for x in dr]
|
||||||
|
if len(hist_data) > 0:
|
||||||
|
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
|
||||||
|
dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
|
||||||
|
dw.writeheader()
|
||||||
|
dw.writerows(hist_data)
|
||||||
|
|
@ -8,14 +8,24 @@ from modules import scripts
|
||||||
|
|
||||||
|
|
||||||
CSV_FILE_PATH = "csv/preset.tsv"
|
CSV_FILE_PATH = "csv/preset.tsv"
|
||||||
|
MYPRESET_PATH = "csv/preset_own.tsv"
|
||||||
|
HEADER = ["preset_name", "preset_weights"]
|
||||||
path_root = scripts.basedir()
|
path_root = scripts.basedir()
|
||||||
|
|
||||||
|
|
||||||
class PresetWeights():
|
class PresetWeights():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
|
|
||||||
self.presets = {}
|
self.presets = {}
|
||||||
with open(self.filepath, "r") as f:
|
|
||||||
|
if os.path.exists(os.path.join(path_root, MYPRESET_PATH)):
|
||||||
|
with open(os.path.join(path_root, MYPRESET_PATH), "r") as f:
|
||||||
|
reader = DictReader(f, delimiter="\t")
|
||||||
|
lines_dict = [row for row in reader]
|
||||||
|
for line_dict in lines_dict:
|
||||||
|
_w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")])
|
||||||
|
self.presets.update({line_dict["preset_name"]: _w})
|
||||||
|
|
||||||
|
with open(os.path.join(path_root, CSV_FILE_PATH), "r") as f:
|
||||||
reader = DictReader(f, delimiter="\t")
|
reader = DictReader(f, delimiter="\t")
|
||||||
lines_dict = [row for row in reader]
|
lines_dict = [row for row in reader]
|
||||||
for line_dict in lines_dict:
|
for line_dict in lines_dict:
|
||||||
|
|
@ -26,7 +36,7 @@ class PresetWeights():
|
||||||
return [k for k in self.presets.keys()]
|
return [k for k in self.presets.keys()]
|
||||||
|
|
||||||
def find_weight_by_name(self, preset_name=""):
|
def find_weight_by_name(self, preset_name=""):
|
||||||
if preset_name and preset_name != "" and preset_name in self.presets:
|
if preset_name and preset_name != "" and preset_name in self.presets.keys():
|
||||||
return self.presets.get(preset_name, ",".join(["0.5" for _ in range(25)]))
|
return self.presets.get(preset_name, ",".join(["0.5" for _ in range(25)]))
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
29
style.css
29
style.css
|
|
@ -1,4 +1,29 @@
|
||||||
#mbw_sl_M00 {
|
#mbw_sl_M00, #mbw_sl_a_M00, #mbw_sl_b_M00 {
|
||||||
bottom:0;
|
bottom:0;
|
||||||
position:absolute;
|
position:absolute;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
#sl_IN_A_00, #sl_IN_A_01, #sl_IN_A_02, #sl_IN_A_03, #sl_IN_A_04, #sl_IN_A_05, #sl_IN_A_06, #sl_IN_A_07, #sl_IN_A_08, #sl_IN_A_09, #sl_IN_A_10, #sl_IN_A_11 {
|
||||||
|
width: 220;
|
||||||
|
}
|
||||||
|
|
||||||
|
#sl_IN_B_00, #sl_IN_B_01, #sl_IN_B_02, #sl_IN_B_03, #sl_IN_B_04, #sl_IN_B_05, #sl_IN_B_06, #sl_IN_B_07, #sl_IN_B_08, #sl_IN_B_09, #sl_IN_B_10, #sl_IN_B_11 {
|
||||||
|
width: 220;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
#sl_M_A_00, #sl_M_B_00 {
|
||||||
|
bottom:0;
|
||||||
|
position:absolute;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
#sl_OUT_A_00, #sl_OUT_A_01, #sl_OUT_A_02, #sl_OUT_A_03, #sl_OUT_A_04, #sl_OUT_A_05, #sl_OUT_A_06, #sl_OUT_A_07, #sl_OUT_A_08, #sl_OUT_A_09, #sl_OUT_A_10, #sl_OUT_A_11 {
|
||||||
|
width: 220;
|
||||||
|
}
|
||||||
|
|
||||||
|
#sl_OUT_B_00, #sl_OUT_B_01, #sl_OUT_B_02, #sl_OUT_B_03, #sl_OUT_B_04, #sl_OUT_B_05, #sl_OUT_B_06, #sl_OUT_B_07, #sl_OUT_B_08, #sl_OUT_B_09, #sl_OUT_B_10, #sl_OUT_B_11 {
|
||||||
|
width: 220;
|
||||||
|
}
|
||||||
|
*/
|
||||||
Loading…
Reference in New Issue