Add feature: Weight-Preset, save merge log

feature: can select weights from Dropdown with 'preset name'
feature: save log of merge models, hashes, settings to history.tsv
pull/1/head
bbc_mc 2022-12-17 21:00:00 +09:00
parent e99bc8bd41
commit 561f40ae6c
12 changed files with 183 additions and 18 deletions

View File

@ -4,7 +4,7 @@
- Implementation GUI of [Merge Block Weighted] (https://note.com/kohya_ss/n/n9a485a066d5b) idea by kohya_ss
- change some part of script to adjust for AUTO1111, basic method is not changed.
![](misc/bw01.png)
![](misc/bw01-1.png)
## How to Install
@ -22,6 +22,15 @@
### Set merge ratio for each block of U-Net
- Select Presets by Dropdown
![](misc/bw08.png)
You can manage presets on tsv file (tab separated file) at `extention/<this extension>/csv/preset.tsv`
![](misc/bw06.png)
- or Input at GUI Slider
![](misc/bw03.png)
- "INxx" is input blocks. 12 blocks
@ -33,7 +42,7 @@
![](misc/bw04.png)
- You can write your weights in "Textbox" and "Apply block weight from text"
- Weights must have 25 values and comma separated
### Setting values
@ -42,10 +51,10 @@
- set "base_alpha"
| base_alpha | |
| ----- | ----------------------------------------------------------------------------- |
| 1 | merged model uses (Text Encoder、Auto Encoder) 100% from `model_A`|
| 0 | marged model uses (Text Encoder、Auto Encoder) 100% from `model_B`|
| base_alpha | |
| ---------- | ----------------------------------------------------------------- |
| 1 | merged model uses (Text Encoder、Auto Encoder) 100% from `model_A` |
| 0 | marged model uses (Text Encoder、Auto Encoder) 100% from `model_B` |
### Other settings
@ -56,15 +65,25 @@
- Merged output is saved in normal "Model" folder.
## Other function
### Save Merge Log
- save log about operated merge, as below,
![](misc/bw07.png)
- log is saved at `extension/<this extension>/csv/history.tsv`
## Sample/Example
- kohya_ss さんのテストを再現してみる
- Compare SD15 and WD13 / Stable Diffusion 1.5 と WD 1.3 の結果を見る
- ※元記事は SD14 を使用 (WD13はSD14ベース)
- see also [Stable DiffusionのモデルをU-Netの深さに応じて比率を変えてマージするKohya S.note](https://note.com/kohya_ss/n/n9a485a066d5b)
- 準備する/マージして作るモデルは、以下の通り / Prepare models as below,
| Model Name | |
| --------------- | ----------------------------------------------------------------- |
| sd-v1.5-pruned | Stable Diffusion v1.5 |
@ -74,7 +93,7 @@
| bw-merge2-2-2 | Merge Block Weighted<br>SD15 and WD13. base_alpha=0<br>weightは後述2 |
- テスト用のGeneration Info, Seedは 14 の4つ
```
masterpiece, best quality, beautiful anime girl, school uniform, strong rim light, intense shadows, highly detailed, cinematic lighting, taken by Canon EOS 5D Simga Art Lens 50mm f1.8 ISO 100 Shutter Speed 1000
Negative prompt: lowres, bad anatomy, bad hands, error, missing fingers, cropped, worst quality, low quality, normal quality, jpeg artifacts, blurry
@ -85,7 +104,8 @@
![](misc/xy_plus-0000-40-7_1.png)
- 変化傾向は、
- 変化傾向は、
- bw-merge1 で、顔立ちがややアニメ化 (sd15-wd13-ws50と比較して)
- bw-merge2 で、ややリアル風(特に seed=3 の目が良い)

1
csv/history.tsv Normal file
View File

@ -0,0 +1 @@
model_A model_A_hash model_B model_B_hash model_O model_O_hash base_alpha weight_name weight_values
1 model_A model_A_hash model_B model_B_hash model_O model_O_hash base_alpha weight_name weight_values

18
csv/preset.tsv Normal file
View File

@ -0,0 +1,18 @@
preset_name preset_weights
GRAD_V 1,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0
GRAD_A 0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0
FLAT_25 0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25
FLAT_75 0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75
WRAP08 1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1
WRAP12 1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1
WRAP14 1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
WRAP16 1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1
MID12_50 0,0,0,0,0,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0,0,0,0,0,0
OUT07 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
OUT12 0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1
OUT12_5 0,0,0,0,0,0,0,0,0,0,0,0,0.5,1,1,1,1,1,1,1,1,1,1,1,1
RING08_SOFT 0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0
RING08_5 0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0
RING10_5 0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0
RING10_3 0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0
RING10_3 0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0
1 preset_name preset_weights
2 GRAD_V 1,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0
3 GRAD_A 0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0
4 FLAT_25 0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25
5 FLAT_75 0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75
6 WRAP08 1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1
7 WRAP12 1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1
8 WRAP14 1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
9 WRAP16 1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1
10 MID12_50 0,0,0,0,0,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0,0,0,0,0,0
11 OUT07 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
12 OUT12 0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1
13 OUT12_5 0,0,0,0,0,0,0,0,0,0,0,0,0.5,1,1,1,1,1,1,1,1,1,1,1,1
14 RING08_SOFT 0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0
15 RING08_5 0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0
16 RING10_5 0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0
17 RING10_3 0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0
18 RING10_3 0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0

BIN
misc/bw01-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 127 KiB

BIN
misc/bw06.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.1 KiB

BIN
misc/bw07.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.2 KiB

BIN
misc/bw08.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 MiB

View File

@ -67,6 +67,7 @@ def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_f
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
dprint(f"-- start Stage 1/2 --", verbose)
count_target_of_basealpha = 0
for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not verbose else theta_0.keys()):
if "model" in key and key in theta_1:
dprint(f" key : {key}", verbose)
@ -102,6 +103,9 @@ def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_f
if weight_index >= 0:
current_alpha = weights[weight_index]
dprint(f"weighted '{key}': {current_alpha}", verbose)
else:
count_target_of_basealpha = count_target_of_basealpha + 1
dprint(f"base_alpha applied: [{key}]", verbose)
theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
@ -122,4 +126,4 @@ def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_f
print("Done!")
return True, output_file
return True, f"{output_file}<br>base_alpha applied [{count_target_of_basealpha}] times."

View File

@ -12,9 +12,13 @@ from modules import scripts, script_callbacks
from modules import sd_models, shared
from scripts.merge_block_weighted import merge
from scripts.merge_history import MergeHistory
from scripts.preset_weights import PresetWeights
path_root = scripts.basedir()
mergeHistory = MergeHistory()
presetWeights = PresetWeights()
#
# UI callback
@ -30,8 +34,9 @@ def on_ui_tabs():
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
html_output_block_weight_info = gr.HTML()
with gr.Column():
txt_block_weight = gr.Text(placeholder="Put weight sets. float number x 25")
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text")
dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list())
txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25")
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
with gr.Row():
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=1)
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
@ -121,13 +126,18 @@ def on_ui_tabs():
_model_B_info = model_B_info.model_name
else:
_model_B_info = ""
filename = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
if ".ckpt" not in filename:
filename = filename + ".ckpt"
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
if ".ckpt" not in model_O:
model_O = model_O + ".ckpt"
_output = os.path.join(ckpt_dir, filename)
_output = os.path.join(ckpt_dir, model_O)
# debug output
print( "#### Merge Block Weighted ####")
if not chk_allow_overwrite:
if os.path.exists(_output):
_err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort."
print(_err_msg)
return gr.update(value=f"{_err_msg} [{_output}]")
print(f"model_0 : {model_A}")
print(f"model_1 : {model_B}")
print(f"base_alpha : {sl_base_alpha}")
@ -138,9 +148,20 @@ def on_ui_tabs():
sd_models.list_models()
if result:
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{filename}"
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{model_O}"
else:
ret_html = ret_message
# save log to history.tsv
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
model_O_hash = "" if not model_O_info else model_O_info.hash
_names = presetWeights.find_names_by_weight(_weights)
if _names and len(_names) > 0:
weight_name = _names[0]
else:
weight_name = ""
mergeHistory.add_history(model_A, model_B, model_O, model_O_hash, sl_base_alpha, _weights, weight_name)
return gr.update(value=f"{ret_html}")
btn_do_merge_block_weighted.click(
fn=onclick_btn_do_merge_block_weighted,
@ -160,6 +181,22 @@ def on_ui_tabs():
]
)
def on_change_dd_preset_weight(dd_preset_weight):
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
_ret = on_btn_apply_block_weithg_from_txt(_weights)
return [gr.update(value=_weights)] + _ret
dd_preset_weight.change(
fn=on_change_dd_preset_weight,
inputs=[dd_preset_weight],
outputs=[txt_block_weight,
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
]
)
def on_btn_reload_checkpoint_mbw():
sd_models.list_models()
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]

40
scripts/merge_history.py Normal file
View File

@ -0,0 +1,40 @@
#
#
#
import os
from csv import DictWriter, writer
from modules import scripts
CSV_FILE_PATH = "csv/history.tsv"
HEADERS = ["model_A", "model_A_hash", "model_B", "model_B_hash", "model_O", "model_O_hash", "base_alpha", "weight_name", "weight_values"]
path_root = scripts.basedir()
class MergeHistory():
def __init__(self):
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
def add_history(self, model_A, model_B, model_O, model_O_hash, sl_base_alpha, weight_values, weight_name=""):
_history_dict = {}
_history_dict.update({
"model_A": f"{os.path.basename(model_A.split(' ')[0])}",
"model_A_hash": f"{model_A.split(' ')[1]}",
"model_B": f"{os.path.basename(model_B.split(' ')[0])}",
"model_B_hash": f"{model_B.split(' ')[1]}",
"model_O": model_O,
"model_O_hash": model_O_hash,
"base_alpha": sl_base_alpha,
"weight_name": weight_name,
"weight_values": weight_values,
})
if not os.path.exists(self.filepath):
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
wr = writer(f, fieldnames=HEADERS, delimiter='\t')
wr.writerow(HEADERS)
# save to file
with open(self.filepath, "a", newline="", encoding='utf-8') as f:
dictwriter = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
dictwriter.writerow(_history_dict)

45
scripts/preset_weights.py Normal file
View File

@ -0,0 +1,45 @@
#
#
#
import os
from csv import DictReader
from modules import scripts
CSV_FILE_PATH = "csv/preset.tsv"
path_root = scripts.basedir()
class PresetWeights():
def __init__(self):
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
self.presets = {}
with open(self.filepath, "r") as f:
reader = DictReader(f, delimiter="\t")
lines_dict = [row for row in reader]
for line_dict in lines_dict:
_w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")])
self.presets.update({line_dict["preset_name"]: _w})
def get_preset_name_list(self):
return [k for k in self.presets.keys()]
def find_weight_by_name(self, preset_name=""):
if preset_name and preset_name != "" and preset_name in self.presets:
return self.presets.get(preset_name, ",".join(["0.5" for _ in range(25)]))
else:
return ""
def find_names_by_weight(self, weights=""):
if weights and weights != "":
if weights in self.presets.values():
return [k for k, v in self.presets.items() if v == weights]
else:
_val = ",".join([f"{x.strip()}" for x in weights.split(",")])
if _val in self.presets.values():
return [k for k, v in self.presets.items() if v == _val]
else:
return []
else:
return []