commit 3d9e665e6e7339e648282977e299585e8db0ca6a Author: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun Apr 9 16:21:14 2023 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ed8ebf5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..60f8f07 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2023] [KohakuBlueLeaf] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..5280f5d --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +# a1111-sd-webui-lycoris + +An extension for loading lycoris model in sd-webui. +I made this stand alone extension (Use sd-webui's extra networks api) to avoid some conflict with other loras extensions. + +### LyCORIS +https://github.com/KohakuBlueleaf/LyCORIS + +### usage +Install it and restart the webui +**Don't use "Apply and restart UI", please restart the webui process** + +And you will find "LyCORIS" tab in the extra networks page \ No newline at end of file diff --git a/extra_networks_lyco.py b/extra_networks_lyco.py new file mode 100644 index 0000000..dc46106 --- /dev/null +++ b/extra_networks_lyco.py @@ -0,0 +1,26 @@ +from modules import extra_networks, shared +import lycoris + +class ExtraNetworkLyCORIS(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('lyco') + + def activate(self, p, params_list): + additional = shared.opts.sd_lyco + + if additional != "" and additional in lycoris.available_lycos and len([x for x in params_list if x.items[0] == additional]) == 0: + p.all_prompts = [x + f"" for x in p.all_prompts] + params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) + + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + lycoris.load_lycos(names, multipliers) + + def deactivate(self, p): + pass diff --git a/lycoris.py b/lycoris.py new file mode 100644 index 0000000..9c75d6e --- /dev/null +++ b/lycoris.py @@ -0,0 +1,724 @@ +from typing import * +import os +import re +import glob + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules import shared, devices, sd_models, errors + +metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} + + +re_digits = re.compile(r"\d+") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") + +re_unet_conv_in = re.compile(r"lora_unet_conv_in(.+)") +re_unet_conv_out = re.compile(r"lora_unet_conv_out(.+)") +re_unet_time_embed = re.compile(r"lora_unet_time_embedding_linear_(\d+)(.+)") + +re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") +re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") +re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") + +re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)") +re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)") +re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+)") + +re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)") +re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)") + +re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") + + +def convert_diffusers_name_to_compvis(key, is_sd2): + # I don't know why but some state dict has this kind of thing + key = key.replace('text_model_text_model', 'text_model') + def match(match_list, regex): + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, re_unet_conv_in): + return f'diffusion_model_input_blocks_0_0{m[0]}' + + if match(m, re_unet_conv_out): + return f'diffusion_model_out_2{m[0]}' + + if match(m, re_unet_time_embed): + return f"diffusion_model_time_embed_{m[0]*2-2}{m[1]}" + + if match(m, re_unet_down_blocks): + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_unet_mid_blocks): + return f"diffusion_model_middle_block_1_{m[1]}" + + if match(m, re_unet_up_blocks): + return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_unet_down_blocks_res): + block = f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_0_" + if m[2].startswith('conv1'): + return f"{block}in_layers_2{m[2][len('conv1'):]}" + elif m[2].startswith('conv2'): + return f"{block}out_layers_3{m[2][len('conv2'):]}" + elif m[2].startswith('time_emb_proj'): + return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}" + elif m[2].startswith('conv_shortcut'): + return f"{block}skip_connection{m[2][len('conv_shortcut'):]}" + + if match(m, re_unet_mid_blocks_res): + block = f"diffusion_model_middle_block_{m[0]*2}_" + if m[1].startswith('conv1'): + return f"{block}in_layers_2{m[1][len('conv1'):]}" + elif m[1].startswith('conv2'): + return f"{block}out_layers_3{m[1][len('conv2'):]}" + elif m[1].startswith('time_emb_proj'): + return f"{block}emb_layers_1{m[1][len('time_emb_proj'):]}" + elif m[1].startswith('conv_shortcut'): + return f"{block}skip_connection{m[1][len('conv_shortcut'):]}" + + if match(m, re_unet_up_blocks_res): + block = f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_0_" + if m[2].startswith('conv1'): + return f"{block}in_layers_2{m[2][len('conv1'):]}" + elif m[2].startswith('conv2'): + return f"{block}out_layers_3{m[2][len('conv2'):]}" + elif m[2].startswith('time_emb_proj'): + return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}" + elif m[2].startswith('conv_shortcut'): + return f"{block}skip_connection{m[2][len('conv_shortcut'):]}" + + if match(m, re_unet_downsample): + return f"diffusion_model_input_blocks_{m[0]*3+3}_0_op{m[1]}" + + if match(m, re_unet_upsample): + return f"diffusion_model_output_blocks_{m[0]*3 + 2}_{1+(m[0]!=0)}_conv{m[1]}" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + return key + + +def assign_lyco_names_to_compvis_modules(sd_model): + lyco_layer_mapping = {} + + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lyco_name = name.replace(".", "_") + lyco_layer_mapping[lyco_name] = module + module.lyco_layer_name = lyco_name + + for name, module in shared.sd_model.model.named_modules(): + lyco_name = name.replace(".", "_") + lyco_layer_mapping[lyco_name] = module + module.lyco_layer_name = lyco_name + + sd_model.lyco_layer_mapping = lyco_layer_mapping + + +class LycoOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + self.metadata = {} + + _, ext = os.path.splitext(filename) + if ext.lower() == ".safetensors": + try: + self.metadata = sd_models.read_metadata_from_safetensors(filename) + except Exception as e: + errors.display(e, f"reading lora {filename}") + + if self.metadata: + m = {} + for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): + m[k] = v + + self.metadata = m + + self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text + + +class LycoModule: + def __init__(self, name): + self.name = name + self.multiplier = 1.0 + self.modules = {} + self.mtime = None + + +class FullModule: + def __init__(self): + self.weight = None + self.alpha = None + self.dim = None + self.shape = None + + +class LycoUpDownModule: + def __init__(self): + self.up_model = None + self.mid_model = None + self.down_model = None + self.alpha = None + self.dim = None + self.shape = None + self.bias = None + + +def make_weight_cp(t, wa, wb): + temp = torch.einsum('i j k l, j r -> i r k l', t, wb) + return torch.einsum('i j k l, i r -> r j k l', temp, wa) + + +class LycoHadaModule: + def __init__(self): + self.t1 = None + self.w1a = None + self.w1b = None + self.t2 = None + self.w2a = None + self.w2b = None + self.alpha = None + self.dim = None + self.shape = None + self.bias = None + + +class IA3Module: + def __init__(self): + self.w = None + self.alpha = None + self.dim = None + self.on_input = None + + +def make_kron(orig_shape, w1, w2): + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + return torch.kron(w1, w2).reshape(orig_shape) + + +class LycoKronModule: + def __init__(self): + self.w1 = None + self.w1a = None + self.w1b = None + self.w2 = None + self.t2 = None + self.w2a = None + self.w2b = None + self._alpha = None + self.dim = None + self.shape = None + self.bias = None + + @property + def alpha(self): + if self.w1a is None and self.w2a is None: + return None + else: + return self._alpha + + @alpha.setter + def alpha(self, x): + self._alpha = x + + +CON_KEY = { + "lora_up.weight", + "lora_down.weight", + "lora_mid.weight" +} +HADA_KEY = { + "hada_t1", + "hada_w1_a", + "hada_w1_b", + "hada_t2", + "hada_w2_a", + "hada_w2_b", +} +IA3_KEY = { + "weight", + "on_input" +} +KRON_KEY = { + "lokr_w1", + "lokr_w1_a", + "lokr_w1_b", + "lokr_t2", + "lokr_w2", + "lokr_w2_a", + "lokr_w2_b", +} + +def load_lyco(name, filename): + print('locon load lyco method') + lyco = LycoModule(name) + lyco.mtime = os.path.getmtime(filename) + + sd = sd_models.read_state_dict(filename) + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lyco_layer_mapping + + keys_failed_to_match = [] + + for key_diffusers, weight in sd.items(): + fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2) + key, lyco_key = fullkey.split(".", 1) + + sd_module = shared.sd_model.lyco_layer_mapping.get(key, None) + + if sd_module is None: + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.lyco_layer_mapping.get(m.group(1), None) + + if sd_module is None: + print(key) + keys_failed_to_match.append(key_diffusers) + continue + + lyco_module = lyco.modules.get(key, None) + if lyco_module is None: + lyco_module = LycoUpDownModule() + lyco.modules[key] = lyco_module + + if lyco_key == "alpha": + lyco_module.alpha = weight.item() + continue + + if lyco_key == "diff": + weight = weight.to(device=devices.device, dtype=devices.dtype) + weight.requires_grad_(False) + lyco_module = FullModule() + lyco.modules[key] = lyco_module + lyco_module.weight = weight + continue + + if 'bias_' in lyco_key: + if lyco_module.bias is None: + lyco_module.bias = [None, None, None] + if 'bias_indices' == lyco_key: + lyco_module.bias[0] = weight + elif 'bias_values' == lyco_key: + lyco_module.bias[1] = weight + elif 'bias_size' == lyco_key: + lyco_module.bias[2] = weight + + if all((i is not None) for i in lyco_module.bias): + print('build bias') + lyco_module.bias = torch.sparse_coo_tensor( + lyco_module.bias[0], + lyco_module.bias[1], + tuple(lyco_module.bias[2]), + ).to(device=devices.cpu, dtype=devices.dtype) + lyco_module.bias.requires_grad_(False) + continue + + if lyco_key in CON_KEY: + if (type(sd_module) == torch.nn.Linear + or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear + or type(sd_module) == torch.nn.MultiheadAttention): + weight = weight.reshape(weight.shape[0], -1) + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.Conv2d: + if lyco_key == "lora_down.weight": + if len(weight.shape) == 2: + weight = weight.reshape(weight.shape[0], -1, 1, 1) + if weight.shape[2] != 1 or weight.shape[3] != 1: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False) + else: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif lyco_key == "lora_mid.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False) + elif lyco_key == "lora_up.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + assert False, f'Lyco layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' + + if hasattr(sd_module, 'weight'): + lyco_module.shape = sd_module.weight.shape + with torch.no_grad(): + if weight.shape != module.weight.shape: + weight = weight.reshape(module.weight.shape) + module.weight.copy_(weight) + + module.to(device=devices.cpu, dtype=devices.dtype) + module.requires_grad_(False) + + if lyco_key == "lora_up.weight": + lyco_module.up_model = module + elif lyco_key == "lora_mid.weight": + lyco_module.mid_model = module + elif lyco_key == "lora_down.weight": + lyco_module.down_model = module + lyco_module.dim = weight.shape[0] + else: + print(lyco_key) + elif lyco_key in HADA_KEY: + if type(lyco_module) != LycoHadaModule: + alpha = lyco_module.alpha + bias = lyco_module.bias + lyco_module = LycoHadaModule() + lyco_module.alpha = alpha + lyco_module.bias = bias + lyco.modules[key] = lyco_module + if hasattr(sd_module, 'weight'): + lyco_module.shape = sd_module.weight.shape + + weight = weight.to(device=devices.cpu, dtype=devices.dtype) + weight.requires_grad_(False) + + if lyco_key == 'hada_w1_a': + lyco_module.w1a = weight + elif lyco_key == 'hada_w1_b': + lyco_module.w1b = weight + lyco_module.dim = weight.shape[0] + elif lyco_key == 'hada_w2_a': + lyco_module.w2a = weight + elif lyco_key == 'hada_w2_b': + lyco_module.w2b = weight + lyco_module.dim = weight.shape[0] + elif lyco_key == 'hada_t1': + lyco_module.t1 = weight + elif lyco_key == 'hada_t2': + lyco_module.t2 = weight + + elif lyco_key in IA3_KEY: + if type(lyco_module) != IA3Module: + lyco_module = IA3Module() + lyco.modules[key] = lyco_module + + if lyco_key == "weight": + lyco_module.w = weight.to(devices.device, dtype=devices.dtype) + elif lyco_key == "on_input": + lyco_module.on_input = weight + elif lyco_key in KRON_KEY: + if not isinstance(lyco_module, LycoKronModule): + alpha = lyco_module.alpha + bias = lyco_module.bias + lyco_module = LycoKronModule() + lyco_module.alpha = alpha + lyco_module.bias = bias + lyco.modules[key] = lyco_module + if hasattr(sd_module, 'weight'): + lyco_module.shape = sd_module.weight.shape + + weight = weight.to(device=devices.cpu, dtype=devices.dtype) + weight.requires_grad_(False) + + if lyco_key == 'lokr_w1': + lyco_module.w1 = weight + elif lyco_key == 'lokr_w1_a': + lyco_module.w1a = weight + elif lyco_key == 'lokr_w1_b': + lyco_module.w1b = weight + lyco_module.dim = weight.shape[0] + elif lyco_key == 'lokr_w2': + lyco_module.w2 = weight + elif lyco_key == 'lokr_w2_a': + lyco_module.w2a = weight + elif lyco_key == 'lokr_w2_b': + lyco_module.w2b = weight + lyco_module.dim = weight.shape[0] + elif lyco_key == 'lokr_t2': + lyco_module.t2 = weight + else: + assert False, f'Bad Lyco layer name: {key_diffusers} - must end in lyco_up.weight, lyco_down.weight or alpha' + + if len(keys_failed_to_match) > 0: + print(shared.sd_model.lyco_layer_mapping) + print(f"Failed to match keys when loading Lyco {filename}: {keys_failed_to_match}") + + return lyco + + +def load_lycos(names, multipliers=None): + already_loaded = {} + + for lyco in loaded_lycos: + if lyco.name in names: + already_loaded[lyco.name] = lyco + + loaded_lycos.clear() + + lycos_on_disk = [available_lycos.get(name, None) for name in names] + if any([x is None for x in lycos_on_disk]): + list_available_lycos() + + lycos_on_disk = [available_lycos.get(name, None) for name in names] + + for i, name in enumerate(names): + lyco = already_loaded.get(name, None) + + lyco_on_disk = lycos_on_disk[i] + if lyco_on_disk is not None: + if lyco is None or os.path.getmtime(lyco_on_disk.filename) > lyco.mtime: + lyco = load_lyco(name, lyco_on_disk.filename) + + if lyco is None: + print(f"Couldn't find Lora with name {name}") + continue + + lyco.multiplier = multipliers[i] if multipliers else 1.0 + loaded_lycos.append(lyco) + + +def _rebuild_conventional(up, down, shape): + return (up.reshape(up.size(0), -1) @ down.reshape(down.size(0), -1)).reshape(shape) + + +def _rebuild_cp_decomposition(up, down, mid): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) + + +def rebuild_weight(module, orig_weight: torch.Tensor) -> torch.Tensor: + if module.__class__.__name__ == 'LycoUpDownModule': + up = module.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + down = module.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [up.size(0), down.size(1)] + if (mid:=module.mid_model) is not None: + # cp-decomposition + mid = mid.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = _rebuild_cp_decomposition(up, down, mid) + output_shape += mid.shape[2:] + else: + if len(down.shape) == 4: + output_shape += down.shape[2:] + updown = _rebuild_conventional(up, down, output_shape) + + elif module.__class__.__name__ == 'LycoHadaModule': + w1a = module.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = module.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = module.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = module.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [w1a.size(0), w1b.size(1)] + + if module.t1 is not None: + output_shape = [w1a.size(1), w1b.size(1)] + t1 = module.t1.to(orig_weight.device, dtype=orig_weight.dtype) + updown1 = make_weight_cp(t1, w1a, w1b) + output_shape += t1.shape[2:] + else: + if len(w1b.shape) == 4: + output_shape += w1b.shape[2:] + updown1 = _rebuild_conventional(w1a, w1b, output_shape) + + if module.t2 is not None: + t2 = module.t2.to(orig_weight.device, dtype=orig_weight.dtype) + updown2 = make_weight_cp(t2, w2a, w2b) + else: + updown2 = _rebuild_conventional(w2a, w2b, output_shape) + + updown = updown1 * updown2 + + elif module.__class__.__name__ == 'FullModule': + output_shape = module.weight.shape + updown = module.weight.to(orig_weight.device, dtype=orig_weight.dtype) + + elif module.__class__.__name__ == 'IA3Module': + output_shape = [module.w.size(0), orig_weight.size(1)] + if module.on_input: + output_shape.reverse() + else: + module.w = module.w.reshape(-1, 1) + updown = orig_weight * module.w + + elif module.__class__.__name__ == 'LycoKronModule': + if module.w1 is not None: + w1 = module.w1.to(orig_weight.device, dtype=orig_weight.dtype) + else: + w1a = module.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = module.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = w1a @ w1b + + if module.w2 is not None: + w2 = module.w2.to(orig_weight.device, dtype=orig_weight.dtype) + elif module.t2 is None: + w2a = module.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = module.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = w2a @ w2b + else: + t2 = module.t2.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = module.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = module.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = make_weight_cp(t2, w2a, w2b) + + output_shape = [w1.size(0)*w2.size(0), w1.size(1)*w2.size(1)] + if len(orig_weight.shape) == 4: + output_shape = orig_weight.shape + + updown = make_kron( + output_shape, w1, w2 + ) + + else: + raise NotImplementedError( + f"Unknown module type: {module.__class__.__name__}\n" + "If the type is one of " + "'LycoUpDownModule', 'LycoHadaModule', 'FullModule', 'IA3Module', 'LycoKronModule'" + "You may have other lyco extension that conflict with locon extension." + ) + + if hasattr(module, 'bias') and module.bias != None: + updown = updown.reshape(module.bias.shape) + updown += module.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + # print(torch.sum(updown)) + return updown + + +def lyco_calc_updown(lyco, module, target): + with torch.no_grad(): + updown = rebuild_weight(module, target) + updown = updown * lyco.multiplier * (module.alpha / module.dim if module.alpha and module.dim else 1.0) + return updown + + +def lyco_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + """ + Applies the currently selected set of Lycos to the weights of torch layer self. + If weights already have this particular set of lycos applied, does nothing. + If not, restores orginal weights from backup and alters weights according to lycos. + """ + + lyco_layer_name = getattr(self, 'lyco_layer_name', None) + if lyco_layer_name is None: + return + + current_names = getattr(self, "lyco_current_names", ()) + lora_names = getattr(self, "lora_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_lycos) + + weights_backup = getattr(self, "lora_backup_weights", None) + + if current_names != wanted_names: + if weights_backup is not None and lora_names == (): + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + for lyco in loaded_lycos: + module = lyco.modules.get(lyco_layer_name, None) + if module is not None and hasattr(self, 'weight'): + self.weight += lyco_calc_updown(lyco, module, self.weight) + continue + + module_q = lyco.modules.get(lyco_layer_name + "_q_proj", None) + module_k = lyco.modules.get(lyco_layer_name + "_k_proj", None) + module_v = lyco.modules.get(lyco_layer_name + "_v_proj", None) + module_out = lyco.modules.get(lyco_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + updown_q = lyco_calc_updown(lyco, module_q, self.in_proj_weight) + updown_k = lyco_calc_updown(lyco, module_k, self.in_proj_weight) + updown_v = lyco_calc_updown(lyco, module_v, self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + + self.in_proj_weight += updown_qkv + self.out_proj.weight += lyco_calc_updown(lyco, module_out, self.out_proj.weight) + continue + + if module is None: + continue + + print(f'failed to calculate lyco weights for layer {lyco_layer_name}') + + setattr(self, "lyco_current_names", wanted_names) + + +def lyco_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): + setattr(self, "lyco_current_names", ()) + setattr(self, "lora_weights_backup", None) + + +def lyco_Linear_forward(self, input): + lyco_apply_weights(self) + + return torch.nn.Linear_forward_before_lyco(self, input) + + +def lyco_Linear_load_state_dict(self, *args, **kwargs): + lyco_reset_cached_weight(self) + + return torch.nn.Linear_load_state_dict_before_lyco(self, *args, **kwargs) + + +def lyco_Conv2d_forward(self, input): + lyco_apply_weights(self) + + return torch.nn.Conv2d_forward_before_lyco(self, input) + + +def lyco_Conv2d_load_state_dict(self, *args, **kwargs): + lyco_reset_cached_weight(self) + + return torch.nn.Conv2d_load_state_dict_before_lyco(self, *args, **kwargs) + + +def lyco_MultiheadAttention_forward(self, *args, **kwargs): + lyco_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_lyco(self, *args, **kwargs) + + +def lyco_MultiheadAttention_load_state_dict(self, *args, **kwargs): + lyco_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_lyco(self, *args, **kwargs) + + +def list_available_lycos(): + available_lycos.clear() + + os.makedirs(shared.cmd_opts.lyco_dir, exist_ok=True) + + candidates = \ + glob.glob(os.path.join(shared.cmd_opts.lyco_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lyco_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lyco_dir, '**/*.ckpt'), recursive=True) + + for filename in sorted(candidates, key=str.lower): + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + + available_lycos[name] = LycoOnDisk(name, filename) + + +available_lycos = {} +loaded_lycos = [] + +list_available_lycos() \ No newline at end of file diff --git a/preload.py b/preload.py new file mode 100644 index 0000000..1638c3a --- /dev/null +++ b/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--lyco-dir", type=str, help="Path to directory with LyCORIS networks.", default=os.path.join(paths.models_path, 'LyCORIS')) diff --git a/scripts/lycoris_script.py b/scripts/lycoris_script.py new file mode 100644 index 0000000..8fa1742 --- /dev/null +++ b/scripts/lycoris_script.py @@ -0,0 +1,56 @@ +import torch +import gradio as gr + +import lycoris +import extra_networks_lyco +import ui_extra_networks_lyco +from modules import script_callbacks, ui_extra_networks, extra_networks, shared + + +def unload(): + torch.nn.Linear.forward = torch.nn.Linear_forward_before_lyco + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lyco + torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lyco + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lyco + torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lyco + torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lyco + + +def before_ui(): + ui_extra_networks.register_page(ui_extra_networks_lyco.ExtraNetworksPageLyCORIS()) + extra_networks.register_extra_network(extra_networks_lyco.ExtraNetworkLyCORIS()) + + +if not hasattr(torch.nn, 'Linear_forward_before_lyco'): + torch.nn.Linear_forward_before_lyco = torch.nn.Linear.forward + +if not hasattr(torch.nn, 'Linear_load_state_dict_before_lyco'): + torch.nn.Linear_load_state_dict_before_lyco = torch.nn.Linear._load_from_state_dict + +if not hasattr(torch.nn, 'Conv2d_forward_before_lyco'): + torch.nn.Conv2d_forward_before_lyco = torch.nn.Conv2d.forward + +if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lyco'): + torch.nn.Conv2d_load_state_dict_before_lyco = torch.nn.Conv2d._load_from_state_dict + +if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lyco'): + torch.nn.MultiheadAttention_forward_before_lyco = torch.nn.MultiheadAttention.forward + +if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lyco'): + torch.nn.MultiheadAttention_load_state_dict_before_lyco = torch.nn.MultiheadAttention._load_from_state_dict + +torch.nn.Linear.forward = lycoris.lyco_Linear_forward +torch.nn.Linear._load_from_state_dict = lycoris.lyco_Linear_load_state_dict +torch.nn.Conv2d.forward = lycoris.lyco_Conv2d_forward +torch.nn.Conv2d._load_from_state_dict = lycoris.lyco_Conv2d_load_state_dict +torch.nn.MultiheadAttention.forward = lycoris.lyco_MultiheadAttention_forward +torch.nn.MultiheadAttention._load_from_state_dict = lycoris.lyco_MultiheadAttention_load_state_dict + +script_callbacks.on_model_loaded(lycoris.assign_lyco_names_to_compvis_modules) +script_callbacks.on_script_unloaded(unload) +script_callbacks.on_before_ui(before_ui) + + +shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { + "sd_lyco": shared.OptionInfo("None", "Add LyCORIS to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lycoris.available_lycos]}, refresh=lycoris.list_available_lycos), +})) diff --git a/ui_extra_networks_lyco.py b/ui_extra_networks_lyco.py new file mode 100644 index 0000000..33ae09d --- /dev/null +++ b/ui_extra_networks_lyco.py @@ -0,0 +1,31 @@ +import json +import os +import lycoris + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageLyCORIS(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('LyCORIS') + + def refresh(self): + lycoris.list_available_lycos() + + def list_items(self): + for name, lyco_on_disk in lycoris.available_lycos.items(): + path, ext = os.path.splitext(lyco_on_disk.filename) + yield { + "name": name, + "filename": path, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(lyco_on_disk.filename), + "prompt": json.dumps(f""), + "local_preview": f"{path}.{shared.opts.samples_format}", + "metadata": json.dumps(lyco_on_disk.metadata, indent=4) if lyco_on_disk.metadata else None, + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.lyco_dir] +