Update networks.py for LyCORIS loading on Diffusers Backend

pull/2293/head
AI-Casanova 2023-10-07 12:10:31 -05:00 committed by GitHub
parent d3ce70d4a9
commit 0ddfb5d4ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 182 additions and 45 deletions

View File

@ -1,7 +1,8 @@
from typing import Union
from typing import Dict, Union
import logging
import os
import re
import bisect
import lora_patches
import network
import network_lora
@ -41,6 +42,147 @@ suffix_conversion = {
}
def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j + 1}."
sd_time_embed_prefix = f"time_embed.{j * 2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j + 1}."
sd_label_embed_prefix = f"label_emb.0.{j * 2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
return sd_hf_conversion_map
class KeyConvert:
def __init__(self):
if shared.backend == shared.Backend.ORIGINAL:
self.converter = self.original
self.is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
else:
self.converter = self.diffusers
self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
self.LORA_PREFIX_UNET = "lora_unet"
self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
def original(self, key):
key = convert_diffusers_name_to_compvis(key, self.is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
m = re_x_proj.match(key)
if m:
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
if sd_module is None and "lora_unet" in key:
key = key.replace("lora_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
elif sd_module is None and "lora_te1_text_model" in key:
key = key.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# some SD1 Loras also have correct compvis keys
if sd_module is None:
key = key.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
return key, sd_module
def diffusers(self, key):
map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
map_keys.sort()
if self.is_sdxl:
search_key = key.replace(self.LORA_PREFIX_UNET + "_", "").replace(self.LORA_PREFIX_TEXT_ENCODER1 + "_",
"").replace(
self.LORA_PREFIX_TEXT_ENCODER2 + "_", "")
position = bisect.bisect_right(map_keys, search_key)
map_key = map_keys[position - 1]
if search_key.startswith(map_key):
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key])
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
return key, sd_module
def __call__(self, key):
return self.converter(key)
def convert_diffusers_name_to_compvis(key, is_sd2):
def match(match_list, regex_text):
regex = re_compiled.get(regex_text)
@ -109,17 +251,33 @@ def assign_network_names_to_compvis_modules(sd_model):
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
"""
if not hasattr(shared.sd_model, 'cond_stage_model'):
return
network_layer_mapping = {}
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
for name, module in shared.sd_model.model.named_modules():
network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
if shared.backend == shared.Backend.DIFFUSERS:
for name, module in shared.sd_model.text_encoder.named_modules():
prefix = "lora_te1_" if shared.sd_model_type == "sdxl" else "lora_te_"
network_name = prefix + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
if shared.sd_model_type == "sdxl":
for name, module in shared.sd_model.text_encoder_2.named_modules():
network_name = "lora_te2_" + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
for name, module in shared.sd_model.unet.named_modules():
network_name = "lora_unet_" + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
else:
if not hasattr(shared.sd_model, 'cond_stage_model'):
return
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
for name, module in shared.sd_model.model.named_modules():
network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
sd_model.network_layer_mapping = network_layer_mapping
@ -128,30 +286,13 @@ def load_network(name, network_on_disk):
net.mtime = os.path.getmtime(network_on_disk.filename)
sd = sd_models.read_state_dict(network_on_disk.filename)
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
if not hasattr(shared.sd_model, 'network_layer_mapping'):
assign_network_names_to_compvis_modules(shared.sd_model)
assign_network_names_to_compvis_modules(shared.sd_model)
keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {}
convert = KeyConvert()
for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1)
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
m = re_x_proj.match(key)
if m:
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
if sd_module is None and "lora_unet" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# some SD1 Loras also have correct compvis keys
if sd_module is None:
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
key, sd_module = convert(key_network_without_network_parts)
if sd_module is None:
keys_failed_to_match[key_network] = key
continue
@ -222,12 +363,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net = networks_in_memory.get(name)
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
try:
if shared.backend == shared.Backend.ORIGINAL:
net = load_network(name, network_on_disk)
networks_in_memory.pop(name, None)
networks_in_memory[name] = net
elif shared.backend == shared.Backend.DIFFUSERS:
net = load_diffusers(name, network_on_disk, te_multipliers[i] if te_multipliers else 1.0, unet_multipliers[i] if unet_multipliers else 1.0, dyn_dims[i] if dyn_dims else 1.0)
net = load_network(name, network_on_disk)
networks_in_memory.pop(name, None)
networks_in_memory[name] = net
except Exception as e:
errors.display(e, f"loading network {network_on_disk.filename}")
continue
@ -237,12 +375,10 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
failed_to_load_networks.append(name)
logging.info(f"Couldn't find network with name {name}")
continue
else:
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
if shared.backend == shared.Backend.ORIGINAL: # load_diffusers cache is handled separately
loaded_networks.append(net)
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
purge_networks_from_memory()
@ -251,7 +387,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
shared.log.info("Networks: Recompiling model")
sd_models.compile_diffusers(shared.sd_model)
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear, diffusers_lora.LoRACompatibleConv]):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None and bias_backup is None:
@ -275,7 +412,7 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
self.bias = None
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear, diffusers_lora.LoRACompatibleConv]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.