Update networks.py for LyCORIS loading on Diffusers Backend

pull/2293/head
AI-Casanova 2023-10-07 12:10:31 -05:00 committed by GitHub
parent d3ce70d4a9
commit 0ddfb5d4ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 182 additions and 45 deletions

View File

@ -1,7 +1,8 @@
from typing import Union from typing import Dict, Union
import logging import logging
import os import os
import re import re
import bisect
import lora_patches import lora_patches
import network import network
import network_lora import network_lora
@ -41,6 +42,147 @@ suffix_conversion = {
} }
def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j + 1}."
sd_time_embed_prefix = f"time_embed.{j * 2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j + 1}."
sd_label_embed_prefix = f"label_emb.0.{j * 2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
return sd_hf_conversion_map
class KeyConvert:
def __init__(self):
if shared.backend == shared.Backend.ORIGINAL:
self.converter = self.original
self.is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
else:
self.converter = self.diffusers
self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
self.LORA_PREFIX_UNET = "lora_unet"
self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
def original(self, key):
key = convert_diffusers_name_to_compvis(key, self.is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
m = re_x_proj.match(key)
if m:
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
if sd_module is None and "lora_unet" in key:
key = key.replace("lora_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
elif sd_module is None and "lora_te1_text_model" in key:
key = key.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# some SD1 Loras also have correct compvis keys
if sd_module is None:
key = key.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
return key, sd_module
def diffusers(self, key):
map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
map_keys.sort()
if self.is_sdxl:
search_key = key.replace(self.LORA_PREFIX_UNET + "_", "").replace(self.LORA_PREFIX_TEXT_ENCODER1 + "_",
"").replace(
self.LORA_PREFIX_TEXT_ENCODER2 + "_", "")
position = bisect.bisect_right(map_keys, search_key)
map_key = map_keys[position - 1]
if search_key.startswith(map_key):
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key])
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
return key, sd_module
def __call__(self, key):
return self.converter(key)
def convert_diffusers_name_to_compvis(key, is_sd2): def convert_diffusers_name_to_compvis(key, is_sd2):
def match(match_list, regex_text): def match(match_list, regex_text):
regex = re_compiled.get(regex_text) regex = re_compiled.get(regex_text)
@ -109,9 +251,25 @@ def assign_network_names_to_compvis_modules(sd_model):
network_layer_mapping[network_name] = module network_layer_mapping[network_name] = module
module.network_layer_name = network_name module.network_layer_name = network_name
""" """
network_layer_mapping = {}
if shared.backend == shared.Backend.DIFFUSERS:
for name, module in shared.sd_model.text_encoder.named_modules():
prefix = "lora_te1_" if shared.sd_model_type == "sdxl" else "lora_te_"
network_name = prefix + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
if shared.sd_model_type == "sdxl":
for name, module in shared.sd_model.text_encoder_2.named_modules():
network_name = "lora_te2_" + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
for name, module in shared.sd_model.unet.named_modules():
network_name = "lora_unet_" + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
else:
if not hasattr(shared.sd_model, 'cond_stage_model'): if not hasattr(shared.sd_model, 'cond_stage_model'):
return return
network_layer_mapping = {}
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
network_name = name.replace(".", "_") network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module network_layer_mapping[network_name] = module
@ -128,30 +286,13 @@ def load_network(name, network_on_disk):
net.mtime = os.path.getmtime(network_on_disk.filename) net.mtime = os.path.getmtime(network_on_disk.filename)
sd = sd_models.read_state_dict(network_on_disk.filename) sd = sd_models.read_state_dict(network_on_disk.filename)
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
if not hasattr(shared.sd_model, 'network_layer_mapping'):
assign_network_names_to_compvis_modules(shared.sd_model) assign_network_names_to_compvis_modules(shared.sd_model)
keys_failed_to_match = {} keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {} matched_networks = {}
convert = KeyConvert()
for key_network, weight in sd.items(): for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1) key_network_without_network_parts, network_part = key_network.split(".", 1)
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) key, sd_module = convert(key_network_without_network_parts)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
m = re_x_proj.match(key)
if m:
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
if sd_module is None and "lora_unet" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# some SD1 Loras also have correct compvis keys
if sd_module is None:
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None: if sd_module is None:
keys_failed_to_match[key_network] = key keys_failed_to_match[key_network] = key
continue continue
@ -222,12 +363,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net = networks_in_memory.get(name) net = networks_in_memory.get(name)
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
try: try:
if shared.backend == shared.Backend.ORIGINAL:
net = load_network(name, network_on_disk) net = load_network(name, network_on_disk)
networks_in_memory.pop(name, None) networks_in_memory.pop(name, None)
networks_in_memory[name] = net networks_in_memory[name] = net
elif shared.backend == shared.Backend.DIFFUSERS:
net = load_diffusers(name, network_on_disk, te_multipliers[i] if te_multipliers else 1.0, unet_multipliers[i] if unet_multipliers else 1.0, dyn_dims[i] if dyn_dims else 1.0)
except Exception as e: except Exception as e:
errors.display(e, f"loading network {network_on_disk.filename}") errors.display(e, f"loading network {network_on_disk.filename}")
continue continue
@ -237,11 +375,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
failed_to_load_networks.append(name) failed_to_load_networks.append(name)
logging.info(f"Couldn't find network with name {name}") logging.info(f"Couldn't find network with name {name}")
continue continue
else:
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
if shared.backend == shared.Backend.ORIGINAL: # load_diffusers cache is handled separately
loaded_networks.append(net) loaded_networks.append(net)
if failed_to_load_networks: if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
@ -251,7 +387,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
shared.log.info("Networks: Recompiling model") shared.log.info("Networks: Recompiling model")
sd_models.compile_diffusers(shared.sd_model) sd_models.compile_diffusers(shared.sd_model)
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear, diffusers_lora.LoRACompatibleConv]):
weights_backup = getattr(self, "network_weights_backup", None) weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None) bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None and bias_backup is None: if weights_backup is None and bias_backup is None:
@ -275,7 +412,7 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
self.bias = None self.bias = None
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear, diffusers_lora.LoRACompatibleConv]):
""" """
Applies the currently selected set of networks to the weights of torch layer self. Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing. If weights already have this particular set of networks applied, does nothing.