mirror of https://github.com/vladmandic/automatic
Update networks.py for LyCORIS loading on Diffusers Backend
parent
d3ce70d4a9
commit
0ddfb5d4ed
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Union
|
from typing import Dict, Union
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import bisect
|
||||||
import lora_patches
|
import lora_patches
|
||||||
import network
|
import network
|
||||||
import network_lora
|
import network_lora
|
||||||
|
|
@ -41,6 +42,147 @@ suffix_conversion = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_unet_conversion_map() -> Dict[str, str]:
|
||||||
|
unet_conversion_map_layer = []
|
||||||
|
|
||||||
|
for i in range(3): # num_blocks is 3 in sdxl
|
||||||
|
# loop over downblocks/upblocks
|
||||||
|
for j in range(2):
|
||||||
|
# loop over resnets/attentions for downblocks
|
||||||
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||||
|
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no attention layers in down_blocks.3
|
||||||
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||||
|
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(3):
|
||||||
|
# loop over resnets/attentions for upblocks
|
||||||
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||||
|
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||||
|
|
||||||
|
# if i > 0: commentout for sdxl
|
||||||
|
# no attention layers in up_blocks.0
|
||||||
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||||
|
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no downsample in down_blocks.3
|
||||||
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||||
|
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||||
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||||
|
|
||||||
|
# no upsample in up_blocks.3
|
||||||
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||||
|
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{2}." # change for sdxl
|
||||||
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||||
|
|
||||||
|
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||||
|
sd_mid_atn_prefix = "middle_block.1."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||||
|
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map_resnet = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("in_layers.0.", "norm1."),
|
||||||
|
("in_layers.2.", "conv1."),
|
||||||
|
("out_layers.0.", "norm2."),
|
||||||
|
("out_layers.3.", "conv2."),
|
||||||
|
("emb_layers.1.", "time_emb_proj."),
|
||||||
|
("skip_connection.", "conv_shortcut."),
|
||||||
|
]
|
||||||
|
|
||||||
|
unet_conversion_map = []
|
||||||
|
for sd, hf in unet_conversion_map_layer:
|
||||||
|
if "resnets" in hf:
|
||||||
|
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||||
|
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||||
|
else:
|
||||||
|
unet_conversion_map.append((sd, hf))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_time_embed_prefix = f"time_embedding.linear_{j + 1}."
|
||||||
|
sd_time_embed_prefix = f"time_embed.{j * 2}."
|
||||||
|
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_label_embed_prefix = f"add_embedding.linear_{j + 1}."
|
||||||
|
sd_label_embed_prefix = f"label_emb.0.{j * 2}."
|
||||||
|
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||||
|
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||||
|
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||||
|
|
||||||
|
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
||||||
|
return sd_hf_conversion_map
|
||||||
|
|
||||||
|
|
||||||
|
class KeyConvert:
|
||||||
|
def __init__(self):
|
||||||
|
if shared.backend == shared.Backend.ORIGINAL:
|
||||||
|
self.converter = self.original
|
||||||
|
self.is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.converter = self.diffusers
|
||||||
|
self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
|
||||||
|
self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
|
||||||
|
self.LORA_PREFIX_UNET = "lora_unet"
|
||||||
|
self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||||
|
|
||||||
|
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
||||||
|
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
||||||
|
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
||||||
|
|
||||||
|
def original(self, key):
|
||||||
|
key = convert_diffusers_name_to_compvis(key, self.is_sd2)
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
if sd_module is None:
|
||||||
|
m = re_x_proj.match(key)
|
||||||
|
if m:
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
|
||||||
|
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
|
||||||
|
if sd_module is None and "lora_unet" in key:
|
||||||
|
key = key.replace("lora_unet", "diffusion_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
elif sd_module is None and "lora_te1_text_model" in key:
|
||||||
|
key = key.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
# some SD1 Loras also have correct compvis keys
|
||||||
|
if sd_module is None:
|
||||||
|
key = key.replace("lora_te1_text_model", "transformer_text_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
return key, sd_module
|
||||||
|
|
||||||
|
def diffusers(self, key):
|
||||||
|
map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
||||||
|
map_keys.sort()
|
||||||
|
|
||||||
|
if self.is_sdxl:
|
||||||
|
search_key = key.replace(self.LORA_PREFIX_UNET + "_", "").replace(self.LORA_PREFIX_TEXT_ENCODER1 + "_",
|
||||||
|
"").replace(
|
||||||
|
self.LORA_PREFIX_TEXT_ENCODER2 + "_", "")
|
||||||
|
position = bisect.bisect_right(map_keys, search_key)
|
||||||
|
map_key = map_keys[position - 1]
|
||||||
|
if search_key.startswith(map_key):
|
||||||
|
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key])
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
return key, sd_module
|
||||||
|
|
||||||
|
def __call__(self, key):
|
||||||
|
return self.converter(key)
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusers_name_to_compvis(key, is_sd2):
|
def convert_diffusers_name_to_compvis(key, is_sd2):
|
||||||
def match(match_list, regex_text):
|
def match(match_list, regex_text):
|
||||||
regex = re_compiled.get(regex_text)
|
regex = re_compiled.get(regex_text)
|
||||||
|
|
@ -109,17 +251,33 @@ def assign_network_names_to_compvis_modules(sd_model):
|
||||||
network_layer_mapping[network_name] = module
|
network_layer_mapping[network_name] = module
|
||||||
module.network_layer_name = network_name
|
module.network_layer_name = network_name
|
||||||
"""
|
"""
|
||||||
if not hasattr(shared.sd_model, 'cond_stage_model'):
|
|
||||||
return
|
|
||||||
network_layer_mapping = {}
|
network_layer_mapping = {}
|
||||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
if shared.backend == shared.Backend.DIFFUSERS:
|
||||||
network_name = name.replace(".", "_")
|
for name, module in shared.sd_model.text_encoder.named_modules():
|
||||||
network_layer_mapping[network_name] = module
|
prefix = "lora_te1_" if shared.sd_model_type == "sdxl" else "lora_te_"
|
||||||
module.network_layer_name = network_name
|
network_name = prefix + name.replace(".", "_")
|
||||||
for name, module in shared.sd_model.model.named_modules():
|
network_layer_mapping[network_name] = module
|
||||||
network_name = name.replace(".", "_")
|
module.network_layer_name = network_name
|
||||||
network_layer_mapping[network_name] = module
|
if shared.sd_model_type == "sdxl":
|
||||||
module.network_layer_name = network_name
|
for name, module in shared.sd_model.text_encoder_2.named_modules():
|
||||||
|
network_name = "lora_te2_" + name.replace(".", "_")
|
||||||
|
network_layer_mapping[network_name] = module
|
||||||
|
module.network_layer_name = network_name
|
||||||
|
for name, module in shared.sd_model.unet.named_modules():
|
||||||
|
network_name = "lora_unet_" + name.replace(".", "_")
|
||||||
|
network_layer_mapping[network_name] = module
|
||||||
|
module.network_layer_name = network_name
|
||||||
|
else:
|
||||||
|
if not hasattr(shared.sd_model, 'cond_stage_model'):
|
||||||
|
return
|
||||||
|
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||||
|
network_name = name.replace(".", "_")
|
||||||
|
network_layer_mapping[network_name] = module
|
||||||
|
module.network_layer_name = network_name
|
||||||
|
for name, module in shared.sd_model.model.named_modules():
|
||||||
|
network_name = name.replace(".", "_")
|
||||||
|
network_layer_mapping[network_name] = module
|
||||||
|
module.network_layer_name = network_name
|
||||||
sd_model.network_layer_mapping = network_layer_mapping
|
sd_model.network_layer_mapping = network_layer_mapping
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -128,30 +286,13 @@ def load_network(name, network_on_disk):
|
||||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||||
sd = sd_models.read_state_dict(network_on_disk.filename)
|
sd = sd_models.read_state_dict(network_on_disk.filename)
|
||||||
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
|
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
|
||||||
if not hasattr(shared.sd_model, 'network_layer_mapping'):
|
assign_network_names_to_compvis_modules(shared.sd_model)
|
||||||
assign_network_names_to_compvis_modules(shared.sd_model)
|
|
||||||
keys_failed_to_match = {}
|
keys_failed_to_match = {}
|
||||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
|
|
||||||
matched_networks = {}
|
matched_networks = {}
|
||||||
|
convert = KeyConvert()
|
||||||
for key_network, weight in sd.items():
|
for key_network, weight in sd.items():
|
||||||
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
||||||
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
|
key, sd_module = convert(key_network_without_network_parts)
|
||||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
|
||||||
if sd_module is None:
|
|
||||||
m = re_x_proj.match(key)
|
|
||||||
if m:
|
|
||||||
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
|
|
||||||
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
|
|
||||||
if sd_module is None and "lora_unet" in key_network_without_network_parts:
|
|
||||||
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
|
|
||||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
|
||||||
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
|
|
||||||
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
|
||||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
|
||||||
# some SD1 Loras also have correct compvis keys
|
|
||||||
if sd_module is None:
|
|
||||||
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
|
|
||||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
|
||||||
if sd_module is None:
|
if sd_module is None:
|
||||||
keys_failed_to_match[key_network] = key
|
keys_failed_to_match[key_network] = key
|
||||||
continue
|
continue
|
||||||
|
|
@ -222,12 +363,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||||
net = networks_in_memory.get(name)
|
net = networks_in_memory.get(name)
|
||||||
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
||||||
try:
|
try:
|
||||||
if shared.backend == shared.Backend.ORIGINAL:
|
net = load_network(name, network_on_disk)
|
||||||
net = load_network(name, network_on_disk)
|
networks_in_memory.pop(name, None)
|
||||||
networks_in_memory.pop(name, None)
|
networks_in_memory[name] = net
|
||||||
networks_in_memory[name] = net
|
|
||||||
elif shared.backend == shared.Backend.DIFFUSERS:
|
|
||||||
net = load_diffusers(name, network_on_disk, te_multipliers[i] if te_multipliers else 1.0, unet_multipliers[i] if unet_multipliers else 1.0, dyn_dims[i] if dyn_dims else 1.0)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, f"loading network {network_on_disk.filename}")
|
errors.display(e, f"loading network {network_on_disk.filename}")
|
||||||
continue
|
continue
|
||||||
|
|
@ -237,12 +375,10 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||||
failed_to_load_networks.append(name)
|
failed_to_load_networks.append(name)
|
||||||
logging.info(f"Couldn't find network with name {name}")
|
logging.info(f"Couldn't find network with name {name}")
|
||||||
continue
|
continue
|
||||||
else:
|
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
||||||
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
|
||||||
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
|
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
|
||||||
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
|
loaded_networks.append(net)
|
||||||
if shared.backend == shared.Backend.ORIGINAL: # load_diffusers cache is handled separately
|
|
||||||
loaded_networks.append(net)
|
|
||||||
if failed_to_load_networks:
|
if failed_to_load_networks:
|
||||||
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||||
purge_networks_from_memory()
|
purge_networks_from_memory()
|
||||||
|
|
@ -251,7 +387,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||||
shared.log.info("Networks: Recompiling model")
|
shared.log.info("Networks: Recompiling model")
|
||||||
sd_models.compile_diffusers(shared.sd_model)
|
sd_models.compile_diffusers(shared.sd_model)
|
||||||
|
|
||||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
|
||||||
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear, diffusers_lora.LoRACompatibleConv]):
|
||||||
weights_backup = getattr(self, "network_weights_backup", None)
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
bias_backup = getattr(self, "network_bias_backup", None)
|
bias_backup = getattr(self, "network_bias_backup", None)
|
||||||
if weights_backup is None and bias_backup is None:
|
if weights_backup is None and bias_backup is None:
|
||||||
|
|
@ -275,7 +412,7 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
|
||||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear, diffusers_lora.LoRACompatibleConv]):
|
||||||
"""
|
"""
|
||||||
Applies the currently selected set of networks to the weights of torch layer self.
|
Applies the currently selected set of networks to the weights of torch layer self.
|
||||||
If weights already have this particular set of networks applied, does nothing.
|
If weights already have this particular set of networks applied, does nothing.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue