Add feature: Weight-Preset, save merge log
feature: can select weights from Dropdown with 'preset name' feature: save log of merge models, hashes, settings to history.tsvpull/1/head
parent
e99bc8bd41
commit
561f40ae6c
40
README.md
40
README.md
|
|
@ -4,7 +4,7 @@
|
|||
- Implementation GUI of [Merge Block Weighted] (https://note.com/kohya_ss/n/n9a485a066d5b) idea by kohya_ss
|
||||
- change some part of script to adjust for AUTO1111, basic method is not changed.
|
||||
|
||||

|
||||

|
||||
|
||||
## How to Install
|
||||
|
||||
|
|
@ -22,6 +22,15 @@
|
|||
|
||||
### Set merge ratio for each block of U-Net
|
||||
|
||||
- Select Presets by Dropdown
|
||||
|
||||

|
||||
|
||||
You can manage presets on tsv file (tab separated file) at `extention/<this extension>/csv/preset.tsv`
|
||||

|
||||
|
||||
- or Input at GUI Slider
|
||||
|
||||

|
||||
|
||||
- "INxx" is input blocks. 12 blocks
|
||||
|
|
@ -33,7 +42,7 @@
|
|||

|
||||
|
||||
- You can write your weights in "Textbox" and "Apply block weight from text"
|
||||
|
||||
|
||||
- Weights must have 25 values and comma separated
|
||||
|
||||
### Setting values
|
||||
|
|
@ -42,10 +51,10 @@
|
|||
|
||||
- set "base_alpha"
|
||||
|
||||
| base_alpha | |
|
||||
| ----- | ----------------------------------------------------------------------------- |
|
||||
| 1 | merged model uses (Text Encoder、Auto Encoder) 100% from `model_A`|
|
||||
| 0 | marged model uses (Text Encoder、Auto Encoder) 100% from `model_B`|
|
||||
| base_alpha | |
|
||||
| ---------- | ----------------------------------------------------------------- |
|
||||
| 1 | merged model uses (Text Encoder、Auto Encoder) 100% from `model_A` |
|
||||
| 0 | marged model uses (Text Encoder、Auto Encoder) 100% from `model_B` |
|
||||
|
||||
### Other settings
|
||||
|
||||
|
|
@ -56,15 +65,25 @@
|
|||
|
||||
- Merged output is saved in normal "Model" folder.
|
||||
|
||||
## Other function
|
||||
|
||||
### Save Merge Log
|
||||
|
||||
- save log about operated merge, as below,
|
||||

|
||||
|
||||
- log is saved at `extension/<this extension>/csv/history.tsv`
|
||||
|
||||
## Sample/Example
|
||||
|
||||
- kohya_ss さんのテストを再現してみる
|
||||
|
||||
|
||||
- Compare SD15 and WD13 / Stable Diffusion 1.5 と WD 1.3 の結果を見る
|
||||
- ※元記事は SD14 を使用 (WD13はSD14ベース)
|
||||
- see also [Stable DiffusionのモデルをU-Netの深さに応じて比率を変えてマージする|Kohya S.|note](https://note.com/kohya_ss/n/n9a485a066d5b)
|
||||
|
||||
- 準備する/マージして作るモデルは、以下の通り / Prepare models as below,
|
||||
|
||||
|
||||
| Model Name | |
|
||||
| --------------- | ----------------------------------------------------------------- |
|
||||
| sd-v1.5-pruned | Stable Diffusion v1.5 |
|
||||
|
|
@ -74,7 +93,7 @@
|
|||
| bw-merge2-2-2 | Merge Block Weighted<br>SD15 and WD13. base_alpha=0<br>weightは後述2 |
|
||||
|
||||
- テスト用のGeneration Info, Seedは 1~4 の4つ
|
||||
|
||||
|
||||
```
|
||||
masterpiece, best quality, beautiful anime girl, school uniform, strong rim light, intense shadows, highly detailed, cinematic lighting, taken by Canon EOS 5D Simga Art Lens 50mm f1.8 ISO 100 Shutter Speed 1000
|
||||
Negative prompt: lowres, bad anatomy, bad hands, error, missing fingers, cropped, worst quality, low quality, normal quality, jpeg artifacts, blurry
|
||||
|
|
@ -85,7 +104,8 @@
|
|||
|
||||

|
||||
|
||||
- 変化傾向は、
|
||||
- 変化傾向は、
|
||||
|
||||
- bw-merge1 で、顔立ちがややアニメ化 (sd15-wd13-ws50と比較して)
|
||||
- bw-merge2 で、ややリアル風(特に seed=3 の目が良い)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
model_A model_A_hash model_B model_B_hash model_O model_O_hash base_alpha weight_name weight_values
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
preset_name preset_weights
|
||||
GRAD_V 1,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0
|
||||
GRAD_A 0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0
|
||||
FLAT_25 0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25
|
||||
FLAT_75 0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75
|
||||
WRAP08 1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1
|
||||
WRAP12 1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1
|
||||
WRAP14 1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
|
||||
WRAP16 1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1
|
||||
MID12_50 0,0,0,0,0,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0,0,0,0,0,0
|
||||
OUT07 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
|
||||
OUT12 0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1
|
||||
OUT12_5 0,0,0,0,0,0,0,0,0,0,0,0,0.5,1,1,1,1,1,1,1,1,1,1,1,1
|
||||
RING08_SOFT 0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0
|
||||
RING08_5 0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0
|
||||
RING10_5 0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0
|
||||
RING10_3 0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0
|
||||
RING10_3 0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0
|
||||
|
Binary file not shown.
|
After Width: | Height: | Size: 127 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 9.1 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 9.2 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 31 MiB |
|
|
@ -67,6 +67,7 @@ def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_f
|
|||
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
|
||||
|
||||
dprint(f"-- start Stage 1/2 --", verbose)
|
||||
count_target_of_basealpha = 0
|
||||
for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not verbose else theta_0.keys()):
|
||||
if "model" in key and key in theta_1:
|
||||
dprint(f" key : {key}", verbose)
|
||||
|
|
@ -102,6 +103,9 @@ def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_f
|
|||
if weight_index >= 0:
|
||||
current_alpha = weights[weight_index]
|
||||
dprint(f"weighted '{key}': {current_alpha}", verbose)
|
||||
else:
|
||||
count_target_of_basealpha = count_target_of_basealpha + 1
|
||||
dprint(f"base_alpha applied: [{key}]", verbose)
|
||||
|
||||
theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
|
||||
|
||||
|
|
@ -122,4 +126,4 @@ def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_f
|
|||
|
||||
print("Done!")
|
||||
|
||||
return True, output_file
|
||||
return True, f"{output_file}<br>base_alpha applied [{count_target_of_basealpha}] times."
|
||||
|
|
|
|||
|
|
@ -12,9 +12,13 @@ from modules import scripts, script_callbacks
|
|||
from modules import sd_models, shared
|
||||
|
||||
from scripts.merge_block_weighted import merge
|
||||
from scripts.merge_history import MergeHistory
|
||||
from scripts.preset_weights import PresetWeights
|
||||
|
||||
path_root = scripts.basedir()
|
||||
|
||||
mergeHistory = MergeHistory()
|
||||
presetWeights = PresetWeights()
|
||||
|
||||
#
|
||||
# UI callback
|
||||
|
|
@ -30,8 +34,9 @@ def on_ui_tabs():
|
|||
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
|
||||
html_output_block_weight_info = gr.HTML()
|
||||
with gr.Column():
|
||||
txt_block_weight = gr.Text(placeholder="Put weight sets. float number x 25")
|
||||
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text")
|
||||
dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list())
|
||||
txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25")
|
||||
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
|
||||
with gr.Row():
|
||||
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=1)
|
||||
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
|
||||
|
|
@ -121,13 +126,18 @@ def on_ui_tabs():
|
|||
_model_B_info = model_B_info.model_name
|
||||
else:
|
||||
_model_B_info = ""
|
||||
filename = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
|
||||
if ".ckpt" not in filename:
|
||||
filename = filename + ".ckpt"
|
||||
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
|
||||
if ".ckpt" not in model_O:
|
||||
model_O = model_O + ".ckpt"
|
||||
|
||||
_output = os.path.join(ckpt_dir, filename)
|
||||
_output = os.path.join(ckpt_dir, model_O)
|
||||
# debug output
|
||||
print( "#### Merge Block Weighted ####")
|
||||
if not chk_allow_overwrite:
|
||||
if os.path.exists(_output):
|
||||
_err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort."
|
||||
print(_err_msg)
|
||||
return gr.update(value=f"{_err_msg} [{_output}]")
|
||||
print(f"model_0 : {model_A}")
|
||||
print(f"model_1 : {model_B}")
|
||||
print(f"base_alpha : {sl_base_alpha}")
|
||||
|
|
@ -138,9 +148,20 @@ def on_ui_tabs():
|
|||
|
||||
sd_models.list_models()
|
||||
if result:
|
||||
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{filename}"
|
||||
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{model_O}"
|
||||
else:
|
||||
ret_html = ret_message
|
||||
|
||||
# save log to history.tsv
|
||||
model_O_info = sd_models.get_closet_checkpoint_match(model_O)
|
||||
model_O_hash = "" if not model_O_info else model_O_info.hash
|
||||
_names = presetWeights.find_names_by_weight(_weights)
|
||||
if _names and len(_names) > 0:
|
||||
weight_name = _names[0]
|
||||
else:
|
||||
weight_name = ""
|
||||
mergeHistory.add_history(model_A, model_B, model_O, model_O_hash, sl_base_alpha, _weights, weight_name)
|
||||
|
||||
return gr.update(value=f"{ret_html}")
|
||||
btn_do_merge_block_weighted.click(
|
||||
fn=onclick_btn_do_merge_block_weighted,
|
||||
|
|
@ -160,6 +181,22 @@ def on_ui_tabs():
|
|||
]
|
||||
)
|
||||
|
||||
def on_change_dd_preset_weight(dd_preset_weight):
|
||||
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
|
||||
_ret = on_btn_apply_block_weithg_from_txt(_weights)
|
||||
return [gr.update(value=_weights)] + _ret
|
||||
dd_preset_weight.change(
|
||||
fn=on_change_dd_preset_weight,
|
||||
inputs=[dd_preset_weight],
|
||||
outputs=[txt_block_weight,
|
||||
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
|
||||
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
|
||||
sl_M_00,
|
||||
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
|
||||
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
|
||||
]
|
||||
)
|
||||
|
||||
def on_btn_reload_checkpoint_mbw():
|
||||
sd_models.list_models()
|
||||
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
#
|
||||
#
|
||||
#
|
||||
import os
|
||||
from csv import DictWriter, writer
|
||||
|
||||
from modules import scripts
|
||||
|
||||
|
||||
CSV_FILE_PATH = "csv/history.tsv"
|
||||
HEADERS = ["model_A", "model_A_hash", "model_B", "model_B_hash", "model_O", "model_O_hash", "base_alpha", "weight_name", "weight_values"]
|
||||
path_root = scripts.basedir()
|
||||
|
||||
|
||||
class MergeHistory():
|
||||
def __init__(self):
|
||||
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
|
||||
|
||||
def add_history(self, model_A, model_B, model_O, model_O_hash, sl_base_alpha, weight_values, weight_name=""):
|
||||
_history_dict = {}
|
||||
_history_dict.update({
|
||||
"model_A": f"{os.path.basename(model_A.split(' ')[0])}",
|
||||
"model_A_hash": f"{model_A.split(' ')[1]}",
|
||||
"model_B": f"{os.path.basename(model_B.split(' ')[0])}",
|
||||
"model_B_hash": f"{model_B.split(' ')[1]}",
|
||||
"model_O": model_O,
|
||||
"model_O_hash": model_O_hash,
|
||||
"base_alpha": sl_base_alpha,
|
||||
"weight_name": weight_name,
|
||||
"weight_values": weight_values,
|
||||
})
|
||||
|
||||
if not os.path.exists(self.filepath):
|
||||
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
|
||||
wr = writer(f, fieldnames=HEADERS, delimiter='\t')
|
||||
wr.writerow(HEADERS)
|
||||
# save to file
|
||||
with open(self.filepath, "a", newline="", encoding='utf-8') as f:
|
||||
dictwriter = DictWriter(f, fieldnames=HEADERS, delimiter='\t')
|
||||
dictwriter.writerow(_history_dict)
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
#
|
||||
#
|
||||
#
|
||||
import os
|
||||
from csv import DictReader
|
||||
|
||||
from modules import scripts
|
||||
|
||||
|
||||
CSV_FILE_PATH = "csv/preset.tsv"
|
||||
path_root = scripts.basedir()
|
||||
|
||||
|
||||
class PresetWeights():
|
||||
def __init__(self):
|
||||
self.filepath = os.path.join(path_root, CSV_FILE_PATH)
|
||||
self.presets = {}
|
||||
with open(self.filepath, "r") as f:
|
||||
reader = DictReader(f, delimiter="\t")
|
||||
lines_dict = [row for row in reader]
|
||||
for line_dict in lines_dict:
|
||||
_w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")])
|
||||
self.presets.update({line_dict["preset_name"]: _w})
|
||||
|
||||
def get_preset_name_list(self):
|
||||
return [k for k in self.presets.keys()]
|
||||
|
||||
def find_weight_by_name(self, preset_name=""):
|
||||
if preset_name and preset_name != "" and preset_name in self.presets:
|
||||
return self.presets.get(preset_name, ",".join(["0.5" for _ in range(25)]))
|
||||
else:
|
||||
return ""
|
||||
|
||||
def find_names_by_weight(self, weights=""):
|
||||
if weights and weights != "":
|
||||
if weights in self.presets.values():
|
||||
return [k for k, v in self.presets.items() if v == weights]
|
||||
else:
|
||||
_val = ",".join([f"{x.strip()}" for x in weights.split(",")])
|
||||
if _val in self.presets.values():
|
||||
return [k for k, v in self.presets.items() if v == _val]
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
Loading…
Reference in New Issue