Init repo

pull/1/head
bbc_mc 2022-12-15 22:25:00 +09:00
commit e99bc8bd41
10 changed files with 442 additions and 0 deletions

118
README.md Normal file
View File

@ -0,0 +1,118 @@
# Merge Block Weighted - GUI
- This is Extension for [AUTOMATIC1111's Stable Diffusion Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
- Implementation GUI of [Merge Block Weighted] (https://note.com/kohya_ss/n/n9a485a066d5b) idea by kohya_ss
- change some part of script to adjust for AUTO1111, basic method is not changed.
![](misc/bw01.png)
## How to Install
- Go to `Extensions` tab on your web UI
- `Install from URL` with this repo URL
- Install
## How to use
### Select `model_A` and `model_B`, and input `Output model name`
![](misc/bw02.png)
- if checkpoint is updated, push `Reload Checkpoint` button to reload Dropdown choises.
### Set merge ratio for each block of U-Net
![](misc/bw03.png)
- "INxx" is input blocks. 12 blocks
- "M00" is middle block. 1 block
- "OUTxx" is output blocks. 12 blocks
![](misc/bw04.png)
- You can write your weights in "Textbox" and "Apply block weight from text"
- Weights must have 25 values and comma separated
### Setting values
![](misc/bw05.png)
- set "base_alpha"
| base_alpha | |
| ----- | ----------------------------------------------------------------------------- |
| 1 | merged model uses (Text Encoder、Auto Encoder) 100% from `model_A`|
| 0 | marged model uses (Text Encoder、Auto Encoder) 100% from `model_B`|
### Other settings
| Settings | |
| ---------------------------- | -------------------------------------------------------------- |
| verbose console output | Check true, if want to see some additional info on CLI |
| Allow overwrite output-model | Check true, if allow overwrite model file which has same name. |
- Merged output is saved in normal "Model" folder.
## Sample/Example
- kohya_ss さんのテストを再現してみる
- Compare SD15 and WD13 / Stable Diffusion 1.5 と WD 1.3 の結果を見る
- see also [Stable DiffusionのモデルをU-Netの深さに応じて比率を変えてマージするKohya S.note](https://note.com/kohya_ss/n/n9a485a066d5b)
- 準備する/マージして作るモデルは、以下の通り / Prepare models as below,
| Model Name | |
| --------------- | ----------------------------------------------------------------- |
| sd-v1.5-pruned | Stable Diffusion v1.5 |
| wd-v1.3-float32 | wd v1.3-float32 |
| SD15-WD13-ws50 | 通常マージしたもの<br>SD15 + WD13, 0.5 # Weighted sum 0.5 |
| bw-merge1-2-2 | Merge Block Weighted<br>SD15 and WD13. base_alpha=1<br>weightは後述1 |
| bw-merge2-2-2 | Merge Block Weighted<br>SD15 and WD13. base_alpha=0<br>weightは後述2 |
- テスト用のGeneration Info, Seedは 14 の4つ
```
masterpiece, best quality, beautiful anime girl, school uniform, strong rim light, intense shadows, highly detailed, cinematic lighting, taken by Canon EOS 5D Simga Art Lens 50mm f1.8 ISO 100 Shutter Speed 1000
Negative prompt: lowres, bad anatomy, bad hands, error, missing fingers, cropped, worst quality, low quality, normal quality, jpeg artifacts, blurry
Steps: 40, Sampler: Euler a, CFG scale: 7, Seed: 1, Face restoration: CodeFormer, Size: 512x512, Batch size: 4
```
### result (x/y)
![](misc/xy_plus-0000-40-7_1.png)
- 変化傾向は、
- bw-merge1 で、顔立ちがややアニメ化 (sd15-wd13-ws50と比較して)
- bw-merge2 で、ややリアル風(特に seed=3 の目が良い)
- おおまかに見て、kohya_ss さんの結果と同様の方向性になった。実装は問題ないと判断する
### 後述1: weight1
```
1, 0.9166666667, 0.8333333333, 0.75, 0.6666666667,
0.5833333333, 0.5, 0.4166666667, 0.3333333333, 0.25, 0.1666666667,
0.0833333333,
0,
0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,
0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0
```
### 後述2: weight2
```
0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,
0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,
1.0,
0.9166666667, 0.8333333333, 0.75, 0.6666666667,
0.5833333333, 0.5, 0.4166666667, 0.3333333333, 0.25, 0.1666666667,
0.0833333333, 0
```
## Special Thanks
- kohya_ss, [Stable DiffusionのモデルをU-Netの深さに応じて比率を変えてマージするKohya S.note](https://note.com/kohya_ss/n/n9a485a066d5b)

BIN
misc/bw01.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 152 KiB

BIN
misc/bw02.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.5 KiB

BIN
misc/bw03.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

BIN
misc/bw04.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

BIN
misc/bw05.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 MiB

View File

@ -0,0 +1,125 @@
# from https://note.com/kohya_ss/n/n9a485a066d5b
# kohya_ss
# original code: https://github.com/eyriewow/merge-models
# use them as base of this code
# 2022/12/15
# bbc-mc
import os
import argparse
import re
import torch
from tqdm import tqdm
from modules import sd_models
NUM_INPUT_BLOCKS = 12
NUM_MID_BLOCK = 1
NUM_OUTPUT_BLOCKS = 12
NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
def dprint(str, flg):
if flg:
print(str)
def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_file="", allow_overwrite=False, verbose=False):
if weights is None:
weights = None
else:
weights = [float(w) for w in weights.split(',')]
if len(weights) != NUM_TOTAL_BLOCKS:
_err_msg = f"weights value must be {NUM_TOTAL_BLOCKS}."
print(_err_msg)
return False, _err_msg
device = device if device in ["cpu", "cuda"] else "cpu"
def load_model(_model, _device):
model_info = sd_models.get_closet_checkpoint_match(_model)
if model_info:
model_file = model_info.filename
return sd_models.read_state_dict(model_file, map_location=_device)
print("loading", model_0)
theta_0 = load_model(model_0, device)
print("loading", model_1)
theta_1 = load_model(model_1, device)
alpha = base_alpha
if not output_file or output_file == "":
output_file = f'bw-{model_0}-{model_1}-{str(alpha)[2:] + "0"}.ckpt'
else:
output_file = output_file if ".ckpt" in output_file else output_file + ".ckpt"
# check if output file already exists
if os.path.isfile(output_file) and not allow_overwrite:
_err_msg = f"Exiting... [{output_file}]"
print(_err_msg)
return False, _err_msg
re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
dprint(f"-- start Stage 1/2 --", verbose)
for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not verbose else theta_0.keys()):
if "model" in key and key in theta_1:
dprint(f" key : {key}", verbose)
current_alpha = alpha
# check weighted and U-Net or not
if weights is not None and 'model.diffusion_model.' in key:
# check block index
weight_index = -1
if 'time_embed' in key:
weight_index = 0 # before input blocks
elif '.out.' in key:
weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
else:
m = re_inp.search(key)
if m:
inp_idx = int(m.groups()[0])
weight_index = inp_idx
else:
m = re_mid.search(key)
if m:
weight_index = NUM_INPUT_BLOCKS
else:
m = re_out.search(key)
if m:
out_idx = int(m.groups()[0])
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
if weight_index >= NUM_TOTAL_BLOCKS:
print(f"error. illegal block index: {key}")
return False, ""
if weight_index >= 0:
current_alpha = weights[weight_index]
dprint(f"weighted '{key}': {current_alpha}", verbose)
theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
else:
dprint(f" key - {key}", verbose)
dprint(f"-- start Stage 2/2 --", verbose)
for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
if "model" in key and key not in theta_0:
dprint(f" key : {key}", verbose)
theta_0.update({key:theta_1[key]})
else:
dprint(f" key - {key}", verbose)
print("Saving...")
torch.save({"state_dict": theta_0}, output_file)
print("Done!")
return True, output_file

View File

@ -0,0 +1,195 @@
# Merge block weighted Board
#
# extension of AUTOMATIC1111 web ui
#
# 2022/12/14 bbc_mc
#
import os
import gradio as gr
from modules import scripts, script_callbacks
from modules import sd_models, shared
from scripts.merge_block_weighted import merge
path_root = scripts.basedir()
#
# UI callback
#
def on_ui_tabs():
with gr.Blocks() as main_block:
with gr.Column():
with gr.Row():
with gr.Column(variant="panel"):
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
btn_clear_weighted = gr.Button(value="Clear values")
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
html_output_block_weight_info = gr.HTML()
with gr.Column():
txt_block_weight = gr.Text(placeholder="Put weight sets. float number x 25")
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text")
with gr.Row():
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.00000000001, value=1)
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
with gr.Row():
model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles())
model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles())
txt_model_O = gr.Text(label="Output Model Name")
with gr.Row():
with gr.Column():
sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
with gr.Column():
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
gr.Slider(visible=False)
sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.00000000001, value=0.5, elem_id="mbw_sl_M00")
with gr.Column():
sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.00000000001, value=0.5)
sl_IN = [
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11]
sl_MID = [sl_M_00]
sl_OUT = [
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11]
# Events
def onclick_btn_do_merge_block_weighted(
model_A, model_B,
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
):
_weights = ",".join(
[str(x) for x in [
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11
]])
#
if not model_A or not model_B:
return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]")
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
model_A_info = sd_models.get_closet_checkpoint_match(model_A)
if model_A_info:
_model_A_name = model_A_info.model_name
else:
_model_A_name = ""
model_B_info = sd_models.get_closet_checkpoint_match(model_B)
if model_B_info:
_model_B_info = model_B_info.model_name
else:
_model_B_info = ""
filename = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
if ".ckpt" not in filename:
filename = filename + ".ckpt"
_output = os.path.join(ckpt_dir, filename)
# debug output
print( "#### Merge Block Weighted ####")
print(f"model_0 : {model_A}")
print(f"model_1 : {model_B}")
print(f"base_alpha : {sl_base_alpha}")
print(f"output_file: {_output}")
print(f"weights : {_weights}")
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw)
sd_models.list_models()
if result:
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{filename}"
else:
ret_html = ret_message
return gr.update(value=f"{ret_html}")
btn_do_merge_block_weighted.click(
fn=onclick_btn_do_merge_block_weighted,
inputs=[model_A, model_B] + sl_IN + sl_MID + sl_OUT + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
outputs=[html_output_block_weight_info]
)
btn_clear_weighted.click(
fn=lambda: [gr.update(value=0.5) for _ in range(25)],
inputs=[],
outputs=[
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
]
)
def on_btn_reload_checkpoint_mbw():
sd_models.list_models()
return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())]
btn_reload_checkpoint_mbw.click(
fn=on_btn_reload_checkpoint_mbw,
inputs=[],
outputs=[model_A, model_B]
)
def on_btn_apply_block_weithg_from_txt(txt_block_weight):
if not txt_block_weight or txt_block_weight == "":
return [gr.update() for _ in range(25)]
_list = [x.strip() for x in txt_block_weight.split(",")]
if(len(_list) != 25):
return [gr.update() for _ in range(25)]
return [gr.update(value=x) for x in _list]
btn_apply_block_weithg_from_txt.click(
fn=on_btn_apply_block_weithg_from_txt,
inputs=[txt_block_weight],
outputs=[
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11,
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
]
)
# return required as (gradio_component, title, elem_id)
return (main_block, "Merge Block Weighted", "merge_block_weighted"),
# on_UI
script_callbacks.on_ui_tabs(on_ui_tabs)

4
style.css Normal file
View File

@ -0,0 +1,4 @@
#mbw_sl_M00 {
bottom:0;
position:absolute;
}