mirror of https://github.com/bmaltais/kohya_ss
Add new Merge LyCORIS models
parent
8b1ceee5bd
commit
55d6d7a95d
|
|
@ -310,6 +310,7 @@ This will store a backup file with your current locally installed pip packages a
|
|||
- Fixed `lora_interrogator.py` not working. Please refer to [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) for details. Thank you A2va and heyalexchoi!
|
||||
- Fixed the handling of tags containing `_` in `tag_images_by_wd14_tagger.py`.
|
||||
- Add new Extract DyLoRA gui to the Utilities tab.
|
||||
- Add new Merge LyCORIS models into checkpoint gui to the Utilities tab.
|
||||
* 2023/04/17 (v21.5.4)
|
||||
- Fixed a bug that caused an error when loading DyLoRA with the `--network_weight` option in `train_network.py`.
|
||||
- Added the `--recursive` option to each script in the `finetune` folder to process folders recursively. Please refer to [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) for details. Thanks to Linaqruf!
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
|
|||
from library.merge_lora_gui import gradio_merge_lora_tab
|
||||
from library.resize_lora_gui import gradio_resize_lora_tab
|
||||
from library.extract_lora_from_dylora_gui import gradio_extract_dylora_tab
|
||||
from library.merge_lycoris_gui import gradio_merge_lycoris_tab
|
||||
from lora_gui import lora_tab
|
||||
|
||||
|
||||
|
|
@ -49,6 +50,7 @@ def UI(**kwargs):
|
|||
gradio_extract_lora_tab()
|
||||
gradio_extract_lycoris_locon_tab()
|
||||
gradio_merge_lora_tab()
|
||||
gradio_merge_lycoris_tab()
|
||||
gradio_resize_lora_tab()
|
||||
|
||||
# Show the interface
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ def gradio_merge_lora_tab():
|
|||
|
||||
with gr.Row():
|
||||
ratio_c = gr.Slider(
|
||||
label='Model C erge ratio (eg: 0.5 mean 50%)',
|
||||
label='Model C merge ratio (eg: 0.5 mean 50%)',
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,152 @@
|
|||
import gradio as gr
|
||||
from easygui import msgbox
|
||||
import subprocess
|
||||
import os
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_file_path,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||
|
||||
def merge_lycoris(
|
||||
base_model,
|
||||
lycoris_model,
|
||||
weight,
|
||||
output_name,
|
||||
dtype,
|
||||
device,
|
||||
is_v2,
|
||||
):
|
||||
print('Merge model...')
|
||||
|
||||
run_cmd = f'{PYTHON} "{os.path.join("tools","merge_lycoris.py")}"'
|
||||
run_cmd += f' {base_model}'
|
||||
run_cmd += f' "{lycoris_model}"'
|
||||
run_cmd += f' "{output_name}"'
|
||||
run_cmd += f' --weight {weight}'
|
||||
run_cmd += f' --device {device}'
|
||||
run_cmd += f' --dtype {dtype}'
|
||||
if is_v2:
|
||||
run_cmd += f' --is_v2'
|
||||
|
||||
print(run_cmd)
|
||||
|
||||
# Run the command
|
||||
if os.name == 'posix':
|
||||
os.system(run_cmd)
|
||||
else:
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
print('Done merging...')
|
||||
|
||||
###
|
||||
# Gradio UI
|
||||
###
|
||||
|
||||
|
||||
def gradio_merge_lycoris_tab():
|
||||
with gr.Tab('Merge LyCORIS'):
|
||||
gr.Markdown('This utility can merge a LyCORIS model into a SD checkpoint.')
|
||||
|
||||
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
ckpt_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
|
||||
ckpt_ext_name = gr.Textbox(value='SD model types', visible=False)
|
||||
|
||||
with gr.Row():
|
||||
base_model = gr.Textbox(
|
||||
label='SD Model',
|
||||
placeholder='(Optional) Stable Diffusion base model',
|
||||
interactive=True,
|
||||
info='Provide a SD file path that you want to merge with the LyCORIS file'
|
||||
)
|
||||
base_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
base_model_file.click(
|
||||
get_file_path,
|
||||
inputs=[base_model, ckpt_ext, ckpt_ext_name],
|
||||
outputs=base_model,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
lycoris_model = gr.Textbox(
|
||||
label='LyCORIS model',
|
||||
placeholder='Path to the LyCORIS model',
|
||||
interactive=True,
|
||||
)
|
||||
button_lycoris_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lycoris_model_file.click(
|
||||
get_file_path,
|
||||
inputs=[lycoris_model, lora_ext, lora_ext_name],
|
||||
outputs=lycoris_model,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
weight = gr.Slider(
|
||||
label='Model A merge ratio (eg: 0.5 mean 50%)',
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
value=1.0,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
output_name = gr.Textbox(
|
||||
label='Save to',
|
||||
placeholder='path for the checkpoint file to save...',
|
||||
interactive=True,
|
||||
)
|
||||
button_output_name = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_output_name.click(
|
||||
get_saveasfilename_path,
|
||||
inputs=[output_name, lora_ext, lora_ext_name],
|
||||
outputs=output_name,
|
||||
show_progress=False,
|
||||
)
|
||||
dtype = gr.Dropdown(
|
||||
label='Save dtype',
|
||||
choices=['float', 'float16', 'float32', 'float64', 'bfloat', 'bfloat16'],
|
||||
value='float16',
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
device = gr.Dropdown(
|
||||
label='Device',
|
||||
choices=[
|
||||
'cpu',
|
||||
# 'cuda',
|
||||
],
|
||||
value='cpu',
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True)
|
||||
|
||||
merge_button = gr.Button('Merge model')
|
||||
|
||||
merge_button.click(
|
||||
merge_lycoris,
|
||||
inputs=[
|
||||
base_model,
|
||||
lycoris_model,
|
||||
weight,
|
||||
output_name,
|
||||
dtype,
|
||||
device,
|
||||
is_v2,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
@ -0,0 +1,504 @@
|
|||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import torch.linalg as linalg
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def make_sparse(t: torch.Tensor, sparsity=0.95):
|
||||
abs_t = torch.abs(t)
|
||||
np_array = abs_t.detach().cpu().numpy()
|
||||
quan = float(np.quantile(np_array, sparsity))
|
||||
sparse_t = t.masked_fill(abs_t < quan, 0)
|
||||
return sparse_t
|
||||
|
||||
|
||||
def extract_conv(
|
||||
weight: Union[torch.Tensor, nn.Parameter],
|
||||
mode = 'fixed',
|
||||
mode_param = 0,
|
||||
device = 'cpu',
|
||||
is_cp = False,
|
||||
) -> Tuple[nn.Parameter, nn.Parameter]:
|
||||
weight = weight.to(device)
|
||||
out_ch, in_ch, kernel_size, _ = weight.shape
|
||||
|
||||
U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
|
||||
|
||||
if mode=='fixed':
|
||||
lora_rank = mode_param
|
||||
elif mode=='threshold':
|
||||
assert mode_param>=0
|
||||
lora_rank = torch.sum(S>mode_param)
|
||||
elif mode=='ratio':
|
||||
assert 1>=mode_param>=0
|
||||
min_s = torch.max(S)*mode_param
|
||||
lora_rank = torch.sum(S>min_s)
|
||||
elif mode=='quantile' or mode=='percentile':
|
||||
assert 1>=mode_param>=0
|
||||
s_cum = torch.cumsum(S, dim=0)
|
||||
min_cum_sum = mode_param * torch.sum(S)
|
||||
lora_rank = torch.sum(s_cum<min_cum_sum)
|
||||
else:
|
||||
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
||||
lora_rank = max(1, lora_rank)
|
||||
lora_rank = min(out_ch, in_ch, lora_rank)
|
||||
if lora_rank>=out_ch/2 and not is_cp:
|
||||
return weight, 'full'
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
|
||||
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
|
||||
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
|
||||
del U, S, Vh, weight
|
||||
return (extract_weight_A, extract_weight_B, diff), 'low rank'
|
||||
|
||||
|
||||
def extract_linear(
|
||||
weight: Union[torch.Tensor, nn.Parameter],
|
||||
mode = 'fixed',
|
||||
mode_param = 0,
|
||||
device = 'cpu',
|
||||
) -> Tuple[nn.Parameter, nn.Parameter]:
|
||||
weight = weight.to(device)
|
||||
out_ch, in_ch = weight.shape
|
||||
|
||||
U, S, Vh = linalg.svd(weight)
|
||||
|
||||
if mode=='fixed':
|
||||
lora_rank = mode_param
|
||||
elif mode=='threshold':
|
||||
assert mode_param>=0
|
||||
lora_rank = torch.sum(S>mode_param)
|
||||
elif mode=='ratio':
|
||||
assert 1>=mode_param>=0
|
||||
min_s = torch.max(S)*mode_param
|
||||
lora_rank = torch.sum(S>min_s)
|
||||
elif mode=='quantile' or mode=='percentile':
|
||||
assert 1>=mode_param>=0
|
||||
s_cum = torch.cumsum(S, dim=0)
|
||||
min_cum_sum = mode_param * torch.sum(S)
|
||||
lora_rank = torch.sum(s_cum<min_cum_sum)
|
||||
else:
|
||||
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
||||
lora_rank = max(1, lora_rank)
|
||||
lora_rank = min(out_ch, in_ch, lora_rank)
|
||||
if lora_rank>=out_ch/2:
|
||||
return weight, 'full'
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
diff = (weight - U @ Vh).detach()
|
||||
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
|
||||
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
|
||||
del U, S, Vh, weight
|
||||
return (extract_weight_A, extract_weight_B, diff), 'low rank'
|
||||
|
||||
|
||||
def extract_diff(
|
||||
base_model,
|
||||
db_model,
|
||||
mode = 'fixed',
|
||||
linear_mode_param = 0,
|
||||
conv_mode_param = 0,
|
||||
extract_device = 'cpu',
|
||||
use_bias = False,
|
||||
sparsity = 0.98,
|
||||
small_conv = True
|
||||
):
|
||||
UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
"Attention",
|
||||
"ResnetBlock2D",
|
||||
"Downsample2D",
|
||||
"Upsample2D"
|
||||
]
|
||||
UNET_TARGET_REPLACE_NAME = [
|
||||
"conv_in",
|
||||
"conv_out",
|
||||
"time_embedding.linear_1",
|
||||
"time_embedding.linear_2",
|
||||
]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = 'lora_unet'
|
||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||
def make_state_dict(
|
||||
prefix,
|
||||
root_module: torch.nn.Module,
|
||||
target_module: torch.nn.Module,
|
||||
target_replace_modules,
|
||||
target_replace_names = []
|
||||
):
|
||||
loras = {}
|
||||
temp = {}
|
||||
temp_name = {}
|
||||
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
temp[name] = {}
|
||||
for child_name, child_module in module.named_modules():
|
||||
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
|
||||
continue
|
||||
temp[name][child_name] = child_module.weight
|
||||
elif name in target_replace_names:
|
||||
temp_name[name] = module.weight
|
||||
|
||||
for name, module in tqdm(list(target_module.named_modules())):
|
||||
if name in temp:
|
||||
weights = temp[name]
|
||||
for child_name, child_module in module.named_modules():
|
||||
lora_name = prefix + '.' + name + '.' + child_name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
layer = child_module.__class__.__name__
|
||||
if layer in {'Linear', 'Conv2d'}:
|
||||
root_weight = child_module.weight
|
||||
if torch.allclose(root_weight, weights[child_name]):
|
||||
continue
|
||||
|
||||
if layer == 'Linear':
|
||||
weight, decompose_mode = extract_linear(
|
||||
(child_module.weight - weights[child_name]),
|
||||
mode,
|
||||
linear_mode_param,
|
||||
device = extract_device,
|
||||
)
|
||||
if decompose_mode == 'low rank':
|
||||
extract_a, extract_b, diff = weight
|
||||
elif layer == 'Conv2d':
|
||||
is_linear = (child_module.weight.shape[2] == 1
|
||||
and child_module.weight.shape[3] == 1)
|
||||
weight, decompose_mode = extract_conv(
|
||||
(child_module.weight - weights[child_name]),
|
||||
mode,
|
||||
linear_mode_param if is_linear else conv_mode_param,
|
||||
device = extract_device,
|
||||
)
|
||||
if decompose_mode == 'low rank':
|
||||
extract_a, extract_b, diff = weight
|
||||
if small_conv and not is_linear and decompose_mode == 'low rank':
|
||||
dim = extract_a.size(0)
|
||||
(extract_c, extract_a, _), _ = extract_conv(
|
||||
extract_a.transpose(0, 1),
|
||||
'fixed', dim,
|
||||
extract_device, True
|
||||
)
|
||||
extract_a = extract_a.transpose(0, 1)
|
||||
extract_c = extract_c.transpose(0, 1)
|
||||
loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
|
||||
diff = child_module.weight - torch.einsum(
|
||||
'i j k l, j r, p i -> p r k l',
|
||||
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
|
||||
).detach().cpu().contiguous()
|
||||
del extract_c
|
||||
else:
|
||||
continue
|
||||
if decompose_mode == 'low rank':
|
||||
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
|
||||
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
|
||||
loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
|
||||
if use_bias:
|
||||
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
||||
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
|
||||
|
||||
indices = sparse_diff.indices().to(torch.int16)
|
||||
values = sparse_diff.values().half()
|
||||
loras[f'{lora_name}.bias_indices'] = indices
|
||||
loras[f'{lora_name}.bias_values'] = values
|
||||
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
|
||||
del extract_a, extract_b, diff
|
||||
elif decompose_mode == 'full':
|
||||
loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
elif name in temp_name:
|
||||
weights = temp_name[name]
|
||||
lora_name = prefix + '.' + name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
layer = module.__class__.__name__
|
||||
|
||||
if layer in {'Linear', 'Conv2d'}:
|
||||
root_weight = module.weight
|
||||
if torch.allclose(root_weight, weights):
|
||||
continue
|
||||
|
||||
if layer == 'Linear':
|
||||
weight, decompose_mode = extract_linear(
|
||||
(root_weight - weights),
|
||||
mode,
|
||||
linear_mode_param,
|
||||
device = extract_device,
|
||||
)
|
||||
if decompose_mode == 'low rank':
|
||||
extract_a, extract_b, diff = weight
|
||||
elif layer == 'Conv2d':
|
||||
is_linear = (
|
||||
root_weight.shape[2] == 1
|
||||
and root_weight.shape[3] == 1
|
||||
)
|
||||
weight, decompose_mode = extract_conv(
|
||||
(root_weight - weights),
|
||||
mode,
|
||||
linear_mode_param if is_linear else conv_mode_param,
|
||||
device = extract_device,
|
||||
)
|
||||
if decompose_mode == 'low rank':
|
||||
extract_a, extract_b, diff = weight
|
||||
if small_conv and not is_linear and decompose_mode == 'low rank':
|
||||
dim = extract_a.size(0)
|
||||
(extract_c, extract_a, _), _ = extract_conv(
|
||||
extract_a.transpose(0, 1),
|
||||
'fixed', dim,
|
||||
extract_device, True
|
||||
)
|
||||
extract_a = extract_a.transpose(0, 1)
|
||||
extract_c = extract_c.transpose(0, 1)
|
||||
loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
|
||||
diff = root_weight - torch.einsum(
|
||||
'i j k l, j r, p i -> p r k l',
|
||||
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
|
||||
).detach().cpu().contiguous()
|
||||
del extract_c
|
||||
else:
|
||||
continue
|
||||
if decompose_mode == 'low rank':
|
||||
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
|
||||
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
|
||||
loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
|
||||
if use_bias:
|
||||
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
||||
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
|
||||
|
||||
indices = sparse_diff.indices().to(torch.int16)
|
||||
values = sparse_diff.values().half()
|
||||
loras[f'{lora_name}.bias_indices'] = indices
|
||||
loras[f'{lora_name}.bias_values'] = values
|
||||
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
|
||||
del extract_a, extract_b, diff
|
||||
elif decompose_mode == 'full':
|
||||
loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loras
|
||||
|
||||
text_encoder_loras = make_state_dict(
|
||||
LORA_PREFIX_TEXT_ENCODER,
|
||||
base_model[0], db_model[0],
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
)
|
||||
|
||||
unet_loras = make_state_dict(
|
||||
LORA_PREFIX_UNET,
|
||||
base_model[2], db_model[2],
|
||||
UNET_TARGET_REPLACE_MODULE,
|
||||
UNET_TARGET_REPLACE_NAME
|
||||
)
|
||||
print(len(text_encoder_loras), len(unet_loras))
|
||||
return text_encoder_loras|unet_loras
|
||||
|
||||
|
||||
def get_module(
|
||||
lyco_state_dict: Dict,
|
||||
lora_name
|
||||
):
|
||||
if f'{lora_name}.lora_up.weight' in lyco_state_dict:
|
||||
up = lyco_state_dict[f'{lora_name}.lora_up.weight']
|
||||
down = lyco_state_dict[f'{lora_name}.lora_down.weight']
|
||||
mid = lyco_state_dict.get(f'{lora_name}.lora_mid.weight', None)
|
||||
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
|
||||
return 'locon', (up, down, mid, alpha)
|
||||
elif f'{lora_name}.hada_w1_a' in lyco_state_dict:
|
||||
w1a = lyco_state_dict[f'{lora_name}.hada_w1_a']
|
||||
w1b = lyco_state_dict[f'{lora_name}.hada_w1_b']
|
||||
w2a = lyco_state_dict[f'{lora_name}.hada_w2_a']
|
||||
w2b = lyco_state_dict[f'{lora_name}.hada_w2_b']
|
||||
t1 = lyco_state_dict.get(f'{lora_name}.hada_t1', None)
|
||||
t2 = lyco_state_dict.get(f'{lora_name}.hada_t2', None)
|
||||
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
|
||||
return 'hada', (w1a, w1b, w2a, w2b, t1, t2, alpha)
|
||||
elif f'{lora_name}.weight' in lyco_state_dict:
|
||||
weight = lyco_state_dict[f'{lora_name}.weight']
|
||||
on_input = lyco_state_dict.get(f'{lora_name}.on_input', False)
|
||||
return 'ia3', (weight, on_input)
|
||||
elif (f'{lora_name}.lokr_w1' in lyco_state_dict
|
||||
or f'{lora_name}.lokr_w1_a' in lyco_state_dict):
|
||||
w1 = lyco_state_dict.get(f'{lora_name}.lokr_w1', None)
|
||||
w1a = lyco_state_dict.get(f'{lora_name}.lokr_w1_a', None)
|
||||
w1b = lyco_state_dict.get(f'{lora_name}.lokr_w1_b', None)
|
||||
w2 = lyco_state_dict.get(f'{lora_name}.lokr_w2', None)
|
||||
w2a = lyco_state_dict.get(f'{lora_name}.lokr_w2_a', None)
|
||||
w2b = lyco_state_dict.get(f'{lora_name}.lokr_w2_b', None)
|
||||
t1 = lyco_state_dict.get(f'{lora_name}.lokr_t1', None)
|
||||
t2 = lyco_state_dict.get(f'{lora_name}.lokr_t2', None)
|
||||
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
|
||||
return 'kron', (w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha)
|
||||
elif f'{lora_name}.diff' in lyco_state_dict:
|
||||
return 'full', lyco_state_dict[f'{lora_name}.diff']
|
||||
else:
|
||||
return 'None', ()
|
||||
|
||||
|
||||
def cp_weight_from_conv(
|
||||
up, down, mid
|
||||
):
|
||||
up = up.reshape(up.size(0), up.size(1))
|
||||
down = down.reshape(down.size(0), down.size(1))
|
||||
return torch.einsum('m n w h, i m, n j -> i j w h', mid, up, down)
|
||||
|
||||
def cp_weight(
|
||||
wa, wb, t
|
||||
):
|
||||
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
||||
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def rebuild_weight(module_type, params, orig_weight, scale=1):
|
||||
if orig_weight is None:
|
||||
return orig_weight
|
||||
merged = orig_weight
|
||||
if module_type == 'locon':
|
||||
up, down, mid, alpha = params
|
||||
if alpha is not None:
|
||||
scale *= alpha/up.size(1)
|
||||
if mid is not None:
|
||||
rebuild = cp_weight_from_conv(up, down, mid)
|
||||
else:
|
||||
rebuild = up.reshape(up.size(0),-1) @ down.reshape(down.size(0), -1)
|
||||
merged = orig_weight + rebuild.reshape(orig_weight.shape) * scale
|
||||
del up, down, mid, alpha, params, rebuild
|
||||
elif module_type == 'hada':
|
||||
w1a, w1b, w2a, w2b, t1, t2, alpha = params
|
||||
if alpha is not None:
|
||||
scale *= alpha / w1b.size(0)
|
||||
if t1 is not None:
|
||||
rebuild1 = cp_weight(w1a, w1b, t1)
|
||||
else:
|
||||
rebuild1 = w1a @ w1b
|
||||
if t2 is not None:
|
||||
rebuild2 = cp_weight(w2a, w2b, t2)
|
||||
else:
|
||||
rebuild2 = w2a @ w2b
|
||||
rebuild = (rebuild1 * rebuild2).reshape(orig_weight.shape)
|
||||
merged = orig_weight + rebuild * scale
|
||||
del w1a, w1b, w2a, w2b, t1, t2, alpha, params, rebuild, rebuild1, rebuild2
|
||||
elif module_type == 'ia3':
|
||||
weight, on_input = params
|
||||
if not on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
merged = orig_weight + weight * orig_weight * scale
|
||||
del weight, on_input, params
|
||||
elif module_type == 'kron':
|
||||
w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha = params
|
||||
if alpha is not None and (w1b is not None or w2b is not None):
|
||||
scale *= alpha / (w1b.size(0) if w1b else w2b.size(0))
|
||||
if w1a is not None and w1b is not None:
|
||||
if t1:
|
||||
w1 = cp_weight(w1a, w1b, t1)
|
||||
else:
|
||||
w1 = w1a @ w1b
|
||||
if w2a is not None and w2b is not None:
|
||||
if t2:
|
||||
w2 = cp_weight(w2a, w2b, t2)
|
||||
else:
|
||||
w2 = w2a @ w2b
|
||||
rebuild = torch.kron(w1, w2).reshape(orig_weight.shape)
|
||||
merged = orig_weight + rebuild* scale
|
||||
del w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha, params, rebuild
|
||||
elif module_type == 'full':
|
||||
rebuild = params.reshape(orig_weight.shape)
|
||||
merged = orig_weight + rebuild * scale
|
||||
del params, rebuild
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def merge(
|
||||
base_model,
|
||||
lyco_state_dict,
|
||||
scale: float = 1.0,
|
||||
device = 'cpu'
|
||||
):
|
||||
UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
"Attention",
|
||||
"ResnetBlock2D",
|
||||
"Downsample2D",
|
||||
"Upsample2D"
|
||||
]
|
||||
UNET_TARGET_REPLACE_NAME = [
|
||||
"conv_in",
|
||||
"conv_out",
|
||||
"time_embedding.linear_1",
|
||||
"time_embedding.linear_2",
|
||||
]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = 'lora_unet'
|
||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||
merged = 0
|
||||
def merge_state_dict(
|
||||
prefix,
|
||||
root_module: torch.nn.Module,
|
||||
lyco_state_dict: Dict[str,torch.Tensor],
|
||||
target_replace_modules,
|
||||
target_replace_names = []
|
||||
):
|
||||
nonlocal merged
|
||||
for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'):
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
|
||||
continue
|
||||
lora_name = prefix + '.' + name + '.' + child_name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
|
||||
result = rebuild_weight(*get_module(
|
||||
lyco_state_dict, lora_name
|
||||
), getattr(child_module, 'weight'), scale)
|
||||
if result is not None:
|
||||
merged += 1
|
||||
child_module.requires_grad_(False)
|
||||
child_module.weight.copy_(result)
|
||||
elif name in target_replace_names:
|
||||
lora_name = prefix + '.' + name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
|
||||
result = rebuild_weight(*get_module(
|
||||
lyco_state_dict, lora_name
|
||||
), getattr(module, 'weight'), scale)
|
||||
if result is not None:
|
||||
merged += 1
|
||||
module.requires_grad_(False)
|
||||
module.weight.copy_(result)
|
||||
|
||||
if device == 'cpu':
|
||||
for k, v in tqdm(list(lyco_state_dict.items()), desc='Converting Dtype'):
|
||||
lyco_state_dict[k] = v.float()
|
||||
|
||||
merge_state_dict(
|
||||
LORA_PREFIX_TEXT_ENCODER,
|
||||
base_model[0],
|
||||
lyco_state_dict,
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE,
|
||||
UNET_TARGET_REPLACE_NAME
|
||||
)
|
||||
merge_state_dict(
|
||||
LORA_PREFIX_UNET,
|
||||
base_model[2],
|
||||
lyco_state_dict,
|
||||
UNET_TARGET_REPLACE_MODULE,
|
||||
UNET_TARGET_REPLACE_NAME
|
||||
)
|
||||
print(f'{merged} Modules been merged')
|
||||
|
|
@ -1,35 +1,60 @@
|
|||
import os
|
||||
import sys
|
||||
import os, sys
|
||||
sys.path.insert(0, os.getcwd())
|
||||
import argparse
|
||||
import torch
|
||||
from lycoris.utils import merge_loha, merge_locon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"base_model", help="The model you want to merge with loha",
|
||||
default='', type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"lycoris_model", help="the lyco model you want to merge into sd model",
|
||||
default='', type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_name", help="the output model",
|
||||
default='./out.pt', type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is_v2", help="Your base model is sd v2 or not",
|
||||
default=False, action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", help="Which device you want to use to merge the weight",
|
||||
default='cpu', type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", help='dtype to save',
|
||||
default='float', type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight", help='weight for the lyco model to merge',
|
||||
default='1.0', type=float
|
||||
)
|
||||
return parser.parse_args()
|
||||
ARGS = get_args()
|
||||
|
||||
|
||||
from lycoris_utils import merge
|
||||
from lycoris.kohya_model_utils import (
|
||||
load_models_from_stable_diffusion_checkpoint,
|
||||
save_stable_diffusion_checkpoint,
|
||||
load_file
|
||||
)
|
||||
import gradio as gr
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def merge_models(base_model, lycoris_model, output_name, is_v2, device, dtype, weight):
|
||||
base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model)
|
||||
if lycoris_model.rsplit('.', 1)[-1] == 'safetensors':
|
||||
lyco = load_file(lycoris_model)
|
||||
def main():
|
||||
base = load_models_from_stable_diffusion_checkpoint(ARGS.is_v2, ARGS.base_model)
|
||||
if ARGS.lycoris_model.rsplit('.', 1)[-1] == 'safetensors':
|
||||
lyco = load_file(ARGS.lycoris_model)
|
||||
else:
|
||||
lyco = torch.load(lycoris_model)
|
||||
lyco = torch.load(ARGS.lycoris_model)
|
||||
|
||||
algo = None
|
||||
for key in lyco:
|
||||
if 'hada' in key:
|
||||
algo = 'loha'
|
||||
break
|
||||
elif 'lora_up' in key:
|
||||
algo = 'lora'
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError('Cannot find the algo for this lycoris model file.')
|
||||
|
||||
dtype_str = dtype.replace('fp', 'float').replace('bf', 'bfloat')
|
||||
dtype_str = ARGS.dtype.replace('fp', 'float').replace('bf', 'bfloat')
|
||||
dtype = {
|
||||
'float': torch.float,
|
||||
'float16': torch.float16,
|
||||
|
|
@ -41,40 +66,20 @@ def merge_models(base_model, lycoris_model, output_name, is_v2, device, dtype, w
|
|||
if dtype is None:
|
||||
raise ValueError(f'Cannot Find the dtype "{dtype}"')
|
||||
|
||||
if algo == 'loha':
|
||||
merge_loha(base, lyco, weight, device)
|
||||
elif algo == 'lora':
|
||||
merge_locon(base, lyco, weight, device)
|
||||
merge(
|
||||
base,
|
||||
lyco,
|
||||
ARGS.weight,
|
||||
ARGS.device
|
||||
)
|
||||
|
||||
save_stable_diffusion_checkpoint(
|
||||
is_v2, output_name,
|
||||
ARGS.is_v2, ARGS.output_name,
|
||||
base[0], base[2],
|
||||
None, 0, 0, dtype,
|
||||
base[1]
|
||||
)
|
||||
|
||||
return output_name
|
||||
|
||||
|
||||
def main():
|
||||
iface = gr.Interface(
|
||||
fn=merge_models,
|
||||
inputs=[
|
||||
gr.inputs.Textbox(label="Base Model Path"),
|
||||
gr.inputs.Textbox(label="Lycoris Model Path"),
|
||||
gr.inputs.Textbox(label="Output Model Path", default='./out.pt'),
|
||||
gr.inputs.Checkbox(label="Is base model SD V2?", default=False),
|
||||
gr.inputs.Textbox(label="Device", default='cpu'),
|
||||
gr.inputs.Dropdown(choices=['float', 'float16', 'float32', 'float64', 'bfloat', 'bfloat16'], label="Dtype", default='float'),
|
||||
gr.inputs.Number(label="Weight", default=1.0)
|
||||
],
|
||||
outputs=gr.outputs.Textbox(label="Merged Model Path"),
|
||||
title="Model Merger",
|
||||
description="Merge Lycoris and Stable Diffusion models",
|
||||
)
|
||||
|
||||
iface.launch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Reference in New Issue