diff --git a/README.md b/README.md index 1724fd2..c93a718 100644 --- a/README.md +++ b/README.md @@ -310,6 +310,7 @@ This will store a backup file with your current locally installed pip packages a - Fixed `lora_interrogator.py` not working. Please refer to [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) for details. Thank you A2va and heyalexchoi! - Fixed the handling of tags containing `_` in `tag_images_by_wd14_tagger.py`. - Add new Extract DyLoRA gui to the Utilities tab. + - Add new Merge LyCORIS models into checkpoint gui to the Utilities tab. * 2023/04/17 (v21.5.4) - Fixed a bug that caused an error when loading DyLoRA with the `--network_weight` option in `train_network.py`. - Added the `--recursive` option to each script in the `finetune` folder to process folders recursively. Please refer to [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) for details. Thanks to Linaqruf! diff --git a/kohya_gui.py b/kohya_gui.py index ba4076c..f7e9769 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -10,6 +10,7 @@ from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab from library.merge_lora_gui import gradio_merge_lora_tab from library.resize_lora_gui import gradio_resize_lora_tab from library.extract_lora_from_dylora_gui import gradio_extract_dylora_tab +from library.merge_lycoris_gui import gradio_merge_lycoris_tab from lora_gui import lora_tab @@ -49,6 +50,7 @@ def UI(**kwargs): gradio_extract_lora_tab() gradio_extract_lycoris_locon_tab() gradio_merge_lora_tab() + gradio_merge_lycoris_tab() gradio_resize_lora_tab() # Show the interface diff --git a/library/merge_lora_gui.py b/library/merge_lora_gui.py index 9e4ee02..fc6fc2f 100644 --- a/library/merge_lora_gui.py +++ b/library/merge_lora_gui.py @@ -200,7 +200,7 @@ def gradio_merge_lora_tab(): with gr.Row(): ratio_c = gr.Slider( - label='Model C erge ratio (eg: 0.5 mean 50%)', + label='Model C merge ratio (eg: 0.5 mean 50%)', minimum=0, maximum=1, step=0.01, diff --git a/library/merge_lycoris_gui.py b/library/merge_lycoris_gui.py new file mode 100644 index 0000000..1cce39c --- /dev/null +++ b/library/merge_lycoris_gui.py @@ -0,0 +1,152 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_file_path, +) + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +document_symbol = '\U0001F4C4' # 📄 +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + +def merge_lycoris( + base_model, + lycoris_model, + weight, + output_name, + dtype, + device, + is_v2, +): + print('Merge model...') + + run_cmd = f'{PYTHON} "{os.path.join("tools","merge_lycoris.py")}"' + run_cmd += f' {base_model}' + run_cmd += f' "{lycoris_model}"' + run_cmd += f' "{output_name}"' + run_cmd += f' --weight {weight}' + run_cmd += f' --device {device}' + run_cmd += f' --dtype {dtype}' + if is_v2: + run_cmd += f' --is_v2' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + print('Done merging...') + +### +# Gradio UI +### + + +def gradio_merge_lycoris_tab(): + with gr.Tab('Merge LyCORIS'): + gr.Markdown('This utility can merge a LyCORIS model into a SD checkpoint.') + + lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + ckpt_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) + ckpt_ext_name = gr.Textbox(value='SD model types', visible=False) + + with gr.Row(): + base_model = gr.Textbox( + label='SD Model', + placeholder='(Optional) Stable Diffusion base model', + interactive=True, + info='Provide a SD file path that you want to merge with the LyCORIS file' + ) + base_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + base_model_file.click( + get_file_path, + inputs=[base_model, ckpt_ext, ckpt_ext_name], + outputs=base_model, + show_progress=False, + ) + + with gr.Row(): + lycoris_model = gr.Textbox( + label='LyCORIS model', + placeholder='Path to the LyCORIS model', + interactive=True, + ) + button_lycoris_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lycoris_model_file.click( + get_file_path, + inputs=[lycoris_model, lora_ext, lora_ext_name], + outputs=lycoris_model, + show_progress=False, + ) + + with gr.Row(): + weight = gr.Slider( + label='Model A merge ratio (eg: 0.5 mean 50%)', + minimum=0, + maximum=1, + step=0.01, + value=1.0, + interactive=True, + ) + + with gr.Row(): + output_name = gr.Textbox( + label='Save to', + placeholder='path for the checkpoint file to save...', + interactive=True, + ) + button_output_name = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_output_name.click( + get_saveasfilename_path, + inputs=[output_name, lora_ext, lora_ext_name], + outputs=output_name, + show_progress=False, + ) + dtype = gr.Dropdown( + label='Save dtype', + choices=['float', 'float16', 'float32', 'float64', 'bfloat', 'bfloat16'], + value='float16', + interactive=True, + ) + + device = gr.Dropdown( + label='Device', + choices=[ + 'cpu', + # 'cuda', + ], + value='cpu', + interactive=True, + ) + + is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True) + + merge_button = gr.Button('Merge model') + + merge_button.click( + merge_lycoris, + inputs=[ + base_model, + lycoris_model, + weight, + output_name, + dtype, + device, + is_v2, + ], + show_progress=False, + ) diff --git a/tools/lycoris_utils.py b/tools/lycoris_utils.py new file mode 100644 index 0000000..25b066b --- /dev/null +++ b/tools/lycoris_utils.py @@ -0,0 +1,504 @@ +from typing import * + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torch.linalg as linalg + +from tqdm import tqdm + + +def make_sparse(t: torch.Tensor, sparsity=0.95): + abs_t = torch.abs(t) + np_array = abs_t.detach().cpu().numpy() + quan = float(np.quantile(np_array, sparsity)) + sparse_t = t.masked_fill(abs_t < quan, 0) + return sparse_t + + +def extract_conv( + weight: Union[torch.Tensor, nn.Parameter], + mode = 'fixed', + mode_param = 0, + device = 'cpu', + is_cp = False, +) -> Tuple[nn.Parameter, nn.Parameter]: + weight = weight.to(device) + out_ch, in_ch, kernel_size, _ = weight.shape + + U, S, Vh = linalg.svd(weight.reshape(out_ch, -1)) + + if mode=='fixed': + lora_rank = mode_param + elif mode=='threshold': + assert mode_param>=0 + lora_rank = torch.sum(S>mode_param) + elif mode=='ratio': + assert 1>=mode_param>=0 + min_s = torch.max(S)*mode_param + lora_rank = torch.sum(S>min_s) + elif mode=='quantile' or mode=='percentile': + assert 1>=mode_param>=0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum=out_ch/2 and not is_cp: + return weight, 'full' + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() + extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() + del U, S, Vh, weight + return (extract_weight_A, extract_weight_B, diff), 'low rank' + + +def extract_linear( + weight: Union[torch.Tensor, nn.Parameter], + mode = 'fixed', + mode_param = 0, + device = 'cpu', +) -> Tuple[nn.Parameter, nn.Parameter]: + weight = weight.to(device) + out_ch, in_ch = weight.shape + + U, S, Vh = linalg.svd(weight) + + if mode=='fixed': + lora_rank = mode_param + elif mode=='threshold': + assert mode_param>=0 + lora_rank = torch.sum(S>mode_param) + elif mode=='ratio': + assert 1>=mode_param>=0 + min_s = torch.max(S)*mode_param + lora_rank = torch.sum(S>min_s) + elif mode=='quantile' or mode=='percentile': + assert 1>=mode_param>=0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum=out_ch/2: + return weight, 'full' + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - U @ Vh).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() + extract_weight_B = U.reshape(out_ch, lora_rank).detach() + del U, S, Vh, weight + return (extract_weight_A, extract_weight_B, diff), 'low rank' + + +def extract_diff( + base_model, + db_model, + mode = 'fixed', + linear_mode_param = 0, + conv_mode_param = 0, + extract_device = 'cpu', + use_bias = False, + sparsity = 0.98, + small_conv = True +): + UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "Attention", + "ResnetBlock2D", + "Downsample2D", + "Upsample2D" + ] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", + ] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = 'lora_unet' + LORA_PREFIX_TEXT_ENCODER = 'lora_te' + def make_state_dict( + prefix, + root_module: torch.nn.Module, + target_module: torch.nn.Module, + target_replace_modules, + target_replace_names = [] + ): + loras = {} + temp = {} + temp_name = {} + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + temp[name] = {} + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}: + continue + temp[name][child_name] = child_module.weight + elif name in target_replace_names: + temp_name[name] = module.weight + + for name, module in tqdm(list(target_module.named_modules())): + if name in temp: + weights = temp[name] + for child_name, child_module in module.named_modules(): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + layer = child_module.__class__.__name__ + if layer in {'Linear', 'Conv2d'}: + root_weight = child_module.weight + if torch.allclose(root_weight, weights[child_name]): + continue + + if layer == 'Linear': + weight, decompose_mode = extract_linear( + (child_module.weight - weights[child_name]), + mode, + linear_mode_param, + device = extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + elif layer == 'Conv2d': + is_linear = (child_module.weight.shape[2] == 1 + and child_module.weight.shape[3] == 1) + weight, decompose_mode = extract_conv( + (child_module.weight - weights[child_name]), + mode, + linear_mode_param if is_linear else conv_mode_param, + device = extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + if small_conv and not is_linear and decompose_mode == 'low rank': + dim = extract_a.size(0) + (extract_c, extract_a, _), _ = extract_conv( + extract_a.transpose(0, 1), + 'fixed', dim, + extract_device, True + ) + extract_a = extract_a.transpose(0, 1) + extract_c = extract_c.transpose(0, 1) + loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half() + diff = child_module.weight - torch.einsum( + 'i j k l, j r, p i -> p r k l', + extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1) + ).detach().cpu().contiguous() + del extract_c + else: + continue + if decompose_mode == 'low rank': + loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half() + loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half() + loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half() + if use_bias: + diff = diff.detach().cpu().reshape(extract_b.size(0), -1) + sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() + + indices = sparse_diff.indices().to(torch.int16) + values = sparse_diff.values().half() + loras[f'{lora_name}.bias_indices'] = indices + loras[f'{lora_name}.bias_values'] = values + loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16) + del extract_a, extract_b, diff + elif decompose_mode == 'full': + loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half() + else: + raise NotImplementedError + elif name in temp_name: + weights = temp_name[name] + lora_name = prefix + '.' + name + lora_name = lora_name.replace('.', '_') + layer = module.__class__.__name__ + + if layer in {'Linear', 'Conv2d'}: + root_weight = module.weight + if torch.allclose(root_weight, weights): + continue + + if layer == 'Linear': + weight, decompose_mode = extract_linear( + (root_weight - weights), + mode, + linear_mode_param, + device = extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + elif layer == 'Conv2d': + is_linear = ( + root_weight.shape[2] == 1 + and root_weight.shape[3] == 1 + ) + weight, decompose_mode = extract_conv( + (root_weight - weights), + mode, + linear_mode_param if is_linear else conv_mode_param, + device = extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + if small_conv and not is_linear and decompose_mode == 'low rank': + dim = extract_a.size(0) + (extract_c, extract_a, _), _ = extract_conv( + extract_a.transpose(0, 1), + 'fixed', dim, + extract_device, True + ) + extract_a = extract_a.transpose(0, 1) + extract_c = extract_c.transpose(0, 1) + loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half() + diff = root_weight - torch.einsum( + 'i j k l, j r, p i -> p r k l', + extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1) + ).detach().cpu().contiguous() + del extract_c + else: + continue + if decompose_mode == 'low rank': + loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half() + loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half() + loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half() + if use_bias: + diff = diff.detach().cpu().reshape(extract_b.size(0), -1) + sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() + + indices = sparse_diff.indices().to(torch.int16) + values = sparse_diff.values().half() + loras[f'{lora_name}.bias_indices'] = indices + loras[f'{lora_name}.bias_values'] = values + loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16) + del extract_a, extract_b, diff + elif decompose_mode == 'full': + loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half() + else: + raise NotImplementedError + return loras + + text_encoder_loras = make_state_dict( + LORA_PREFIX_TEXT_ENCODER, + base_model[0], db_model[0], + TEXT_ENCODER_TARGET_REPLACE_MODULE + ) + + unet_loras = make_state_dict( + LORA_PREFIX_UNET, + base_model[2], db_model[2], + UNET_TARGET_REPLACE_MODULE, + UNET_TARGET_REPLACE_NAME + ) + print(len(text_encoder_loras), len(unet_loras)) + return text_encoder_loras|unet_loras + + +def get_module( + lyco_state_dict: Dict, + lora_name +): + if f'{lora_name}.lora_up.weight' in lyco_state_dict: + up = lyco_state_dict[f'{lora_name}.lora_up.weight'] + down = lyco_state_dict[f'{lora_name}.lora_down.weight'] + mid = lyco_state_dict.get(f'{lora_name}.lora_mid.weight', None) + alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) + return 'locon', (up, down, mid, alpha) + elif f'{lora_name}.hada_w1_a' in lyco_state_dict: + w1a = lyco_state_dict[f'{lora_name}.hada_w1_a'] + w1b = lyco_state_dict[f'{lora_name}.hada_w1_b'] + w2a = lyco_state_dict[f'{lora_name}.hada_w2_a'] + w2b = lyco_state_dict[f'{lora_name}.hada_w2_b'] + t1 = lyco_state_dict.get(f'{lora_name}.hada_t1', None) + t2 = lyco_state_dict.get(f'{lora_name}.hada_t2', None) + alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) + return 'hada', (w1a, w1b, w2a, w2b, t1, t2, alpha) + elif f'{lora_name}.weight' in lyco_state_dict: + weight = lyco_state_dict[f'{lora_name}.weight'] + on_input = lyco_state_dict.get(f'{lora_name}.on_input', False) + return 'ia3', (weight, on_input) + elif (f'{lora_name}.lokr_w1' in lyco_state_dict + or f'{lora_name}.lokr_w1_a' in lyco_state_dict): + w1 = lyco_state_dict.get(f'{lora_name}.lokr_w1', None) + w1a = lyco_state_dict.get(f'{lora_name}.lokr_w1_a', None) + w1b = lyco_state_dict.get(f'{lora_name}.lokr_w1_b', None) + w2 = lyco_state_dict.get(f'{lora_name}.lokr_w2', None) + w2a = lyco_state_dict.get(f'{lora_name}.lokr_w2_a', None) + w2b = lyco_state_dict.get(f'{lora_name}.lokr_w2_b', None) + t1 = lyco_state_dict.get(f'{lora_name}.lokr_t1', None) + t2 = lyco_state_dict.get(f'{lora_name}.lokr_t2', None) + alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) + return 'kron', (w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha) + elif f'{lora_name}.diff' in lyco_state_dict: + return 'full', lyco_state_dict[f'{lora_name}.diff'] + else: + return 'None', () + + +def cp_weight_from_conv( + up, down, mid +): + up = up.reshape(up.size(0), up.size(1)) + down = down.reshape(down.size(0), down.size(1)) + return torch.einsum('m n w h, i m, n j -> i j w h', mid, up, down) + +def cp_weight( + wa, wb, t +): + temp = torch.einsum('i j k l, j r -> i r k l', t, wb) + return torch.einsum('i j k l, i r -> r j k l', temp, wa) + + +@torch.no_grad() +def rebuild_weight(module_type, params, orig_weight, scale=1): + if orig_weight is None: + return orig_weight + merged = orig_weight + if module_type == 'locon': + up, down, mid, alpha = params + if alpha is not None: + scale *= alpha/up.size(1) + if mid is not None: + rebuild = cp_weight_from_conv(up, down, mid) + else: + rebuild = up.reshape(up.size(0),-1) @ down.reshape(down.size(0), -1) + merged = orig_weight + rebuild.reshape(orig_weight.shape) * scale + del up, down, mid, alpha, params, rebuild + elif module_type == 'hada': + w1a, w1b, w2a, w2b, t1, t2, alpha = params + if alpha is not None: + scale *= alpha / w1b.size(0) + if t1 is not None: + rebuild1 = cp_weight(w1a, w1b, t1) + else: + rebuild1 = w1a @ w1b + if t2 is not None: + rebuild2 = cp_weight(w2a, w2b, t2) + else: + rebuild2 = w2a @ w2b + rebuild = (rebuild1 * rebuild2).reshape(orig_weight.shape) + merged = orig_weight + rebuild * scale + del w1a, w1b, w2a, w2b, t1, t2, alpha, params, rebuild, rebuild1, rebuild2 + elif module_type == 'ia3': + weight, on_input = params + if not on_input: + weight = weight.reshape(-1, 1) + merged = orig_weight + weight * orig_weight * scale + del weight, on_input, params + elif module_type == 'kron': + w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha = params + if alpha is not None and (w1b is not None or w2b is not None): + scale *= alpha / (w1b.size(0) if w1b else w2b.size(0)) + if w1a is not None and w1b is not None: + if t1: + w1 = cp_weight(w1a, w1b, t1) + else: + w1 = w1a @ w1b + if w2a is not None and w2b is not None: + if t2: + w2 = cp_weight(w2a, w2b, t2) + else: + w2 = w2a @ w2b + rebuild = torch.kron(w1, w2).reshape(orig_weight.shape) + merged = orig_weight + rebuild* scale + del w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha, params, rebuild + elif module_type == 'full': + rebuild = params.reshape(orig_weight.shape) + merged = orig_weight + rebuild * scale + del params, rebuild + + return merged + + +def merge( + base_model, + lyco_state_dict, + scale: float = 1.0, + device = 'cpu' +): + UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "Attention", + "ResnetBlock2D", + "Downsample2D", + "Upsample2D" + ] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", + ] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = 'lora_unet' + LORA_PREFIX_TEXT_ENCODER = 'lora_te' + merged = 0 + def merge_state_dict( + prefix, + root_module: torch.nn.Module, + lyco_state_dict: Dict[str,torch.Tensor], + target_replace_modules, + target_replace_names = [] + ): + nonlocal merged + for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}: + continue + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + + result = rebuild_weight(*get_module( + lyco_state_dict, lora_name + ), getattr(child_module, 'weight'), scale) + if result is not None: + merged += 1 + child_module.requires_grad_(False) + child_module.weight.copy_(result) + elif name in target_replace_names: + lora_name = prefix + '.' + name + lora_name = lora_name.replace('.', '_') + + result = rebuild_weight(*get_module( + lyco_state_dict, lora_name + ), getattr(module, 'weight'), scale) + if result is not None: + merged += 1 + module.requires_grad_(False) + module.weight.copy_(result) + + if device == 'cpu': + for k, v in tqdm(list(lyco_state_dict.items()), desc='Converting Dtype'): + lyco_state_dict[k] = v.float() + + merge_state_dict( + LORA_PREFIX_TEXT_ENCODER, + base_model[0], + lyco_state_dict, + TEXT_ENCODER_TARGET_REPLACE_MODULE, + UNET_TARGET_REPLACE_NAME + ) + merge_state_dict( + LORA_PREFIX_UNET, + base_model[2], + lyco_state_dict, + UNET_TARGET_REPLACE_MODULE, + UNET_TARGET_REPLACE_NAME + ) + print(f'{merged} Modules been merged') \ No newline at end of file diff --git a/tools/merge_lycoris.py b/tools/merge_lycoris.py index 570fa2b..92223ca 100644 --- a/tools/merge_lycoris.py +++ b/tools/merge_lycoris.py @@ -1,35 +1,60 @@ -import os -import sys +import os, sys +sys.path.insert(0, os.getcwd()) import argparse -import torch -from lycoris.utils import merge_loha, merge_locon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "base_model", help="The model you want to merge with loha", + default='', type=str + ) + parser.add_argument( + "lycoris_model", help="the lyco model you want to merge into sd model", + default='', type=str + ) + parser.add_argument( + "output_name", help="the output model", + default='./out.pt', type=str + ) + parser.add_argument( + "--is_v2", help="Your base model is sd v2 or not", + default=False, action="store_true" + ) + parser.add_argument( + "--device", help="Which device you want to use to merge the weight", + default='cpu', type=str + ) + parser.add_argument( + "--dtype", help='dtype to save', + default='float', type=str + ) + parser.add_argument( + "--weight", help='weight for the lyco model to merge', + default='1.0', type=float + ) + return parser.parse_args() +ARGS = get_args() + + +from lycoris_utils import merge from lycoris.kohya_model_utils import ( load_models_from_stable_diffusion_checkpoint, save_stable_diffusion_checkpoint, load_file ) -import gradio as gr + +import torch -def merge_models(base_model, lycoris_model, output_name, is_v2, device, dtype, weight): - base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model) - if lycoris_model.rsplit('.', 1)[-1] == 'safetensors': - lyco = load_file(lycoris_model) +def main(): + base = load_models_from_stable_diffusion_checkpoint(ARGS.is_v2, ARGS.base_model) + if ARGS.lycoris_model.rsplit('.', 1)[-1] == 'safetensors': + lyco = load_file(ARGS.lycoris_model) else: - lyco = torch.load(lycoris_model) - - algo = None - for key in lyco: - if 'hada' in key: - algo = 'loha' - break - elif 'lora_up' in key: - algo = 'lora' - break - else: - raise NotImplementedError('Cannot find the algo for this lycoris model file.') - - dtype_str = dtype.replace('fp', 'float').replace('bf', 'bfloat') + lyco = torch.load(ARGS.lycoris_model) + + dtype_str = ARGS.dtype.replace('fp', 'float').replace('bf', 'bfloat') dtype = { 'float': torch.float, 'float16': torch.float16, @@ -40,41 +65,21 @@ def merge_models(base_model, lycoris_model, output_name, is_v2, device, dtype, w }.get(dtype_str, None) if dtype is None: raise ValueError(f'Cannot Find the dtype "{dtype}"') - - if algo == 'loha': - merge_loha(base, lyco, weight, device) - elif algo == 'lora': - merge_locon(base, lyco, weight, device) - + + merge( + base, + lyco, + ARGS.weight, + ARGS.device + ) + save_stable_diffusion_checkpoint( - is_v2, output_name, - base[0], base[2], - None, 0, 0, dtype, + ARGS.is_v2, ARGS.output_name, + base[0], base[2], + None, 0, 0, dtype, base[1] ) - return output_name - - -def main(): - iface = gr.Interface( - fn=merge_models, - inputs=[ - gr.inputs.Textbox(label="Base Model Path"), - gr.inputs.Textbox(label="Lycoris Model Path"), - gr.inputs.Textbox(label="Output Model Path", default='./out.pt'), - gr.inputs.Checkbox(label="Is base model SD V2?", default=False), - gr.inputs.Textbox(label="Device", default='cpu'), - gr.inputs.Dropdown(choices=['float', 'float16', 'float32', 'float64', 'bfloat', 'bfloat16'], label="Dtype", default='float'), - gr.inputs.Number(label="Weight", default=1.0) - ], - outputs=gr.outputs.Textbox(label="Merged Model Path"), - title="Model Merger", - description="Merge Lycoris and Stable Diffusion models", - ) - - iface.launch() - if __name__ == '__main__': - main() + main() \ No newline at end of file