From 0ddfb5d4ed3fefc05f2f4e82d1c013f8a7dce196 Mon Sep 17 00:00:00 2001 From: AI-Casanova <54461896+AI-Casanova@users.noreply.github.com> Date: Sat, 7 Oct 2023 12:10:31 -0500 Subject: [PATCH] Update networks.py for LyCORIS loading on Diffusers Backend --- extensions-builtin/Lora/networks.py | 227 ++++++++++++++++++++++------ 1 file changed, 182 insertions(+), 45 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 63c078981..a4c98dfe9 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,7 +1,8 @@ -from typing import Union +from typing import Dict, Union import logging import os import re +import bisect import lora_patches import network import network_lora @@ -41,6 +42,147 @@ suffix_conversion = { } +def make_unet_conversion_map() -> Dict[str, str]: + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3 * i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2 * j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j + 1}." + sd_time_embed_prefix = f"time_embed.{j * 2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j + 1}." + sd_label_embed_prefix = f"label_emb.0.{j * 2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map} + return sd_hf_conversion_map + + +class KeyConvert: + def __init__(self): + if shared.backend == shared.Backend.ORIGINAL: + self.converter = self.original + self.is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping + + else: + self.converter = self.diffusers + self.is_sdxl = True if shared.sd_model_type == "sdxl" else False + self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None + self.LORA_PREFIX_UNET = "lora_unet" + self.LORA_PREFIX_TEXT_ENCODER = "lora_te" + + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" + self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" + + def original(self, key): + key = convert_diffusers_name_to_compvis(key, self.is_sd2) + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + if sd_module is None: + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) + # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" + if sd_module is None and "lora_unet" in key: + key = key.replace("lora_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "lora_te1_text_model" in key: + key = key.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + # some SD1 Loras also have correct compvis keys + if sd_module is None: + key = key.replace("lora_te1_text_model", "transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + return key, sd_module + + def diffusers(self, key): + map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules + map_keys.sort() + + if self.is_sdxl: + search_key = key.replace(self.LORA_PREFIX_UNET + "_", "").replace(self.LORA_PREFIX_TEXT_ENCODER1 + "_", + "").replace( + self.LORA_PREFIX_TEXT_ENCODER2 + "_", "") + position = bisect.bisect_right(map_keys, search_key) + map_key = map_keys[position - 1] + if search_key.startswith(map_key): + key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]) + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + return key, sd_module + + def __call__(self, key): + return self.converter(key) + + def convert_diffusers_name_to_compvis(key, is_sd2): def match(match_list, regex_text): regex = re_compiled.get(regex_text) @@ -109,17 +251,33 @@ def assign_network_names_to_compvis_modules(sd_model): network_layer_mapping[network_name] = module module.network_layer_name = network_name """ - if not hasattr(shared.sd_model, 'cond_stage_model'): - return network_layer_mapping = {} - for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): - network_name = name.replace(".", "_") - network_layer_mapping[network_name] = module - module.network_layer_name = network_name - for name, module in shared.sd_model.model.named_modules(): - network_name = name.replace(".", "_") - network_layer_mapping[network_name] = module - module.network_layer_name = network_name + if shared.backend == shared.Backend.DIFFUSERS: + for name, module in shared.sd_model.text_encoder.named_modules(): + prefix = "lora_te1_" if shared.sd_model_type == "sdxl" else "lora_te_" + network_name = prefix + name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + if shared.sd_model_type == "sdxl": + for name, module in shared.sd_model.text_encoder_2.named_modules(): + network_name = "lora_te2_" + name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + for name, module in shared.sd_model.unet.named_modules(): + network_name = "lora_unet_" + name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + else: + if not hasattr(shared.sd_model, 'cond_stage_model'): + return + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + for name, module in shared.sd_model.model.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name sd_model.network_layer_mapping = network_layer_mapping @@ -128,30 +286,13 @@ def load_network(name, network_on_disk): net.mtime = os.path.getmtime(network_on_disk.filename) sd = sd_models.read_state_dict(network_on_disk.filename) # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 - if not hasattr(shared.sd_model, 'network_layer_mapping'): - assign_network_names_to_compvis_modules(shared.sd_model) + assign_network_names_to_compvis_modules(shared.sd_model) keys_failed_to_match = {} - is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping matched_networks = {} + convert = KeyConvert() for key_network, weight in sd.items(): key_network_without_network_parts, network_part = key_network.split(".", 1) - key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) - sd_module = shared.sd_model.network_layer_mapping.get(key, None) - if sd_module is None: - m = re_x_proj.match(key) - if m: - sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) - # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" - if sd_module is None and "lora_unet" in key_network_without_network_parts: - key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) - elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: - key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) - # some SD1 Loras also have correct compvis keys - if sd_module is None: - key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) + key, sd_module = convert(key_network_without_network_parts) if sd_module is None: keys_failed_to_match[key_network] = key continue @@ -222,12 +363,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No net = networks_in_memory.get(name) if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: try: - if shared.backend == shared.Backend.ORIGINAL: - net = load_network(name, network_on_disk) - networks_in_memory.pop(name, None) - networks_in_memory[name] = net - elif shared.backend == shared.Backend.DIFFUSERS: - net = load_diffusers(name, network_on_disk, te_multipliers[i] if te_multipliers else 1.0, unet_multipliers[i] if unet_multipliers else 1.0, dyn_dims[i] if dyn_dims else 1.0) + net = load_network(name, network_on_disk) + networks_in_memory.pop(name, None) + networks_in_memory[name] = net except Exception as e: errors.display(e, f"loading network {network_on_disk.filename}") continue @@ -237,12 +375,10 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No failed_to_load_networks.append(name) logging.info(f"Couldn't find network with name {name}") continue - else: - net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 - net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 - net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 - if shared.backend == shared.Backend.ORIGINAL: # load_diffusers cache is handled separately - loaded_networks.append(net) + net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 + net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 + net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 + loaded_networks.append(net) if failed_to_load_networks: sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) purge_networks_from_memory() @@ -251,7 +387,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No shared.log.info("Networks: Recompiling model") sd_models.compile_diffusers(shared.sd_model) -def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): + +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear, diffusers_lora.LoRACompatibleConv]): weights_backup = getattr(self, "network_weights_backup", None) bias_backup = getattr(self, "network_bias_backup", None) if weights_backup is None and bias_backup is None: @@ -275,7 +412,7 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li self.bias = None -def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear, diffusers_lora.LoRACompatibleConv]): """ Applies the currently selected set of networks to the weights of torch layer self. If weights already have this particular set of networks applied, does nothing.