Big update with functions, keeping current function and UIs
add: add tab to divide feature. Current remain on "MBW" and new feature comes on "MBW Each" refact: divide scripts to each feature feature: "MBW Each" function, allow set percentage of model A and model B for each layer. feature: add support for csv/preset_own.tsv, not included in git (so not overwrite by update) feature: add datetime column on logfile feature: add Multi-Merge feature (#5)exp/feature-each-merge
parent
63ba0926bb
commit
75a31b481a
|
|
@ -1 +1,5 @@
|
|||
/csv/history.tsv
|
||||
/csv/preset_own.tsv
|
||||
|
||||
#
|
||||
_*
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
//
|
||||
// fix position of sliders
|
||||
//
|
||||
|
||||
//
|
||||
// UI
|
||||
//
|
||||
onUiUpdate(function () {
|
||||
// check Extension loaded
|
||||
if (gradioApp().querySelector("div#tab_mbw_each") == null ) return;
|
||||
|
||||
// check already done
|
||||
//if (gradioApp().querySelector("#div_mdl_size_a") != null) return;
|
||||
|
||||
// apply
|
||||
let _style = "min-width: min(200px, 100%); flex-grow: 1";
|
||||
gradioApp().querySelector("#sl_IN_A_00").parentElement.parentElement.setAttribute("style", _style)
|
||||
gradioApp().querySelector("#sl_IN_B_00").parentElement.parentElement.setAttribute("style", _style)
|
||||
gradioApp().querySelector("#sl_M_A_00").parentElement.parentElement.setAttribute("style", _style)
|
||||
gradioApp().querySelector("#sl_M_B_00").parentElement.parentElement.setAttribute("style", _style)
|
||||
gradioApp().querySelector("#sl_OUT_A_00").parentElement.parentElement.setAttribute("style", _style)
|
||||
gradioApp().querySelector("#sl_OUT_B_00").parentElement.parentElement.setAttribute("style", _style)
|
||||
});
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
|
||||
from modules import sd_models, shared
|
||||
from tqdm import tqdm
|
||||
|
||||
from scripts.mbw.merge_block_weighted import merge
|
||||
from scripts.util.preset_weights import PresetWeights
|
||||
from scripts.util.merge_history import MergeHistory
|
||||
|
||||
presetWeights = PresetWeights()
|
||||
mergeHistory = MergeHistory()
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Column(variant="panel"):
|
||||
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
|
||||
btn_clear_weighted = gr.Button(value="Clear values")
|
||||
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
|
||||
html_output_block_weight_info = gr.HTML()
|
||||
with gr.Column():
|
||||
dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list())
|
||||
txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25")
|
||||
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
|
||||
with gr.Row():
|
||||
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=1)
|
||||
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
|
||||
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
|
||||
with gr.Row():
|
||||
model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles())
|
||||
model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles())
|
||||
txt_model_O = gr.Text(label="Output Model Name")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
with gr.Column():
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="mbw_sl_M00")
|
||||
with gr.Column():
|
||||
sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
|
||||
sl_IN = [
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11]
|
||||
sl_MID = [sl_M_00]
|
||||
sl_OUT = [
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11]
|
||||
|
||||
# Events
|
||||
def onclick_btn_do_merge_block_weighted(
|
||||
model_A, model_B,
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||
sl_M_00,
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
|
||||
):
|
||||
_weights = ",".join(
|
||||
[str(x) for x in [
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||
sl_M_00,
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11
|
||||
]])
|
||||
#
|
||||
if not model_A or not model_B:
|
||||
return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]")
|
||||
|
||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||
model_A_info = sd_models.get_closet_checkpoint_match(model_A)
|
||||
if model_A_info:
|
||||
_model_A_name = model_A_info.model_name
|
||||
else:
|
||||
_model_A_name = ""
|
||||
model_B_info = sd_models.get_closet_checkpoint_match(model_B)
|
||||
if model_B_info:
|
||||
_model_B_info = model_B_info.model_name
|
||||
else:
|
||||
_model_B_info = ""
|
||||
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
|
||||
if ".ckpt" not in model_O:
|
||||
model_O = model_O + ".ckpt"
|
||||
|
||||
_output = os.path.join(ckpt_dir, model_O)
|
||||
# debug output
|
||||
print( "#### Merge Block Weighted ####")
|
||||
if not chk_allow_overwrite:
|
||||
if os.path.exists(_output):
|
||||
_err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort."
|
||||
print(_err_msg)
|
||||
return gr.update(value=f"{_err_msg} [{_output}]")
|
||||
print(f"model_0 : {model_A}")
|
||||
print(f"model_1 : {model_B}")
|
||||
print(f"base_alpha : {sl_base_alpha}")
|
||||
print(f"output_file: {_output}")
|
||||
print(f"weights : {_weights}")
|
||||
|
||||
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw)
|
||||
|
||||
sd_models.list_models()
|
||||
if result:
|
||||
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{model_O}"
|
||||
else:
|
||||
ret_html = ret_message
|
||||
|
||||
# save log to history.tsv
|
||||
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
|
||||
model_O_hash = "" if not model_O_info else model_O_info.hash
|
||||
_names = presetWeights.find_names_by_weight(_weights)
|
||||
if _names and len(_names) > 0:
|
||||
weight_name = _names[0]
|
||||
else:
|
||||
weight_name = ""
|
||||
mergeHistory.add_history(model_A, model_B, model_O, model_O_hash, sl_base_alpha, _weights, "", weight_name)
|
||||
|
||||
return gr.update(value=f"{ret_html}")
|
||||
btn_do_merge_block_weighted.click(
|
||||
fn=onclick_btn_do_merge_block_weighted,
|
||||
inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
|
||||
outputs=[html_output_block_weight_info]
|
||||
)
|
||||
|
||||
btn_clear_weighted.click(
|
||||
fn=lambda: [gr.update(value=0.5) for _ in range(25)],
|
||||
inputs=[],
|
||||
outputs=[
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||
sl_M_00,
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||
]
|
||||
)
|
||||
|
||||
def on_change_dd_preset_weight(dd_preset_weight):
|
||||
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
|
||||
_ret = on_btn_apply_block_weithg_from_txt(_weights)
|
||||
return [gr.update(value=_weights)] + _ret
|
||||
dd_preset_weight.change(
|
||||
fn=on_change_dd_preset_weight,
|
||||
inputs=[dd_preset_weight],
|
||||
outputs=[txt_block_weight,
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||
sl_M_00,
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||
]
|
||||
)
|
||||
|
||||
def on_btn_reload_checkpoint_mbw():
|
||||
sd_models.list_models()
|
||||
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
|
||||
btn_reload_checkpoint_mbw.click(
|
||||
fn=on_btn_reload_checkpoint_mbw,
|
||||
inputs=[],
|
||||
outputs=[model_A, model_B]
|
||||
)
|
||||
|
||||
def on_btn_apply_block_weithg_from_txt(txt_block_weight):
|
||||
if not txt_block_weight or txt_block_weight == "":
|
||||
return [gr.update() for _ in range(25)]
|
||||
_list = [x.strip() for x in txt_block_weight.split(",")]
|
||||
if(len(_list) != 25):
|
||||
return [gr.update() for _ in range(25)]
|
||||
return [gr.update(value=x) for x in _list]
|
||||
btn_apply_block_weithg_from_txt.click(
|
||||
fn=on_btn_apply_block_weithg_from_txt,
|
||||
inputs=[txt_block_weight],
|
||||
outputs=[
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||
sl_M_00,
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
# from https://note.com/kohya_ss/n/n9a485a066d5b
|
||||
# kohya_ss
|
||||
# original code: https://github.com/eyriewow/merge-models
|
||||
|
||||
# use them as base of this code
|
||||
# 2022/12/15
|
||||
# bbc-mc
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import re
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import sd_models, shared
|
||||
|
||||
|
||||
NUM_INPUT_BLOCKS = 12
|
||||
NUM_MID_BLOCK = 1
|
||||
NUM_OUTPUT_BLOCKS = 12
|
||||
NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
|
||||
|
||||
|
||||
def dprint(str, flg):
|
||||
if flg:
|
||||
print(str)
|
||||
|
||||
|
||||
def merge(weight_A:list, weight_B:list, model_0, model_1, device="cpu", base_alpha=0.5,
|
||||
output_file="", allow_overwrite=False, verbose=False):
|
||||
|
||||
def _check_arg_weight(weight):
|
||||
if weight is None:
|
||||
return None
|
||||
else:
|
||||
_weight = [float(w) for w in weight.split(",")]
|
||||
if len(_weight) != NUM_TOTAL_BLOCKS:
|
||||
return None
|
||||
else:
|
||||
return _weight
|
||||
|
||||
weight_A = _check_arg_weight(weight_A)
|
||||
if weight_A is None:
|
||||
_err_msg = f"Weight A invalid. program abort. weights value must be {NUM_TOTAL_BLOCKS}."
|
||||
print(_err_msg)
|
||||
return False, _err_msg
|
||||
weight_B = _check_arg_weight(weight_B)
|
||||
if weight_B is None:
|
||||
_err_msg = f"Weight B invalid. program abort. weights value must be {NUM_TOTAL_BLOCKS}."
|
||||
print(_err_msg)
|
||||
return False, _err_msg
|
||||
|
||||
device = device if device in ["cpu", "cuda"] else "cpu"
|
||||
|
||||
alpha = base_alpha
|
||||
if not output_file or output_file == "":
|
||||
output_file = f'bw-{model_0}-{model_1}-{str(alpha)[2:] + "0"}.ckpt'
|
||||
else:
|
||||
output_file = output_file if ".ckpt" in output_file else output_file + ".ckpt"
|
||||
|
||||
# check if output file already exists
|
||||
if os.path.isfile(output_file) and not allow_overwrite:
|
||||
_err_msg = f"Exiting... [{output_file}]"
|
||||
print(_err_msg)
|
||||
return False, _err_msg
|
||||
|
||||
def load_model(_model, _device):
|
||||
model_info = sd_models.get_closet_checkpoint_match(_model)
|
||||
if model_info:
|
||||
model_file = model_info.filename
|
||||
else:
|
||||
return None
|
||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||
if cache_enabled and model_info in sd_models.checkpoints_loaded:
|
||||
print(" load from cache")
|
||||
return sd_models.checkpoints_loaded[model_info].copy()
|
||||
else:
|
||||
print(" loading ...")
|
||||
return sd_models.read_state_dict(model_file, map_location=_device)
|
||||
|
||||
print("loading", model_0)
|
||||
theta_0 = load_model(model_0, device)
|
||||
|
||||
print("loading", model_1)
|
||||
theta_1 = load_model(model_1, device)
|
||||
|
||||
re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
|
||||
re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
|
||||
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
|
||||
|
||||
dprint(f"-- start Stage 1/2 --", verbose)
|
||||
count_target_of_basealpha = 0
|
||||
for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not verbose else theta_0.keys()):
|
||||
if "model" in key and key in theta_1:
|
||||
dprint(f" key : {key}", verbose)
|
||||
|
||||
current_alpha_A = 1 - alpha
|
||||
current_alpha_B = alpha
|
||||
current_alpha_I = 0
|
||||
|
||||
# check weighted and U-Net or not
|
||||
if weight_A is not None and 'model.diffusion_model.' in key:
|
||||
# check block index
|
||||
weight_index = -1
|
||||
|
||||
if 'time_embed' in key:
|
||||
weight_index = 0 # before input blocks
|
||||
elif '.out.' in key:
|
||||
weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
|
||||
else:
|
||||
m = re_inp.search(key)
|
||||
if m:
|
||||
inp_idx = int(m.groups()[0])
|
||||
weight_index = inp_idx
|
||||
else:
|
||||
m = re_mid.search(key)
|
||||
if m:
|
||||
weight_index = NUM_INPUT_BLOCKS
|
||||
else:
|
||||
m = re_out.search(key)
|
||||
if m:
|
||||
out_idx = int(m.groups()[0])
|
||||
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
|
||||
|
||||
if weight_index >= NUM_TOTAL_BLOCKS:
|
||||
print(f"error. illegal block index: {key}")
|
||||
if weight_index >= 0:
|
||||
current_alpha_A = weight_A[weight_index]
|
||||
current_alpha_B = weight_B[weight_index]
|
||||
current_alpha_I = 1 - current_alpha_A - current_alpha_B
|
||||
if verbose:
|
||||
print(f"weighted '{key}': A{current_alpha_A} B{current_alpha_B} I{current_alpha_I}")
|
||||
|
||||
# create I tensor
|
||||
tensor_I_0 = torch.zeros_like(theta_0[key], dtype=theta_0[key].dtype)
|
||||
_var1 = current_alpha_I * tensor_I_0
|
||||
_var2 = current_alpha_A * theta_0[key]
|
||||
_var3 = current_alpha_B * theta_1[key]
|
||||
theta_0[key] = _var1 + _var2 + _var3
|
||||
|
||||
# theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
|
||||
|
||||
else:
|
||||
dprint(f" key - {key}", verbose)
|
||||
|
||||
dprint(f"-- start Stage 2/2 --", verbose)
|
||||
for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
|
||||
if "model" in key and key not in theta_0:
|
||||
dprint(f" key : {key}", verbose)
|
||||
theta_0.update({key:theta_1[key]})
|
||||
else:
|
||||
dprint(f" key - {key}", verbose)
|
||||
|
||||
print("Saving...")
|
||||
|
||||
torch.save({"state_dict": theta_0}, output_file)
|
||||
|
||||
print("Done!")
|
||||
|
||||
return True, f"{output_file}<br>base_alpha applied [{count_target_of_basealpha}] times."
|
||||
|
|
@ -0,0 +1,426 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
import re
|
||||
|
||||
from modules import sd_models, shared
|
||||
from tqdm import tqdm
|
||||
|
||||
from scripts.mbw_each.merge_block_weighted_mod import merge
|
||||
from scripts.util.preset_weights import PresetWeights
|
||||
from scripts.util.merge_history import MergeHistory
|
||||
|
||||
presetWeights = PresetWeights()
|
||||
mergeHistory = MergeHistory()
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Column(variant="panel"):
|
||||
with gr.Row():
|
||||
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
|
||||
btn_clear_weighted = gr.Button(value="Clear values")
|
||||
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
|
||||
with gr.Row():
|
||||
txt_multi_process_cmd = gr.TextArea(label="Multi Proc Cmd", placeholder="Keep empty if dont use.")
|
||||
html_output_block_weight_info = gr.HTML()
|
||||
with gr.Column():
|
||||
dd_preset_weight = gr.Dropdown(label="Preset_Weights", choices=presetWeights.get_preset_name_list())
|
||||
txt_block_weight = gr.Text(label="Weight_values", placeholder="Put weight sets. float number x 25")
|
||||
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
|
||||
with gr.Row():
|
||||
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=0)
|
||||
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
|
||||
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
|
||||
with gr.Row():
|
||||
dd_model_A = gr.Dropdown(label="Model_A", choices=sd_models.checkpoint_tiles())
|
||||
dd_model_B = gr.Dropdown(label="Model_B", choices=sd_models.checkpoint_tiles())
|
||||
txt_model_O = gr.Text(label="(O)Output Model Name")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
sl_IN_A_00 = gr.Slider(label="IN_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_00")
|
||||
sl_IN_A_01 = gr.Slider(label="IN_A_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_01")
|
||||
sl_IN_A_02 = gr.Slider(label="IN_A_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_02")
|
||||
sl_IN_A_03 = gr.Slider(label="IN_A_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_03")
|
||||
sl_IN_A_04 = gr.Slider(label="IN_A_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_04")
|
||||
sl_IN_A_05 = gr.Slider(label="IN_A_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_05")
|
||||
sl_IN_A_06 = gr.Slider(label="IN_A_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_06")
|
||||
sl_IN_A_07 = gr.Slider(label="IN_A_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_07")
|
||||
sl_IN_A_08 = gr.Slider(label="IN_A_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_08")
|
||||
sl_IN_A_09 = gr.Slider(label="IN_A_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_09")
|
||||
sl_IN_A_10 = gr.Slider(label="IN_A_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_10")
|
||||
sl_IN_A_11 = gr.Slider(label="IN_A_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_11")
|
||||
with gr.Column():
|
||||
sl_IN_B_00 = gr.Slider(label="IN_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_00")
|
||||
sl_IN_B_01 = gr.Slider(label="IN_B_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_01")
|
||||
sl_IN_B_02 = gr.Slider(label="IN_B_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_02")
|
||||
sl_IN_B_03 = gr.Slider(label="IN_B_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_03")
|
||||
sl_IN_B_04 = gr.Slider(label="IN_B_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_04")
|
||||
sl_IN_B_05 = gr.Slider(label="IN_B_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_05")
|
||||
sl_IN_B_06 = gr.Slider(label="IN_B_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_06")
|
||||
sl_IN_B_07 = gr.Slider(label="IN_B_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_07")
|
||||
sl_IN_B_08 = gr.Slider(label="IN_B_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_08")
|
||||
sl_IN_B_09 = gr.Slider(label="IN_B_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_09")
|
||||
sl_IN_B_10 = gr.Slider(label="IN_B_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_10")
|
||||
sl_IN_B_11 = gr.Slider(label="IN_B_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_IN_A_11")
|
||||
with gr.Column():
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
sl_M_A_00 = gr.Slider(label="M_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_M_A_00")
|
||||
with gr.Column():
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
sl_M_B_00 = gr.Slider(label="M_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_M_B_00")
|
||||
with gr.Column():
|
||||
sl_OUT_A_11 = gr.Slider(label="OUT_A_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_11")
|
||||
sl_OUT_A_10 = gr.Slider(label="OUT_A_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_10")
|
||||
sl_OUT_A_09 = gr.Slider(label="OUT_A_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_09")
|
||||
sl_OUT_A_08 = gr.Slider(label="OUT_A_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_08")
|
||||
sl_OUT_A_07 = gr.Slider(label="OUT_A_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_07")
|
||||
sl_OUT_A_06 = gr.Slider(label="OUT_A_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_06")
|
||||
sl_OUT_A_05 = gr.Slider(label="OUT_A_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_05")
|
||||
sl_OUT_A_04 = gr.Slider(label="OUT_A_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_04")
|
||||
sl_OUT_A_03 = gr.Slider(label="OUT_A_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_03")
|
||||
sl_OUT_A_02 = gr.Slider(label="OUT_A_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_02")
|
||||
sl_OUT_A_01 = gr.Slider(label="OUT_A_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_01")
|
||||
sl_OUT_A_00 = gr.Slider(label="OUT_A_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_A_00")
|
||||
with gr.Column():
|
||||
sl_OUT_B_11 = gr.Slider(label="OUT_B_11", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_11")
|
||||
sl_OUT_B_10 = gr.Slider(label="OUT_B_10", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_10")
|
||||
sl_OUT_B_09 = gr.Slider(label="OUT_B_09", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_09")
|
||||
sl_OUT_B_08 = gr.Slider(label="OUT_B_08", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_08")
|
||||
sl_OUT_B_07 = gr.Slider(label="OUT_B_07", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_07")
|
||||
sl_OUT_B_06 = gr.Slider(label="OUT_B_06", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_06")
|
||||
sl_OUT_B_05 = gr.Slider(label="OUT_B_05", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_05")
|
||||
sl_OUT_B_04 = gr.Slider(label="OUT_B_04", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_04")
|
||||
sl_OUT_B_03 = gr.Slider(label="OUT_B_03", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_03")
|
||||
sl_OUT_B_02 = gr.Slider(label="OUT_B_02", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_02")
|
||||
sl_OUT_B_01 = gr.Slider(label="OUT_B_01", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_01")
|
||||
sl_OUT_B_00 = gr.Slider(label="OUT_B_00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="sl_OUT_B_00")
|
||||
|
||||
# Footer
|
||||
gr.HTML(
|
||||
"""
|
||||
<p style="font-size: 12px" align="right">
|
||||
<b>Merge Block Weighted</b> extension by <a href="https://github.com/bbc-mc" target="_blank">bbc_mc</a><br />
|
||||
<b>MBW Each</b> is experimental functions and <b>NO PROOF</b> of effectiveness.<br />
|
||||
You can try it by own, to dig more deeper into Abyss ...<br />
|
||||
</p>
|
||||
"""
|
||||
)
|
||||
|
||||
sl_A_IN = [
|
||||
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11]
|
||||
sl_A_MID = [sl_M_A_00]
|
||||
sl_A_OUT = [
|
||||
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11]
|
||||
|
||||
sl_B_IN = [
|
||||
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11]
|
||||
sl_B_MID = [sl_M_B_00]
|
||||
sl_B_OUT = [
|
||||
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11]
|
||||
|
||||
|
||||
# Events
|
||||
def onclick_btn_do_merge_block_weighted(
|
||||
dd_model_A, dd_model_B, txt_multi_process_cmd,
|
||||
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
|
||||
sl_M_A_00,
|
||||
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
|
||||
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
|
||||
sl_M_B_00,
|
||||
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
|
||||
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
|
||||
):
|
||||
base_alpha = sl_base_alpha
|
||||
_weight_A = ",".join(
|
||||
[str(x) for x in [
|
||||
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
|
||||
sl_M_A_00,
|
||||
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
|
||||
]])
|
||||
_weight_B = ",".join(
|
||||
[str(x) for x in [
|
||||
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
|
||||
sl_M_B_00,
|
||||
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
|
||||
]])
|
||||
|
||||
# debug output
|
||||
print( "#### Merge Block Weighted ####")
|
||||
|
||||
if (not dd_model_A or not dd_model_B) and txt_multi_process_cmd == "":
|
||||
_err_msg = f"ERROR: model not found. [{dd_model_A}][{dd_model_B}]"
|
||||
print(_err_msg)
|
||||
return gr.update(value=_err_msg)
|
||||
|
||||
ret_html = ""
|
||||
if txt_multi_process_cmd != "":
|
||||
# need multi-merge
|
||||
_lines = txt_multi_process_cmd.split('\n')
|
||||
print(f"check multi-merge. {len(_lines)} lines found.")
|
||||
for line_index, _line in enumerate(_lines):
|
||||
if _line == "":
|
||||
continue
|
||||
print(f"\n== merge line {line_index+1}/{len(_lines)} ==")
|
||||
_items = [x.strip() for x in _line.split(",") if x != ""]
|
||||
if len(_items) > 0:
|
||||
ret_html += _run_merge(
|
||||
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
|
||||
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw, params=_items)
|
||||
else:
|
||||
_ret = f" multi-merge text found, but invalid params. skipped :[{_line}]"
|
||||
ret_html += _ret
|
||||
print(_ret)
|
||||
else:
|
||||
# normal merge
|
||||
ret_html += _run_merge(
|
||||
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
|
||||
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, verbose=chk_verbose_mbw)
|
||||
|
||||
sd_models.list_models()
|
||||
print( "#### All merge process done. ####")
|
||||
|
||||
return gr.update(value=f"{ret_html}")
|
||||
btn_do_merge_block_weighted.click(
|
||||
fn=onclick_btn_do_merge_block_weighted,
|
||||
inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd] + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
|
||||
outputs=[html_output_block_weight_info]
|
||||
)
|
||||
|
||||
def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0, model_Output="", verbose=False, params=[]):
|
||||
|
||||
# generate output file name from param
|
||||
model_A_info = sd_models.get_closet_checkpoint_match(model_0)
|
||||
_model_A_name = "" if not model_A_info else model_A_info.model_name # expect "" or "aaa.ckpt [abcdefgh]"
|
||||
|
||||
model_B_info = sd_models.get_closet_checkpoint_match(model_1)
|
||||
_model_B_name = "" if not model_B_info else model_B_info.model_name
|
||||
|
||||
model_O = f"bw-merge-{_model_A_name}-{_model_B_name}-{base_alpha}.ckpt" if model_Output == "" else model_Output
|
||||
model_O = model_O if ".ckpt" in model_O else f"{model_O}.ckpt"
|
||||
model_O = re.sub(r'[\\|/|:|?|.|"|<|>|\|\*]', '-', model_O)
|
||||
|
||||
if params and len(params) > 0:
|
||||
for _item in params:
|
||||
# expect "O=merge/test02, IN_B_00 = 0.12345" as params=["O=merge/test02", "IN_B_00 = 0.12345"]
|
||||
if len(_item.split("=")) == 2:
|
||||
_item_l = _item.split("=")[0].strip()
|
||||
_item_r = _item.split("=")[1].strip()
|
||||
if _item_r != "":
|
||||
if _item_l.lower() == "model_a" or _item_l.lower() == "model_b":
|
||||
_model_info = sd_models.get_closet_checkpoint_match(_item_r)
|
||||
if _model_info:
|
||||
_model_name = _model_info.title.split(" ")[0]
|
||||
if _model_name and _model_name.strip() != "":
|
||||
if _item_l.lower() == "model_a":
|
||||
print(f" * Model changed: {model_0} -> {_model_info.title}")
|
||||
model_0 = _model_info.title
|
||||
elif _item_l.lower() == "model_b":
|
||||
print(f" * Model changed: {model_1} -> {_model_info.title}")
|
||||
model_1 = _model_info.title
|
||||
|
||||
elif _item_l.lower() == "preset_weights":
|
||||
_weights = presetWeights.find_weight_by_name(_item_r)
|
||||
if _weights != "" and len(_weights.split(',')) == 25:
|
||||
print(f" * Weights changed by preset-name: {_item_r}")
|
||||
weight_B = _weights
|
||||
weight_A = ",".join([str(1-float(x)) for x in _weights.split(',')])
|
||||
else:
|
||||
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
|
||||
|
||||
elif _item_l.lower() == "weight_values":
|
||||
_weights = _item_r.strip()
|
||||
if _weights != "" and len(_weights.split(' ')) == 25: # this is work-around to use space as separator. Double-meaning issue on commna which already used as value separator and weights separator.
|
||||
print(f" * Weights changed: {_item_r}")
|
||||
weight_B = _weights
|
||||
weight_A = ",".join([str(1-float(x)) for x in _weights.split(' ')])
|
||||
else:
|
||||
print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]")
|
||||
|
||||
elif _item_l.lower() == "base_alpha":
|
||||
if float(_item_r) >= 0:
|
||||
print(f" * base_alpha changed: {base_alpha} -> {_item_r}")
|
||||
base_alpha = float(_item_r)
|
||||
|
||||
elif _item_l.upper() == "O":
|
||||
_item_r = _item_r if ".ckpt" in _item_r else f"{_item_r}.ckpt"
|
||||
print(f" * Output filename changed:[{model_O}] -> [{_item_r}]")
|
||||
model_O = _item_r
|
||||
|
||||
elif len(_item_l.split("_")) == 3:
|
||||
_IMO = _item_l.split("_")[0]
|
||||
_AB = _item_l.split("_")[1]
|
||||
_NUM = _item_l.split("_")[2]
|
||||
|
||||
_index = int(_NUM)
|
||||
_index = _index + 0 if _IMO == "IN" else _index
|
||||
_index = _index + 12 if _IMO == "M" else _index
|
||||
_index = _index + 13 if _IMO == "OUT" else _index
|
||||
|
||||
def _apply_val(key, weight, index, new_value):
|
||||
_weight = [x.strip() for x in weight.split(",")]
|
||||
_new_weight = _weight[:]
|
||||
_new_weight[index] = new_value
|
||||
_new_weight = ",".join(_new_weight)
|
||||
print(f" * weight_{key} changed:[{weight}]")
|
||||
print(f" -> [{_new_weight}]")
|
||||
return _new_weight
|
||||
|
||||
if _AB == "A":
|
||||
weight_A = _apply_val(_AB, weight_A, _index, _item_r)
|
||||
elif _AB == "B":
|
||||
weight_B = _apply_val(_AB, weight_B, _index, _item_r)
|
||||
else:
|
||||
print(f" * Waring: uncaught param found. ignored. [{_item_l}][{_item_r}]")
|
||||
|
||||
#
|
||||
# Prepare params before run merge
|
||||
#
|
||||
output_file = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O)
|
||||
#
|
||||
# Check params
|
||||
#
|
||||
if not os.path.exists(os.path.dirname(output_file)):
|
||||
_err_msg = f"WARNING: target path not found: {os.path.dirname(output_file)}. skipped."
|
||||
print(_err_msg)
|
||||
return _err_msg + "<br />"
|
||||
if not allow_overwrite:
|
||||
if os.path.exists(output_file):
|
||||
_err_msg = f"WARNING: output_file already exists. overwrite not allowed. skipped."
|
||||
print(_err_msg)
|
||||
return _err_msg + "<br />"
|
||||
|
||||
# debug output
|
||||
print(f" model_0 : {model_0}")
|
||||
print(f" model_1 : {model_1}")
|
||||
print(f" model_Out : {model_O}")
|
||||
print(f" base_alpha : {base_alpha}")
|
||||
print(f" output_file: {output_file}")
|
||||
print(f" weight_A : {weight_A}")
|
||||
print(f" weight_B : {weight_B}")
|
||||
|
||||
result, ret_message = merge(
|
||||
weight_A=weight_A, weight_B=weight_B, model_0=model_0, model_1=model_1,
|
||||
allow_overwrite=allow_overwrite, base_alpha=base_alpha, output_file=output_file, verbose=verbose)
|
||||
if result:
|
||||
ret_html = f"merged. {model_0} + {model_1} = {model_O} <br>"
|
||||
print("merged.")
|
||||
else:
|
||||
ret_html = ret_message
|
||||
print("merge failed.")
|
||||
|
||||
|
||||
# save log to history.tsv
|
||||
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
|
||||
model_O_hash = "" if not model_O_info else model_O_info.hash
|
||||
_names = presetWeights.find_names_by_weight(weight_A)
|
||||
if _names and len(_names) > 0:
|
||||
weight_name = _names[0]
|
||||
else:
|
||||
weight_name = ""
|
||||
mergeHistory.add_history(model_0, model_1, model_O, model_O_hash, base_alpha, weight_A, weight_B, weight_name)
|
||||
return ret_html
|
||||
|
||||
btn_clear_weighted.click(
|
||||
fn=lambda: [gr.update(value=0.5) for _ in range(25*2)],
|
||||
inputs=[],
|
||||
outputs=[
|
||||
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
|
||||
sl_M_A_00,
|
||||
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
|
||||
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
|
||||
sl_M_B_00,
|
||||
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
|
||||
]
|
||||
)
|
||||
|
||||
def on_change_dd_preset_weight(dd_preset_weight):
|
||||
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
|
||||
_ret = on_btn_apply_block_weight_from_txt(_weights)
|
||||
return [gr.update(value=_weights)] + _ret
|
||||
dd_preset_weight.change(
|
||||
fn=on_change_dd_preset_weight,
|
||||
inputs=[dd_preset_weight],
|
||||
outputs=[
|
||||
txt_block_weight,
|
||||
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
|
||||
sl_M_A_00,
|
||||
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
|
||||
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
|
||||
sl_M_B_00,
|
||||
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
|
||||
]
|
||||
)
|
||||
|
||||
def on_btn_reload_checkpoint_mbw():
|
||||
sd_models.list_models()
|
||||
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
|
||||
btn_reload_checkpoint_mbw.click(
|
||||
fn=on_btn_reload_checkpoint_mbw,
|
||||
inputs=[],
|
||||
outputs=[dd_model_A, dd_model_B]
|
||||
)
|
||||
|
||||
def on_btn_apply_block_weight_from_txt(txt_block_weight):
|
||||
if not txt_block_weight or txt_block_weight == "":
|
||||
return [gr.update() for _ in range(25*2)]
|
||||
_list = [x.strip() for x in txt_block_weight.split(",")]
|
||||
if(len(_list) != 25):
|
||||
return [gr.update() for _ in range(25*2)]
|
||||
return [gr.update(value=str(1-float(x))) for x in _list] + [gr.update(value=x) for x in _list]
|
||||
btn_apply_block_weithg_from_txt.click(
|
||||
fn=on_btn_apply_block_weight_from_txt,
|
||||
inputs=[txt_block_weight],
|
||||
outputs=[
|
||||
sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05,
|
||||
sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11,
|
||||
sl_M_A_00,
|
||||
sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05,
|
||||
sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11,
|
||||
sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05,
|
||||
sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11,
|
||||
sl_M_B_00,
|
||||
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
|
||||
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
|
||||
]
|
||||
)
|
||||
|
|
@ -8,17 +8,12 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, script_callbacks
|
||||
from modules import sd_models, shared
|
||||
from modules import script_callbacks
|
||||
|
||||
from scripts.merge_block_weighted import merge
|
||||
from scripts.merge_history import MergeHistory
|
||||
from scripts.preset_weights import PresetWeights
|
||||
|
||||
path_root = scripts.basedir()
|
||||
from scripts.mbw import ui_mbw
|
||||
from scripts.mbw_each import ui_mbw_each
|
||||
|
||||
mergeHistory = MergeHistory()
|
||||
presetWeights = PresetWeights()
|
||||
|
||||
#
|
||||
# UI callback
|
||||
|
|
@ -26,204 +21,11 @@ presetWeights = PresetWeights()
|
|||
def on_ui_tabs():
|
||||
|
||||
with gr.Blocks() as main_block:
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Column(variant="panel"):
|
||||
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
|
||||
btn_clear_weighted = gr.Button(value="Clear values")
|
||||
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
|
||||
html_output_block_weight_info = gr.HTML()
|
||||
with gr.Column():
|
||||
dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list())
|
||||
txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25")
|
||||
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
|
||||
with gr.Row():
|
||||
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=1)
|
||||
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
|
||||
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
|
||||
with gr.Row():
|
||||
model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles())
|
||||
model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles())
|
||||
txt_model_O = gr.Text(label="Output Model Name")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
with gr.Column():
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
gr.Slider(visible=False)
|
||||
sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="mbw_sl_M00")
|
||||
with gr.Column():
|
||||
sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
|
||||
sl_IN = [
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11]
|
||||
sl_MID = [sl_M_00]
|
||||
sl_OUT = [
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11]
|
||||
with gr.Tab("MBW", elem_id="tab_mbw"):
|
||||
ui_mbw.on_ui_tabs()
|
||||
|
||||
# Events
|
||||
def onclick_btn_do_merge_block_weighted(
|
||||
model_A, model_B,
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||
sl_M_00,
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
|
||||
):
|
||||
_weights = ",".join(
|
||||
[str(x) for x in [
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||
sl_M_00,
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11
|
||||
]])
|
||||
#
|
||||
if not model_A or not model_B:
|
||||
return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]")
|
||||
|
||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||
model_A_info = sd_models.get_closet_checkpoint_match(model_A)
|
||||
if model_A_info:
|
||||
_model_A_name = model_A_info.model_name
|
||||
else:
|
||||
_model_A_name = ""
|
||||
model_B_info = sd_models.get_closet_checkpoint_match(model_B)
|
||||
if model_B_info:
|
||||
_model_B_info = model_B_info.model_name
|
||||
else:
|
||||
_model_B_info = ""
|
||||
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
|
||||
if ".ckpt" not in model_O:
|
||||
model_O = model_O + ".ckpt"
|
||||
|
||||
_output = os.path.join(ckpt_dir, model_O)
|
||||
# debug output
|
||||
print( "#### Merge Block Weighted ####")
|
||||
if not chk_allow_overwrite:
|
||||
if os.path.exists(_output):
|
||||
_err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort."
|
||||
print(_err_msg)
|
||||
return gr.update(value=f"{_err_msg} [{_output}]")
|
||||
print(f"model_0 : {model_A}")
|
||||
print(f"model_1 : {model_B}")
|
||||
print(f"base_alpha : {sl_base_alpha}")
|
||||
print(f"output_file: {_output}")
|
||||
print(f"weights : {_weights}")
|
||||
|
||||
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw)
|
||||
|
||||
sd_models.list_models()
|
||||
if result:
|
||||
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{model_O}"
|
||||
else:
|
||||
ret_html = ret_message
|
||||
|
||||
# save log to history.tsv
|
||||
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
|
||||
model_O_hash = "" if not model_O_info else model_O_info.hash
|
||||
_names = presetWeights.find_names_by_weight(_weights)
|
||||
if _names and len(_names) > 0:
|
||||
weight_name = _names[0]
|
||||
else:
|
||||
weight_name = ""
|
||||
mergeHistory.add_history(model_A, model_B, model_O, model_O_hash, sl_base_alpha, _weights, weight_name)
|
||||
|
||||
return gr.update(value=f"{ret_html}")
|
||||
btn_do_merge_block_weighted.click(
|
||||
fn=onclick_btn_do_merge_block_weighted,
|
||||
inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
|
||||
outputs=[html_output_block_weight_info]
|
||||
)
|
||||
|
||||
btn_clear_weighted.click(
|
||||
fn=lambda: [gr.update(value=0.5) for _ in range(25)],
|
||||
inputs=[],
|
||||
outputs=[
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||
sl_M_00,
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||
]
|
||||
)
|
||||
|
||||
def on_change_dd_preset_weight(dd_preset_weight):
|
||||
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
|
||||
_ret = on_btn_apply_block_weithg_from_txt(_weights)
|
||||
return [gr.update(value=_weights)] + _ret
|
||||
dd_preset_weight.change(
|
||||
fn=on_change_dd_preset_weight,
|
||||
inputs=[dd_preset_weight],
|
||||
outputs=[txt_block_weight,
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||
sl_M_00,
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||
]
|
||||
)
|
||||
|
||||
def on_btn_reload_checkpoint_mbw():
|
||||
sd_models.list_models()
|
||||
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
|
||||
btn_reload_checkpoint_mbw.click(
|
||||
fn=on_btn_reload_checkpoint_mbw,
|
||||
inputs=[],
|
||||
outputs=[model_A, model_B]
|
||||
)
|
||||
|
||||
def on_btn_apply_block_weithg_from_txt(txt_block_weight):
|
||||
if not txt_block_weight or txt_block_weight == "":
|
||||
return [gr.update() for _ in range(25)]
|
||||
_list = [x.strip() for x in txt_block_weight.split(",")]
|
||||
if(len(_list) != 25):
|
||||
return [gr.update() for _ in range(25)]
|
||||
return [gr.update(value=x) for x in _list]
|
||||
btn_apply_block_weithg_from_txt.click(
|
||||
fn=on_btn_apply_block_weithg_from_txt,
|
||||
inputs=[txt_block_weight],
|
||||
outputs=[
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||
sl_M_00,
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||
]
|
||||
)
|
||||
with gr.Tab("MBW Each", elem_id="tab_mbw_each"):
|
||||
ui_mbw_each.on_ui_tabs()
|
||||
|
||||
# return required as (gradio_component, title, elem_id)
|
||||
return (main_block, "Merge Block Weighted", "merge_block_weighted"),
|
||||
|
|
|
|||
|
|
@ -1,40 +0,0 @@
|
|||
#
|
||||
#
|
||||
#
|
||||
import os
|
||||
from csv import DictWriter, writer
|
||||
|
||||
from modules import scripts
|
||||
|
||||
|
||||
CSV_FILE_PATH = "csv/history.tsv"
|
||||
HEADERS = ["model_A", "model_A_hash", "model_B", "model_B_hash", "model_O", "model_O_hash", "base_alpha", "weight_name", "weight_values"]
|
||||
path_root = scripts.basedir()
|
||||
|
||||
|
||||
class MergeHistory():
|
||||
def __init__(self):
|
||||
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
|
||||
|
||||
def add_history(self, model_A, model_B, model_O, model_O_hash, sl_base_alpha, weight_values, weight_name=""):
|
||||
_history_dict = {}
|
||||
_history_dict.update({
|
||||
"model_A": f"{os.path.basename(model_A.split(' ')[0])}",
|
||||
"model_A_hash": f"{model_A.split(' ')[1]}",
|
||||
"model_B": f"{os.path.basename(model_B.split(' ')[0])}",
|
||||
"model_B_hash": f"{model_B.split(' ')[1]}",
|
||||
"model_O": model_O,
|
||||
"model_O_hash": model_O_hash,
|
||||
"base_alpha": sl_base_alpha,
|
||||
"weight_name": weight_name,
|
||||
"weight_values": weight_values,
|
||||
})
|
||||
|
||||
if not os.path.exists(self.filepath):
|
||||
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
|
||||
wr = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
|
||||
wr.writeheader()
|
||||
# save to file
|
||||
with open(self.filepath, "a", newline="", encoding='utf-8') as f:
|
||||
dictwriter = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
|
||||
dictwriter.writerow(_history_dict)
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
#
|
||||
#
|
||||
#
|
||||
import os
|
||||
import datetime
|
||||
from csv import DictWriter, DictReader
|
||||
|
||||
from modules import scripts
|
||||
|
||||
|
||||
CSV_FILE_PATH = "csv/history.tsv"
|
||||
HEADERS = ["model_A", "model_A_hash", "model_B", "model_B_hash", "model_O", "model_O_hash", "base_alpha", "weight_name", "weight_values", "weight_values2", "datetime"]
|
||||
path_root = scripts.basedir()
|
||||
|
||||
|
||||
class MergeHistory():
|
||||
def __init__(self):
|
||||
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
|
||||
if os.path.exists(self.filepath):
|
||||
self.update_header()
|
||||
|
||||
def add_history(self, model_A, model_B, model_O, model_O_hash, sl_base_alpha, weight_value_A, weight_value_B, weight_name=""):
|
||||
_history_dict = {}
|
||||
_history_dict.update({
|
||||
"model_A": f"{os.path.basename(model_A.split(' ')[0])}",
|
||||
"model_A_hash": f"{model_A.split(' ')[1]}",
|
||||
"model_B": f"{os.path.basename(model_B.split(' ')[0])}",
|
||||
"model_B_hash": f"{model_B.split(' ')[1]}",
|
||||
"model_O": model_O,
|
||||
"model_O_hash": model_O_hash,
|
||||
"base_alpha": sl_base_alpha,
|
||||
"weight_name": weight_name,
|
||||
"weight_values": weight_value_A,
|
||||
"weight_values2": weight_value_B,
|
||||
"datetime": f"{datetime.datetime.now()}"
|
||||
})
|
||||
|
||||
if not os.path.exists(self.filepath):
|
||||
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
|
||||
dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
|
||||
dw.writeheader()
|
||||
# save to file
|
||||
with open(self.filepath, "a", newline="", encoding='utf-8') as f:
|
||||
dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
|
||||
dw.writerow(_history_dict)
|
||||
|
||||
def update_header(self):
|
||||
hist_data = []
|
||||
if os.path.exists(self.filepath):
|
||||
# check header in case HEADERS updated
|
||||
with open(self.filepath, "r", newline="", encoding="utf-8") as f:
|
||||
dr = DictReader(f, delimiter='\t')
|
||||
new_header = [ x for x in HEADERS if x not in dr.fieldnames ]
|
||||
if len(new_header) > 0:
|
||||
# need update.
|
||||
hist_data = [ x for x in dr]
|
||||
if len(hist_data) > 0:
|
||||
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
|
||||
dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
|
||||
dw.writeheader()
|
||||
dw.writerows(hist_data)
|
||||
|
|
@ -8,14 +8,24 @@ from modules import scripts
|
|||
|
||||
|
||||
CSV_FILE_PATH = "csv/preset.tsv"
|
||||
MYPRESET_PATH = "csv/preset_own.tsv"
|
||||
HEADER = ["preset_name", "preset_weights"]
|
||||
path_root = scripts.basedir()
|
||||
|
||||
|
||||
class PresetWeights():
|
||||
def __init__(self):
|
||||
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
|
||||
self.presets = {}
|
||||
with open(self.filepath, "r") as f:
|
||||
|
||||
if os.path.exists(os.path.join(path_root, MYPRESET_PATH)):
|
||||
with open(os.path.join(path_root, MYPRESET_PATH), "r") as f:
|
||||
reader = DictReader(f, delimiter="\t")
|
||||
lines_dict = [row for row in reader]
|
||||
for line_dict in lines_dict:
|
||||
_w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")])
|
||||
self.presets.update({line_dict["preset_name"]: _w})
|
||||
|
||||
with open(os.path.join(path_root, CSV_FILE_PATH), "r") as f:
|
||||
reader = DictReader(f, delimiter="\t")
|
||||
lines_dict = [row for row in reader]
|
||||
for line_dict in lines_dict:
|
||||
|
|
@ -26,7 +36,7 @@ class PresetWeights():
|
|||
return [k for k in self.presets.keys()]
|
||||
|
||||
def find_weight_by_name(self, preset_name=""):
|
||||
if preset_name and preset_name != "" and preset_name in self.presets:
|
||||
if preset_name and preset_name != "" and preset_name in self.presets.keys():
|
||||
return self.presets.get(preset_name, ",".join(["0.5" for _ in range(25)]))
|
||||
else:
|
||||
return ""
|
||||
29
style.css
29
style.css
|
|
@ -1,4 +1,29 @@
|
|||
#mbw_sl_M00 {
|
||||
#mbw_sl_M00, #mbw_sl_a_M00, #mbw_sl_b_M00 {
|
||||
bottom:0;
|
||||
position:absolute;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
#sl_IN_A_00, #sl_IN_A_01, #sl_IN_A_02, #sl_IN_A_03, #sl_IN_A_04, #sl_IN_A_05, #sl_IN_A_06, #sl_IN_A_07, #sl_IN_A_08, #sl_IN_A_09, #sl_IN_A_10, #sl_IN_A_11 {
|
||||
width: 220;
|
||||
}
|
||||
|
||||
#sl_IN_B_00, #sl_IN_B_01, #sl_IN_B_02, #sl_IN_B_03, #sl_IN_B_04, #sl_IN_B_05, #sl_IN_B_06, #sl_IN_B_07, #sl_IN_B_08, #sl_IN_B_09, #sl_IN_B_10, #sl_IN_B_11 {
|
||||
width: 220;
|
||||
}
|
||||
*/
|
||||
|
||||
#sl_M_A_00, #sl_M_B_00 {
|
||||
bottom:0;
|
||||
position:absolute;
|
||||
}
|
||||
|
||||
/*
|
||||
#sl_OUT_A_00, #sl_OUT_A_01, #sl_OUT_A_02, #sl_OUT_A_03, #sl_OUT_A_04, #sl_OUT_A_05, #sl_OUT_A_06, #sl_OUT_A_07, #sl_OUT_A_08, #sl_OUT_A_09, #sl_OUT_A_10, #sl_OUT_A_11 {
|
||||
width: 220;
|
||||
}
|
||||
|
||||
#sl_OUT_B_00, #sl_OUT_B_01, #sl_OUT_B_02, #sl_OUT_B_03, #sl_OUT_B_04, #sl_OUT_B_05, #sl_OUT_B_06, #sl_OUT_B_07, #sl_OUT_B_08, #sl_OUT_B_09, #sl_OUT_B_10, #sl_OUT_B_11 {
|
||||
width: 220;
|
||||
}
|
||||
*/
|
||||
Loading…
Reference in New Issue