Major lora refactor: works on my machine edition

pull/3593/head
AI-Casanova 2024-11-23 21:57:03 -06:00
parent 2b147272f8
commit cb561fa486
No known key found for this signature in database
GPG Key ID: 2A04488D60A5BF98
21 changed files with 2225 additions and 7 deletions

View File

@ -5,7 +5,7 @@ from lora_extract import create_ui
from network import NetworkOnDisk
from ui_extra_networks_lora import ExtraNetworksPageLora
from extra_networks_lora import ExtraNetworkLora
from modules import script_callbacks, extra_networks, ui_extra_networks, ui_models # pylint: disable=unused-import
from modules import script_callbacks, extra_networks, ui_extra_networks, ui_models, shared # pylint: disable=unused-import
re_lora = re.compile("<lora:([^:]+):")
@ -56,9 +56,9 @@ def infotext_pasted(infotext, d): # pylint: disable=unused-argument
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
script_callbacks.on_app_started(api_networks)
script_callbacks.on_before_ui(before_ui)
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
script_callbacks.on_infotext_pasted(infotext_pasted)
if not shared.native:
script_callbacks.on_app_started(api_networks)
script_callbacks.on_before_ui(before_ui)
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
script_callbacks.on_infotext_pasted(infotext_pasted)

View File

@ -0,0 +1,151 @@
import re
import time
import numpy as np
import modules.lora.networks as networks
from modules import extra_networks, shared
# from https://github.com/cheald/sd-webui-loractl/blob/master/loractl/lib/utils.py
def get_stepwise(param, step, steps):
def sorted_positions(raw_steps):
steps = [[float(s.strip()) for s in re.split("[@~]", x)]
for x in re.split("[,;]", str(raw_steps))]
if len(steps[0]) == 1: # If we just got a single number, just return it
return steps[0][0]
steps = [[s[0], s[1] if len(s) == 2 else 1] for s in steps] # Add implicit 1s to any steps which don't have a weight
steps.sort(key=lambda k: k[1]) # Sort by index
steps = [list(v) for v in zip(*steps)]
return steps
def calculate_weight(m, step, max_steps, step_offset=2):
if isinstance(m, list):
if m[1][-1] <= 1.0:
step = step / (max_steps - step_offset) if max_steps > 0 else 1.0
v = np.interp(step, m[1], m[0])
return v
else:
return m
stepwise = calculate_weight(sorted_positions(param), step, steps)
return stepwise
def prompt(p):
if shared.opts.lora_apply_tags == 0:
return
all_tags = []
for loaded in networks.loaded_networks:
page = [en for en in shared.extra_networks if en.name == 'lora'][0]
item = page.create_item(loaded.name)
tags = (item or {}).get("tags", {})
loaded.tags = list(tags)
if len(loaded.tags) == 0:
loaded.tags.append(loaded.name)
if shared.opts.lora_apply_tags > 0:
loaded.tags = loaded.tags[:shared.opts.lora_apply_tags]
all_tags.extend(loaded.tags)
if len(all_tags) > 0:
shared.log.debug(f"Load network: type=LoRA tags={all_tags} max={shared.opts.lora_apply_tags} apply")
all_tags = ', '.join(all_tags)
p.extra_generation_params["LoRA tags"] = all_tags
if '_tags_' in p.prompt:
p.prompt = p.prompt.replace('_tags_', all_tags)
else:
p.prompt = f"{p.prompt}, {all_tags}"
if p.all_prompts is not None:
for i in range(len(p.all_prompts)):
if '_tags_' in p.all_prompts[i]:
p.all_prompts[i] = p.all_prompts[i].replace('_tags_', all_tags)
else:
p.all_prompts[i] = f"{p.all_prompts[i]}, {all_tags}"
def infotext(p):
names = [i.name for i in networks.loaded_networks]
if len(names) > 0:
p.extra_generation_params["LoRA networks"] = ", ".join(names)
if shared.opts.lora_add_hashes_to_infotext:
network_hashes = []
for item in networks.loaded_networks:
if not item.network_on_disk.shorthash:
continue
network_hashes.append(item.network_on_disk.shorthash)
if len(network_hashes) > 0:
p.extra_generation_params["LoRA hashes"] = ", ".join(network_hashes)
def parse(p, params_list, step=0):
names = []
te_multipliers = []
unet_multipliers = []
dyn_dims = []
for params in params_list:
assert params.items
names.append(params.positional[0])
te_multiplier = params.named.get("te", params.positional[1] if len(params.positional) > 1 else shared.opts.extra_networks_default_multiplier)
if isinstance(te_multiplier, str) and "@" in te_multiplier:
te_multiplier = get_stepwise(te_multiplier, step, p.steps)
else:
te_multiplier = float(te_multiplier)
unet_multiplier = [params.positional[2] if len(params.positional) > 2 else te_multiplier] * 3
unet_multiplier = [params.named.get("unet", unet_multiplier[0])] * 3
unet_multiplier[0] = params.named.get("in", unet_multiplier[0])
unet_multiplier[1] = params.named.get("mid", unet_multiplier[1])
unet_multiplier[2] = params.named.get("out", unet_multiplier[2])
for i in range(len(unet_multiplier)):
if isinstance(unet_multiplier[i], str) and "@" in unet_multiplier[i]:
unet_multiplier[i] = get_stepwise(unet_multiplier[i], step, p.steps)
else:
unet_multiplier[i] = float(unet_multiplier[i])
dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
te_multipliers.append(te_multiplier)
unet_multipliers.append(unet_multiplier)
dyn_dims.append(dyn_dim)
return names, te_multipliers, unet_multipliers, dyn_dims
class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')
self.active = False
self.model = None
self.errors = {}
def activate(self, p, params_list, step=0):
t0 = time.time()
self.errors.clear()
if self.active:
if self.model != shared.opts.sd_model_checkpoint: # reset if model changed
self.active = False
if len(params_list) > 0 and not self.active: # activate patches once
shared.log.debug(f'Activate network: type=LoRA model="{shared.opts.sd_model_checkpoint}"')
self.active = True
self.model = shared.opts.sd_model_checkpoint
names, te_multipliers, unet_multipliers, dyn_dims = parse(p, params_list, step)
networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
t1 = time.time()
if len(networks.loaded_networks) > 0 and step == 0:
infotext(p)
prompt(p)
shared.log.info(f'Load network: type=LoRA apply={[n.name for n in networks.loaded_networks]} te={te_multipliers} unet={unet_multipliers} dims={dyn_dims} load={t1-t0:.2f}')
def deactivate(self, p):
t0 = time.time()
if shared.native and len(networks.diffuser_loaded) > 0:
if hasattr(shared.sd_model, "unload_lora_weights") and hasattr(shared.sd_model, "text_encoder"):
if not (shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled is True):
try:
if shared.opts.lora_fuse_diffusers:
shared.sd_model.unfuse_lora()
shared.sd_model.unload_lora_weights() # fails for non-CLIP models
except Exception:
pass
t1 = time.time()
networks.timer['restore'] += t1 - t0
if self.active and networks.debug:
shared.log.debug(f"Network end: type=LoRA load={networks.timer['load']:.2f} apply={networks.timer['apply']:.2f} restore={networks.timer['restore']:.2f}")
if self.errors:
for k, v in self.errors.items():
shared.log.error(f'LoRA: name="{k}" errors={v}')
self.errors.clear()

8
modules/lora/lora.py Normal file
View File

@ -0,0 +1,8 @@
# import networks
#
# list_available_loras = networks.list_available_networks
# available_loras = networks.available_networks
# available_lora_aliases = networks.available_network_aliases
# available_lora_hash_lookup = networks.available_network_hash_lookup
# forbidden_lora_aliases = networks.forbidden_network_aliases
# loaded_loras = networks.loaded_networks

View File

@ -0,0 +1,477 @@
import os
import re
import bisect
from typing import Dict
import torch
from modules import shared
debug = os.environ.get('SD_LORA_DEBUG', None) is not None
suffix_conversion = {
"attentions": {},
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"norm1": "in_layers_0",
"norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
}
re_digits = re.compile(r"\d+")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_compiled = {}
def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j + 1}."
sd_time_embed_prefix = f"time_embed.{j * 2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j + 1}."
sd_label_embed_prefix = f"label_emb.0.{j * 2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
return sd_hf_conversion_map
class KeyConvert:
def __init__(self):
self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
self.LORA_PREFIX_UNET = "lora_unet_"
self.LORA_PREFIX_TEXT_ENCODER = "lora_te_"
self.OFT_PREFIX_UNET = "oft_unet_"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1_"
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2_"
def __call__(self, key):
if self.is_sdxl:
if "diffusion_model" in key: # Fix NTC Slider naming error
key = key.replace("diffusion_model", "lora_unet")
map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
map_keys.sort()
search_key = key.replace(self.LORA_PREFIX_UNET, "").replace(self.OFT_PREFIX_UNET, "").replace(self.LORA_PREFIX_TEXT_ENCODER1, "").replace(self.LORA_PREFIX_TEXT_ENCODER2, "")
position = bisect.bisect_right(map_keys, search_key)
map_key = map_keys[position - 1]
if search_key.startswith(map_key):
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft", "lora") # pylint: disable=unsubscriptable-object
if "lycoris" in key and "transformer" in key:
key = key.replace("lycoris", "lora_transformer")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
sd_module = shared.sd_model.network_layer_mapping.get(key.replace("guidance", "timestep"), None) # FLUX1 fix
if debug and sd_module is None:
raise RuntimeError(f"LoRA key not found in network_layer_mapping: key={key} mapping={shared.sd_model.network_layer_mapping.keys()}")
return key, sd_module
# Taken from https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py
# Modified from 'lora_A' and 'lora_B' to 'lora_down' and 'lora_up'
# Added early exit
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
# All credits go to `kohya-ss`.
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
ait_sd[ait_key + ".lora_down.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_up.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]
# scale weight by alpha and dim
alpha = sds_sd.pop(sds_key + ".alpha")
scale = alpha / sd_lora_rank
# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
# calculate dims if not provided
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]
# check upweight is sparse or not
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
)
i += dims[j]
# if is_sparse:
# print(f"weight is sparse: {sds_key}")
# make ai-toolkit weight
ait_down_keys = [k + ".lora_down.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_up.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
else:
# down_weight is chunked to each split
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
# up_weight is sparse: only non-zero values are copied to each split
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]
def _convert_text_encoder_lora_key(key, lora_name):
"""
Converts a text encoder LoRA key to a Diffusers compatible key.
"""
if lora_name.startswith(("lora_te_", "lora_te1_")):
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
else:
key_to_replace = "lora_te2_"
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
pass
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
return diffusers_name
def _convert_kohya_flux_lora_to_diffusers(state_dict):
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
ait_sd = {}
for i in range(19):
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out.0",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_0",
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_2",
f"transformer.transformer_blocks.{i}.ff.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mod_lin",
f"transformer.transformer_blocks.{i}.norm1.linear",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_0",
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_2",
f"transformer.transformer_blocks.{i}.ff_context.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mod_lin",
f"transformer.transformer_blocks.{i}.norm1_context.linear",
)
for i in range(38):
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
[
f"transformer.single_transformer_blocks.{i}.attn.to_q",
f"transformer.single_transformer_blocks.{i}.attn.to_k",
f"transformer.single_transformer_blocks.{i}.attn.to_v",
f"transformer.single_transformer_blocks.{i}.proj_mlp",
],
dims=[3072, 3072, 3072, 12288],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear2",
f"transformer.single_transformer_blocks.{i}.proj_out",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_modulation_lin",
f"transformer.single_transformer_blocks.{i}.norm.linear",
)
if len(sds_sd) > 0:
return None
return ait_sd
return _convert_sd_scripts_to_ai_toolkit(state_dict)
def _convert_kohya_sd3_lora_to_diffusers(state_dict):
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
ait_sd = {}
for i in range(38):
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_joint_blocks_{i}_context_block_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_joint_blocks_{i}_context_block_mlp_fc1",
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_joint_blocks_{i}_context_block_mlp_fc2",
f"transformer.transformer_blocks.{i}.ff_context.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_joint_blocks_{i}_x_block_mlp_fc1",
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_joint_blocks_{i}_x_block_mlp_fc2",
f"transformer.transformer_blocks.{i}.ff.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1",
f"transformer.transformer_blocks.{i}.norm1_context.linear",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_joint_blocks_{i}_x_block_adaLN_modulation_1",
f"transformer.transformer_blocks.{i}.norm1.linear",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_joint_blocks_{i}_context_block_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_joint_blocks_{i}_x_block_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out_0",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_joint_blocks_{i}_x_block_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te1") for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
continue
lora_name = key.split(".")[0]
lora_name_up = f"{lora_name}.lora_up.weight"
lora_name_alpha = f"{lora_name}.alpha"
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
if lora_name.startswith(("lora_te_", "lora_te1_")):
down_weight = sds_sd.pop(key)
sd_lora_rank = down_weight.shape[0]
te_state_dict[diffusers_name] = down_weight
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
if lora_name_alpha in sds_sd:
alpha = sds_sd.pop(lora_name_alpha).item()
scale = alpha / sd_lora_rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
te_state_dict[diffusers_name] *= scale_down
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
if len(sds_sd) > 0:
print(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
if te_state_dict:
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict
return _convert_sd_scripts_to_ai_toolkit(state_dict)

View File

@ -0,0 +1,271 @@
import os
import time
import json
import datetime
import torch
from safetensors.torch import save_file
import gradio as gr
from rich import progress as p
from modules import shared, devices
from modules.ui_common import create_refresh_button
from modules.call_queue import wrap_gradio_gpu_call
class SVDHandler:
def __init__(self, maxrank=0, rank_ratio=1):
self.network_name: str = None
self.U: torch.Tensor = None
self.S: torch.Tensor = None
self.Vh: torch.Tensor = None
self.maxrank: int = maxrank
self.rank_ratio: float = rank_ratio
self.rank: int = 0
self.out_size: int = None
self.in_size: int = None
self.kernel_size: tuple[int, int] = None
self.conv2d: bool = False
def decompose(self, weight, backupweight):
self.conv2d = len(weight.size()) == 4
self.kernel_size = None if not self.conv2d else weight.size()[2:4]
self.out_size, self.in_size = weight.size()[0:2]
diffweight = weight.clone().to(devices.device)
diffweight -= backupweight.to(devices.device)
if self.conv2d:
if self.conv2d and self.kernel_size != (1, 1):
diffweight = diffweight.flatten(start_dim=1)
else:
diffweight = diffweight.squeeze()
self.U, self.S, self.Vh = torch.svd_lowrank(diffweight.to(device=devices.device, dtype=torch.float), self.maxrank, 2)
# del diffweight
self.U = self.U.to(device=devices.cpu, dtype=torch.bfloat16)
self.S = self.S.to(device=devices.cpu, dtype=torch.bfloat16)
self.Vh = self.Vh.t().to(device=devices.cpu, dtype=torch.bfloat16) # svd_lowrank outputs a transposed matrix
def findrank(self):
if self.rank_ratio < 1:
S_squared = self.S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, self.rank_ratio ** 2)) + 1
index = max(1, min(index, len(self.S) - 1))
self.rank = index
if self.maxrank > 0:
self.rank = min(self.rank, self.maxrank)
else:
self.rank = min(self.in_size, self.out_size, self.maxrank)
def makeweights(self):
self.findrank()
up = self.U[:, :self.rank] @ torch.diag(self.S[:self.rank])
down = self.Vh[:self.rank, :]
if self.conv2d and self.kernel_size is not None:
up = up.reshape(self.out_size, self.rank, 1, 1)
down = down.reshape(self.rank, self.in_size, self.kernel_size[0], self.kernel_size[1]) # pylint: disable=unsubscriptable-object
return_dict = {f'{self.network_name}.lora_up.weight': up.contiguous(),
f'{self.network_name}.lora_down.weight': down.contiguous(),
f'{self.network_name}.alpha': torch.tensor(down.shape[0]),
}
return return_dict
def loaded_lora():
if not shared.sd_loaded:
return ""
loaded = set()
if hasattr(shared.sd_model, 'unet'):
for _name, module in shared.sd_model.unet.named_modules():
current = getattr(module, "network_current_names", None)
if current is not None:
current = [item[0] for item in current]
loaded.update(current)
return list(loaded)
def loaded_lora_str():
return ", ".join(loaded_lora())
def make_meta(fn, maxrank, rank_ratio):
meta = {
"model_spec.sai_model_spec": "1.0.0",
"model_spec.title": os.path.splitext(os.path.basename(fn))[0],
"model_spec.author": "SD.Next",
"model_spec.implementation": "https://github.com/vladmandic/automatic",
"model_spec.date": datetime.datetime.now().astimezone().replace(microsecond=0).isoformat(),
"model_spec.base_model": shared.opts.sd_model_checkpoint,
"model_spec.dtype": str(devices.dtype),
"model_spec.base_lora": json.dumps(loaded_lora()),
"model_spec.config": f"maxrank={maxrank} rank_ratio={rank_ratio}",
}
if shared.sd_model_type == "sdxl":
meta["model_spec.architecture"] = "stable-diffusion-xl-v1-base/lora" # sai standard
meta["ss_base_model_version"] = "sdxl_base_v1-0" # kohya standard
elif shared.sd_model_type == "sd":
meta["model_spec.architecture"] = "stable-diffusion-v1/lora"
meta["ss_base_model_version"] = "sd_v1"
elif shared.sd_model_type == "f1":
meta["model_spec.architecture"] = "flux-1-dev/lora"
meta["ss_base_model_version"] = "flux1"
elif shared.sd_model_type == "sc":
meta["model_spec.architecture"] = "stable-cascade-v1-prior/lora"
return meta
def make_lora(fn, maxrank, auto_rank, rank_ratio, modules, overwrite):
if not shared.sd_loaded or not shared.native:
msg = "LoRA extract: model not loaded"
shared.log.warning(msg)
yield msg
return
if loaded_lora() == "":
msg = "LoRA extract: no LoRA detected"
shared.log.warning(msg)
yield msg
return
if not fn:
msg = "LoRA extract: target filename required"
shared.log.warning(msg)
yield msg
return
t0 = time.time()
maxrank = int(maxrank)
rank_ratio = 1 if not auto_rank else rank_ratio
shared.log.debug(f'LoRA extract: modules={modules} maxrank={maxrank} auto={auto_rank} ratio={rank_ratio} fn="{fn}"')
shared.state.begin('LoRA extract')
with p.Progress(p.TextColumn('[cyan]LoRA extract'), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TextColumn('[cyan]{task.description}'), console=shared.console) as progress:
if 'te' in modules and getattr(shared.sd_model, 'text_encoder', None) is not None:
modules = shared.sd_model.text_encoder.named_modules()
task = progress.add_task(description="te1 decompose", total=len(list(modules)))
for name, module in shared.sd_model.text_encoder.named_modules():
progress.update(task, advance=1)
weights_backup = getattr(module, "network_weights_backup", None)
if weights_backup is None or getattr(module, "network_current_names", None) is None:
continue
prefix = "lora_te1_" if hasattr(shared.sd_model, 'text_encoder_2') else "lora_te_"
module.svdhandler = SVDHandler(maxrank, rank_ratio)
module.svdhandler.network_name = prefix + name.replace(".", "_")
with devices.inference_context():
module.svdhandler.decompose(module.weight, weights_backup)
progress.remove_task(task)
t1 = time.time()
if 'te' in modules and getattr(shared.sd_model, 'text_encoder_2', None) is not None:
modules = shared.sd_model.text_encoder_2.named_modules()
task = progress.add_task(description="te2 decompose", total=len(list(modules)))
for name, module in shared.sd_model.text_encoder_2.named_modules():
progress.update(task, advance=1)
weights_backup = getattr(module, "network_weights_backup", None)
if weights_backup is None or getattr(module, "network_current_names", None) is None:
continue
module.svdhandler = SVDHandler(maxrank, rank_ratio)
module.svdhandler.network_name = "lora_te2_" + name.replace(".", "_")
with devices.inference_context():
module.svdhandler.decompose(module.weight, weights_backup)
progress.remove_task(task)
t2 = time.time()
if 'unet' in modules and getattr(shared.sd_model, 'unet', None) is not None:
modules = shared.sd_model.unet.named_modules()
task = progress.add_task(description="unet decompose", total=len(list(modules)))
for name, module in shared.sd_model.unet.named_modules():
progress.update(task, advance=1)
weights_backup = getattr(module, "network_weights_backup", None)
if weights_backup is None or getattr(module, "network_current_names", None) is None:
continue
module.svdhandler = SVDHandler(maxrank, rank_ratio)
module.svdhandler.network_name = "lora_unet_" + name.replace(".", "_")
with devices.inference_context():
module.svdhandler.decompose(module.weight, weights_backup)
progress.remove_task(task)
t3 = time.time()
# TODO: Handle quant for Flux
# if 'te' in modules and getattr(shared.sd_model, 'transformer', None) is not None:
# for name, module in shared.sd_model.transformer.named_modules():
# if "norm" in name and "linear" not in name:
# continue
# weights_backup = getattr(module, "network_weights_backup", None)
# if weights_backup is None:
# continue
# module.svdhandler = SVDHandler()
# module.svdhandler.network_name = "lora_transformer_" + name.replace(".", "_")
# module.svdhandler.decompose(module.weight, weights_backup)
# module.svdhandler.findrank(rank, rank_ratio)
lora_state_dict = {}
for sub in ['text_encoder', 'text_encoder_2', 'unet', 'transformer']:
submodel = getattr(shared.sd_model, sub, None)
if submodel is not None:
modules = submodel.named_modules()
task = progress.add_task(description=f"{sub} exctract", total=len(list(modules)))
for _name, module in submodel.named_modules():
progress.update(task, advance=1)
if not hasattr(module, "svdhandler"):
continue
lora_state_dict.update(module.svdhandler.makeweights())
del module.svdhandler
progress.remove_task(task)
t4 = time.time()
if not os.path.isabs(fn):
fn = os.path.join(shared.cmd_opts.lora_dir, fn)
if not fn.endswith('.safetensors'):
fn += '.safetensors'
if os.path.exists(fn):
if overwrite:
os.remove(fn)
else:
msg = f'LoRA extract: fn="{fn}" file exists'
shared.log.warning(msg)
yield msg
return
shared.state.end()
meta = make_meta(fn, maxrank, rank_ratio)
shared.log.debug(f'LoRA metadata: {meta}')
try:
save_file(tensors=lora_state_dict, metadata=meta, filename=fn)
except Exception as e:
msg = f'LoRA extract error: fn="{fn}" {e}'
shared.log.error(msg)
yield msg
return
t5 = time.time()
shared.log.debug(f'LoRA extract: time={t5-t0:.2f} te1={t1-t0:.2f} te2={t2-t1:.2f} unet={t3-t2:.2f} save={t5-t4:.2f}')
keys = list(lora_state_dict.keys())
msg = f'LoRA extract: fn="{fn}" keys={len(keys)}'
shared.log.info(msg)
yield msg
def create_ui():
def gr_show(visible=True):
return {"visible": visible, "__type__": "update"}
with gr.Tab(label="Extract LoRA"):
with gr.Row():
loaded = gr.Textbox(placeholder="Press refresh to query loaded LoRA", label="Loaded LoRA", interactive=False)
create_refresh_button(loaded, lambda: None, lambda: {'value': loaded_lora_str()}, "testid")
with gr.Group():
with gr.Row():
modules = gr.CheckboxGroup(label="Modules to extract", value=['unet'], choices=['te', 'unet'])
with gr.Row():
auto_rank = gr.Checkbox(value=False, label="Automatically determine rank")
rank_ratio = gr.Slider(label="Autorank ratio", value=1, minimum=0, maximum=1, step=0.05, visible=False)
rank = gr.Slider(label="Maximum rank", value=32, minimum=1, maximum=256)
with gr.Row():
filename = gr.Textbox(label="LoRA target filename")
overwrite = gr.Checkbox(value=False, label="Overwrite existing file")
with gr.Row():
extract = gr.Button(value="Extract LoRA", variant='primary')
status = gr.HTML(value="", show_label=False)
auto_rank.change(fn=lambda x: gr_show(x), inputs=[auto_rank], outputs=[rank_ratio])
extract.click(
fn=wrap_gradio_gpu_call(make_lora, extra_outputs=[]),
inputs=[filename, rank, auto_rank, rank_ratio, modules, overwrite],
outputs=[status]
)

View File

@ -0,0 +1,66 @@
import torch
def make_weight_cp(t, wa, wb):
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
def rebuild_conventional(up, down, shape, dyn_dim=None):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
if dyn_dim is not None:
up = up[:, :dyn_dim]
down = down[:dyn_dim, :]
return (up @ down).reshape(shape)
def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
"""
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, AB BA. Meaning of two matrices is slightly different.
examples
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
"""
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m<n:
new_m = m + 1
while dimension%new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m>factor:
break
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n

187
modules/lora/network.py Normal file
View File

@ -0,0 +1,187 @@
import os
from collections import namedtuple
import enum
from modules import sd_models, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
class SdVersion(enum.Enum):
Unknown = 1
SD1 = 2
SD2 = 3
SD3 = 3
SDXL = 4
SC = 5
F1 = 6
class NetworkOnDisk:
def __init__(self, name, filename):
self.shorthash = None
self.hash = None
self.name = name
self.filename = filename
if filename.startswith(shared.cmd_opts.lora_dir):
self.fullname = os.path.splitext(filename[len(shared.cmd_opts.lora_dir):].strip("/"))[0]
else:
self.fullname = name
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
if self.is_safetensors:
self.metadata = sd_models.read_metadata_from_safetensors(filename)
if self.metadata:
m = {}
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
m[k] = v
self.metadata = m
self.alias = self.metadata.get('ss_output_name', self.name)
sha256 = hashes.sha256_from_cache(self.filename, "lora/" + self.name) or hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=True) or self.metadata.get('sshs_model_hash')
self.set_hash(sha256)
self.sd_version = self.detect_version()
def detect_version(self):
base = str(self.metadata.get('ss_base_model_version', "")).lower()
arch = str(self.metadata.get('modelspec.architecture', "")).lower()
if base.startswith("sd_v1"):
return 'sd1'
if base.startswith("sdxl"):
return 'xl'
if base.startswith("stable_cascade"):
return 'sc'
if base.startswith("sd3"):
return 'sd3'
if base.startswith("flux"):
return 'f1'
if arch.startswith("stable-diffusion-v1"):
return 'sd1'
if arch.startswith("stable-diffusion-xl"):
return 'xl'
if arch.startswith("stable-cascade"):
return 'sc'
if arch.startswith("flux"):
return 'f1'
if "v1-5" in str(self.metadata.get('ss_sd_model_name', "")):
return 'sd1'
if str(self.metadata.get('ss_v2', "")) == "True":
return 'sd2'
if 'flux' in self.name.lower():
return 'f1'
if 'xl' in self.name.lower():
return 'xl'
return ''
def set_hash(self, v):
self.hash = v or ''
self.shorthash = self.hash[0:8]
def read_hash(self):
if not self.hash:
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
def get_alias(self):
import modules.lora.networks as networks
return self.name if shared.opts.lora_preferred_name == "filename" or self.alias.lower() in networks.forbidden_network_aliases else self.alias
class Network: # LoraModule
def __init__(self, name, network_on_disk: NetworkOnDisk):
self.name = name
self.network_on_disk = network_on_disk
self.te_multiplier = 1.0
self.unet_multiplier = [1.0] * 3
self.dyn_dim = None
self.modules = {}
self.bundle_embeddings = {}
self.mtime = None
self.mentioned_name = None
self.tags = None
"""the text that was used to add the network to prompt - can be either name or an alias"""
class ModuleType:
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: # pylint: disable=W0613
return None
class NetworkModule:
def __init__(self, net: Network, weights: NetworkWeights):
self.network = net
self.network_key = weights.network_key
self.sd_key = weights.sd_key
self.sd_module = weights.sd_module
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.dim = None
self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
self.dora_scale = weights.w.get("dora_scale", None)
self.dora_norm_dims = len(self.shape) - 1
def multiplier(self):
unet_multiplier = 3 * [self.network.unet_multiplier] if not isinstance(self.network.unet_multiplier, list) else self.network.unet_multiplier
if 'transformer' in self.sd_key[:20]:
return self.network.te_multiplier
if "down_blocks" in self.sd_key:
return unet_multiplier[0]
if "mid_block" in self.sd_key:
return unet_multiplier[1]
if "up_blocks" in self.sd_key:
return unet_multiplier[2]
else:
return unet_multiplier[0]
def calc_scale(self):
if self.scale is not None:
return self.scale
if self.dim is not None and self.alpha is not None:
return self.alpha / self.dim
return 1.0
def apply_weight_decompose(self, updown, orig_weight):
# Match the device/dtype
orig_weight = orig_weight.to(updown.dtype)
dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
updown = updown.to(orig_weight.device)
merged_scale1 = updown + orig_weight
merged_scale1_norm = (
merged_scale1.transpose(0, 1)
.reshape(merged_scale1.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
)
dora_merged = (
merged_scale1 * (dora_scale / merged_scale1_norm)
)
final_updown = dora_merged - orig_weight
return final_updown
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
updown = updown.reshape(output_shape)
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()
if self.dora_scale is not None:
updown = self.apply_weight_decompose(updown, orig_weight)
return updown * self.calc_scale() * self.multiplier(), ex_bias
def calc_updown(self, target):
raise NotImplementedError
def forward(self, x, y):
raise NotImplementedError

View File

@ -0,0 +1,26 @@
import modules.lora.network as network
class ModuleTypeFull(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["diff"]):
return NetworkModuleFull(net, weights)
return None
class NetworkModuleFull(network.NetworkModule): # pylint: disable=abstract-method
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.weight = weights.w.get("diff")
self.ex_bias = weights.w.get("diff_b")
def calc_updown(self, target):
output_shape = self.weight.shape
updown = self.weight.to(target.device, dtype=target.dtype)
if self.ex_bias is not None:
ex_bias = self.ex_bias.to(target.device, dtype=target.dtype)
else:
ex_bias = None
return self.finalize_updown(updown, target, output_shape, ex_bias)

View File

@ -0,0 +1,30 @@
import modules.lora.network as network
class ModuleTypeGLora(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
return NetworkModuleGLora(net, weights)
return None
# adapted from https://github.com/KohakuBlueleaf/LyCORIS
class NetworkModuleGLora(network.NetworkModule): # pylint: disable=abstract-method
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.w1a = weights.w["a1.weight"]
self.w1b = weights.w["b1.weight"]
self.w2a = weights.w["a2.weight"]
self.w2b = weights.w["b2.weight"]
def calc_updown(self, target): # pylint: disable=arguments-differ
w1a = self.w1a.to(target.device, dtype=target.dtype)
w1b = self.w1b.to(target.device, dtype=target.dtype)
w2a = self.w2a.to(target.device, dtype=target.dtype)
w2b = self.w2b.to(target.device, dtype=target.dtype)
output_shape = [w1a.size(0), w1b.size(1)]
updown = (w2b @ w1b) + ((target @ w2a) @ w1a)
return self.finalize_updown(updown, target, output_shape)

View File

@ -0,0 +1,46 @@
import modules.lora.lyco_helpers as lyco_helpers
import modules.lora.network as network
class ModuleTypeHada(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
return NetworkModuleHada(net, weights)
return None
class NetworkModuleHada(network.NetworkModule): # pylint: disable=abstract-method
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.w1a = weights.w["hada_w1_a"]
self.w1b = weights.w["hada_w1_b"]
self.dim = self.w1b.shape[0]
self.w2a = weights.w["hada_w2_a"]
self.w2b = weights.w["hada_w2_b"]
self.t1 = weights.w.get("hada_t1")
self.t2 = weights.w.get("hada_t2")
def calc_updown(self, target):
w1a = self.w1a.to(target.device, dtype=target.dtype)
w1b = self.w1b.to(target.device, dtype=target.dtype)
w2a = self.w2a.to(target.device, dtype=target.dtype)
w2b = self.w2b.to(target.device, dtype=target.dtype)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
t1 = self.t1.to(target.device, dtype=target.dtype)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
if len(w1b.shape) == 4:
output_shape += w1b.shape[2:]
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
t2 = self.t2.to(target.device, dtype=target.dtype)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
updown = updown1 * updown2
return self.finalize_updown(updown, target, output_shape)

View File

@ -0,0 +1,24 @@
import modules.lora.network as network
class ModuleTypeIa3(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["weight"]):
return NetworkModuleIa3(net, weights)
return None
class NetworkModuleIa3(network.NetworkModule): # pylint: disable=abstract-method
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.w = weights.w["weight"]
self.on_input = weights.w["on_input"].item()
def calc_updown(self, target):
w = self.w.to(target.device, dtype=target.dtype)
output_shape = [w.size(0), target.size(1)]
if self.on_input:
output_shape.reverse()
else:
w = w.reshape(-1, 1)
updown = target * w
return self.finalize_updown(updown, target, output_shape)

View File

@ -0,0 +1,57 @@
import torch
import modules.lora.lyco_helpers as lyco_helpers
import modules.lora.network as network
class ModuleTypeLokr(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
if has_1 and has_2:
return NetworkModuleLokr(net, weights)
return None
def make_kron(orig_shape, w1, w2):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
return torch.kron(w1, w2).reshape(orig_shape)
class NetworkModuleLokr(network.NetworkModule): # pylint: disable=abstract-method
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.w1 = weights.w.get("lokr_w1")
self.w1a = weights.w.get("lokr_w1_a")
self.w1b = weights.w.get("lokr_w1_b")
self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
self.w2 = weights.w.get("lokr_w2")
self.w2a = weights.w.get("lokr_w2_a")
self.w2b = weights.w.get("lokr_w2_b")
self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
self.t2 = weights.w.get("lokr_t2")
def calc_updown(self, target):
if self.w1 is not None:
w1 = self.w1.to(target.device, dtype=target.dtype)
else:
w1a = self.w1a.to(target.device, dtype=target.dtype)
w1b = self.w1b.to(target.device, dtype=target.dtype)
w1 = w1a @ w1b
if self.w2 is not None:
w2 = self.w2.to(target.device, dtype=target.dtype)
elif self.t2 is None:
w2a = self.w2a.to(target.device, dtype=target.dtype)
w2b = self.w2b.to(target.device, dtype=target.dtype)
w2 = w2a @ w2b
else:
t2 = self.t2.to(target.device, dtype=target.dtype)
w2a = self.w2a.to(target.device, dtype=target.dtype)
w2b = self.w2b.to(target.device, dtype=target.dtype)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
if len(target.shape) == 4:
output_shape = target.shape
updown = make_kron(output_shape, w1, w2)
return self.finalize_updown(updown, target, output_shape)

View File

@ -0,0 +1,78 @@
import torch
import diffusers.models.lora as diffusers_lora
import modules.lora.lyco_helpers as lyco_helpers
import modules.lora.network as network
from modules import devices
class ModuleTypeLora(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
return NetworkModuleLora(net, weights)
return None
class NetworkModuleLora(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.up_model = self.create_module(weights.w, "lora_up.weight")
self.down_model = self.create_module(weights.w, "lora_down.weight")
self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
self.dim = weights.w["lora_down.weight"].shape[0]
def create_module(self, weights, key, none_ok=False):
from modules.shared import opts
weight = weights.get(key)
if weight is None and none_ok:
return None
linear_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear]
is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear", "Linear4bit"}
is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] or self.sd_module.__class__.__name__ in {"NNCFConv2d", "QConv2d"}
if is_linear:
weight = weight.reshape(weight.shape[0], -1)
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif is_conv and key == "lora_down.weight" or key == "dyn_up":
if len(weight.shape) == 2:
weight = weight.reshape(weight.shape[0], -1, 1, 1)
if weight.shape[2] != 1 or weight.shape[3] != 1:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
else:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif is_conv and key == "lora_mid.weight":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
elif is_conv and key == "lora_up.weight" or key == "dyn_down":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
raise AssertionError(f'Lora unsupported: layer={self.network_key} type={type(self.sd_module).__name__}')
with torch.no_grad():
if weight.shape != module.weight.shape:
weight = weight.reshape(module.weight.shape)
module.weight.copy_(weight)
if opts.lora_load_gpu:
module = module.to(device=devices.device, dtype=devices.dtype)
module.weight.requires_grad_(False)
return module
def calc_updown(self, target): # pylint: disable=W0237
target_dtype = target.dtype if target.dtype != torch.uint8 else self.up_model.weight.dtype
up = self.up_model.weight.to(target.device, dtype=target_dtype)
down = self.down_model.weight.to(target.device, dtype=target_dtype)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
mid = self.mid_model.weight.to(target.device, dtype=target_dtype)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
if len(down.shape) == 4:
output_shape += down.shape[2:]
updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
return self.finalize_updown(updown, target, output_shape)
def forward(self, x, y):
self.up_model.to(device=devices.device)
self.down_model.to(device=devices.device)
if hasattr(y, "scale"):
return y(scale=1) + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()

View File

@ -0,0 +1,23 @@
import modules.lora.network as network
class ModuleTypeNorm(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["w_norm", "b_norm"]):
return NetworkModuleNorm(net, weights)
return None
class NetworkModuleNorm(network.NetworkModule): # pylint: disable=abstract-method
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.w_norm = weights.w.get("w_norm")
self.b_norm = weights.w.get("b_norm")
def calc_updown(self, target):
output_shape = self.w_norm.shape
updown = self.w_norm.to(target.device, dtype=target.dtype)
if self.b_norm is not None:
ex_bias = self.b_norm.to(target.device, dtype=target.dtype)
else:
ex_bias = None
return self.finalize_updown(updown, target, output_shape, ex_bias)

View File

@ -0,0 +1,81 @@
import torch
import modules.lora.network as network
from modules.lora.lyco_helpers import factorization
from einops import rearrange
class ModuleTypeOFT(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
return NetworkModuleOFT(net, weights)
return None
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
class NetworkModuleOFT(network.NetworkModule): # pylint: disable=abstract-method
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.lin_module = None
self.org_module: list[torch.Module] = [self.sd_module]
self.scale = 1.0
# kohya-ss
if "oft_blocks" in weights.w.keys():
self.is_kohya = True
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
self.alpha = weights.w["alpha"] # alpha is constraint
self.dim = self.oft_blocks.shape[0] # lora dim
# LyCORIS
elif "oft_diag" in weights.w.keys():
self.is_kohya = False
self.oft_blocks = weights.w["oft_diag"]
# self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
if is_linear:
self.out_dim = self.sd_module.out_features
elif is_conv:
self.out_dim = self.sd_module.out_channels
elif is_other_linear:
self.out_dim = self.sd_module.embed_dim
if self.is_kohya:
self.constraint = self.alpha * self.out_dim
self.num_blocks = self.dim
self.block_size = self.out_dim // self.dim
else:
self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
def calc_updown(self, target):
oft_blocks = self.oft_blocks.to(target.device, dtype=target.dtype)
eye = torch.eye(self.block_size, device=target.device)
constraint = self.constraint.to(target.device)
if self.is_kohya:
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
norm_Q = torch.norm(block_Q.flatten()).to(target.device)
new_norm_Q = torch.clamp(norm_Q, max=constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
mat1 = eye + block_Q
mat2 = (eye - block_Q).float().inverse()
oft_blocks = torch.matmul(mat1, mat2)
R = oft_blocks.to(target.device, dtype=target.dtype)
# This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = rearrange(target, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
merged_weight = torch.einsum(
'k n m, k n ... -> k m ...',
R,
merged_weight
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
updown = merged_weight.to(target.device, dtype=target.dtype) - target
output_shape = target.shape
return self.finalize_updown(updown, target, output_shape)

View File

@ -0,0 +1,49 @@
from modules import shared
maybe_diffusers = [ # forced if lora_maybe_diffusers is enabled
'aaebf6360f7d', # sd15-lcm
'3d18b05e4f56', # sdxl-lcm
'b71dcb732467', # sdxl-tcd
'813ea5fb1c67', # sdxl-turbo
# not really needed, but just in case
'5a48ac366664', # hyper-sd15-1step
'ee0ff23dcc42', # hyper-sd15-2step
'e476eb1da5df', # hyper-sd15-4step
'ecb844c3f3b0', # hyper-sd15-8step
'1ab289133ebb', # hyper-sd15-8step-cfg
'4f494295edb1', # hyper-sdxl-8step
'ca14a8c621f8', # hyper-sdxl-8step-cfg
'1c88f7295856', # hyper-sdxl-4step
'fdd5dcd1d88a', # hyper-sdxl-2step
'8cca3706050b', # hyper-sdxl-1step
]
force_diffusers = [ # forced always
'816d0eed49fd', # flash-sdxl
'c2ec22757b46', # flash-sd15
]
force_models = [ # forced always
'sc',
# 'sd3',
'kandinsky',
'hunyuandit',
'auraflow',
]
force_classes = [ # forced always
]
def check_override(shorthash=''):
force = False
force = force or (shared.sd_model_type in force_models)
force = force or (shared.sd_model.__class__.__name__ in force_classes)
if len(shorthash) < 4:
return force
force = force or (any(x.startswith(shorthash) for x in maybe_diffusers) if shared.opts.lora_maybe_diffusers else False)
force = force or any(x.startswith(shorthash) for x in force_diffusers)
if force and shared.opts.lora_maybe_diffusers:
shared.log.debug('LoRA override: force diffusers')
return force

453
modules/lora/networks.py Normal file
View File

@ -0,0 +1,453 @@
from typing import Union, List
import os
import re
import time
import concurrent
import modules.lora.network as network
import modules.lora.network_lora as network_lora
import modules.lora.network_hada as network_hada
import modules.lora.network_ia3 as network_ia3
import modules.lora.network_oft as network_oft
import modules.lora.network_lokr as network_lokr
import modules.lora.network_full as network_full
import modules.lora.network_norm as network_norm
import modules.lora.network_glora as network_glora
import modules.lora.network_overrides as network_overrides
import modules.lora.lora_convert as lora_convert
import torch
import diffusers.models.lora
from modules import shared, devices, sd_models, sd_models_compile, errors, scripts, files_cache, model_quant
debug = os.environ.get('SD_LORA_DEBUG', None) is not None
extra_network_lora = None
available_networks = {}
available_network_aliases = {}
loaded_networks: List[network.Network] = []
timer = { 'load': 0, 'apply': 0, 'restore': 0, 'deactivate': 0 }
lora_cache = {}
diffuser_loaded = []
diffuser_scales = []
available_network_hash_lookup = {}
forbidden_network_aliases = {}
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
module_types = [
network_lora.ModuleTypeLora(),
network_hada.ModuleTypeHada(),
network_ia3.ModuleTypeIa3(),
network_oft.ModuleTypeOFT(),
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
network_glora.ModuleTypeGLora(),
]
def assign_network_names_to_compvis_modules(sd_model):
if sd_model is None:
return
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
network_layer_mapping = {}
if hasattr(sd_model, 'text_encoder') and sd_model.text_encoder is not None:
for name, module in sd_model.text_encoder.named_modules():
prefix = "lora_te1_" if hasattr(sd_model, 'text_encoder_2') else "lora_te_"
network_name = prefix + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
if hasattr(sd_model, 'text_encoder_2'):
for name, module in sd_model.text_encoder_2.named_modules():
network_name = "lora_te2_" + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
if hasattr(sd_model, 'unet'):
for name, module in sd_model.unet.named_modules():
network_name = "lora_unet_" + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
if hasattr(sd_model, 'transformer'):
for name, module in sd_model.transformer.named_modules():
network_name = "lora_transformer_" + name.replace(".", "_")
network_layer_mapping[network_name] = module
if "norm" in network_name and "linear" not in network_name and shared.sd_model_type != "sd3":
continue
module.network_layer_name = network_name
shared.sd_model.network_layer_mapping = network_layer_mapping
def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> network.Network | None:
name = name.replace(".", "_")
shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" detected={network_on_disk.sd_version} method=diffusers scale={lora_scale} fuse={shared.opts.lora_fuse_diffusers}')
if not shared.native:
return None
if not hasattr(shared.sd_model, 'load_lora_weights'):
shared.log.error(f'Load network: type=LoRA class={shared.sd_model.__class__} does not implement load lora')
return None
try:
shared.sd_model.load_lora_weights(network_on_disk.filename, adapter_name=name)
except Exception as e:
if 'already in use' in str(e):
pass
else:
if 'The following keys have not been correctly renamed' in str(e):
shared.log.error(f'Load network: type=LoRA name="{name}" diffusers unsupported format')
else:
shared.log.error(f'Load network: type=LoRA name="{name}" {e}')
if debug:
errors.display(e, "LoRA")
return None
if name not in diffuser_loaded:
diffuser_loaded.append(name)
diffuser_scales.append(lora_scale)
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
return net
def load_network(name, network_on_disk) -> network.Network | None:
if not shared.sd_loaded:
return None
cached = lora_cache.get(name, None)
if debug:
shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" type=lora {"cached" if cached else ""}')
if cached is not None:
return cached
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
sd = sd_models.read_state_dict(network_on_disk.filename, what='network')
if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict
sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # pylint: disable=protected-access
if shared.sd_model_type == 'sd3': # if kohya flux lora, convert state_dict
try:
sd = lora_convert._convert_kohya_sd3_lora_to_diffusers(sd) or sd # pylint: disable=protected-access
except ValueError: # EAFP for diffusers PEFT keys
pass
assign_network_names_to_compvis_modules(shared.sd_model)
keys_failed_to_match = {}
matched_networks = {}
bundle_embeddings = {}
convert = lora_convert.KeyConvert()
for key_network, weight in sd.items():
parts = key_network.split('.')
if parts[0] == "bundle_emb":
emb_name, vec_name = parts[1], key_network.split(".", 2)[-1]
emb_dict = bundle_embeddings.get(emb_name, {})
emb_dict[vec_name] = weight
bundle_embeddings[emb_name] = emb_dict
continue
if len(parts) > 5: # messy handler for diffusers peft lora
key_network_without_network_parts = '_'.join(parts[:-2])
if not key_network_without_network_parts.startswith('lora_'):
key_network_without_network_parts = 'lora_' + key_network_without_network_parts
network_part = '.'.join(parts[-2:]).replace('lora_A', 'lora_down').replace('lora_B', 'lora_up')
else:
key_network_without_network_parts, network_part = key_network.split(".", 1)
key, sd_module = convert(key_network_without_network_parts)
if sd_module is None:
keys_failed_to_match[key_network] = key
continue
if key not in matched_networks:
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
matched_networks[key].w[network_part] = weight
network_types = []
for key, weights in matched_networks.items():
net_module = None
for nettype in module_types:
net_module = nettype.create_module(net, weights)
if net_module is not None:
network_types.append(nettype.__class__.__name__)
break
if net_module is None:
shared.log.error(f'LoRA unhandled: name={name} key={key} weights={weights.w.keys()}')
else:
net.modules[key] = net_module
if len(keys_failed_to_match) > 0:
shared.log.warning(f'LoRA name="{name}" type={set(network_types)} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}')
if debug:
shared.log.debug(f'LoRA name="{name}" unmatched={keys_failed_to_match}')
else:
shared.log.debug(f'LoRA name="{name}" type={set(network_types)} keys={len(matched_networks)}')
if len(matched_networks) == 0:
return None
lora_cache[name] = net
net.bundle_embeddings = bundle_embeddings
return net
def maybe_recompile_model(names, te_multipliers):
recompile_model = False
if shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled:
if len(names) == len(shared.compiled_model_state.lora_model):
for i, name in enumerate(names):
if shared.compiled_model_state.lora_model[
i] != f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}":
recompile_model = True
shared.compiled_model_state.lora_model = []
break
if not recompile_model:
if len(loaded_networks) > 0 and debug:
shared.log.debug('Model Compile: Skipping LoRa loading')
return
else:
recompile_model = True
shared.compiled_model_state.lora_model = []
if recompile_model:
backup_cuda_compile = shared.opts.cuda_compile
sd_models.unload_model_weights(op='model')
shared.opts.cuda_compile = []
sd_models.reload_model_weights(op='model')
shared.opts.cuda_compile = backup_cuda_compile
return recompile_model
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk):
list_available_networks()
networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names]
failed_to_load_networks = []
recompile_model = maybe_recompile_model(names, te_multipliers)
loaded_networks.clear()
diffuser_loaded.clear()
diffuser_scales.clear()
timer['load'] = 0
t0 = time.time()
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
net = None
if network_on_disk is not None:
shorthash = getattr(network_on_disk, 'shorthash', '').lower()
if debug:
shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" hash="{shorthash}"')
try:
if recompile_model:
shared.compiled_model_state.lora_model.append(f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}")
if shared.opts.lora_force_diffusers or network_overrides.check_override(shorthash): # OpenVINO only works with Diffusers LoRa loading
net = load_diffusers(name, network_on_disk, lora_scale=te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier)
else:
net = load_network(name, network_on_disk)
if net is not None:
net.mentioned_name = name
network_on_disk.read_hash()
except Exception as e:
shared.log.error(f'Load network: type=LoRA file="{network_on_disk.filename}" {e}')
if debug:
errors.display(e, 'LoRA')
continue
if net is None:
failed_to_load_networks.append(name)
shared.log.error(f'Load network: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} failed')
continue
shared.sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings)
net.te_multiplier = te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else shared.opts.extra_networks_default_multiplier
net.dyn_dim = dyn_dims[i] if dyn_dims else shared.opts.extra_networks_default_multiplier
loaded_networks.append(net)
while len(lora_cache) > shared.opts.lora_in_memory_limit:
name = next(iter(lora_cache))
lora_cache.pop(name, None)
if len(diffuser_loaded) > 0:
shared.log.debug(f'Load network: type=LoRA loaded={diffuser_loaded} available={shared.sd_model.get_list_adapters()} active={shared.sd_model.get_active_adapters()} scales={diffuser_scales}')
try:
shared.sd_model.set_adapters(adapter_names=diffuser_loaded, adapter_weights=diffuser_scales)
if shared.opts.lora_fuse_diffusers:
shared.sd_model.fuse_lora(adapter_names=diffuser_loaded, lora_scale=1.0, fuse_unet=True, fuse_text_encoder=True) # fuse uses fixed scale since later apply does the scaling
shared.sd_model.unload_lora_weights()
except Exception as e:
shared.log.error(f'Load network: type=LoRA {e}')
if debug:
errors.display(e, 'LoRA')
if len(loaded_networks) > 0 and debug:
shared.log.debug(f'Load network: type=LoRA loaded={len(loaded_networks)} cache={list(lora_cache)}')
devices.torch_gc()
if recompile_model:
shared.log.info("Load network: type=LoRA recompiling model")
backup_lora_model = shared.compiled_model_state.lora_model
if 'Model' in shared.opts.cuda_compile:
shared.sd_model = sd_models_compile.compile_diffusers(shared.sd_model)
shared.compiled_model_state.lora_model = backup_lora_model
if shared.opts.diffusers_offload_mode == "balanced":
sd_models.apply_balanced_offload(shared.sd_model)
t1 = time.time()
timer['load'] += t1 - t0
def set_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown, ex_bias):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None and bias_backup is None:
return
device = self.weight.device
with devices.inference_context():
if weights_backup is not None:
if updown is not None:
if len(weights_backup.shape) == 4 and weights_backup.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable
weights_backup = weights_backup.clone().to(device)
weights_backup += updown.to(weights_backup)
if getattr(self, "quant_type", None) in ['nf4', 'fp4']:
bnb = model_quant.load_bnb('Load network: type=LoRA', silent=True)
if bnb is not None:
self.weight = bnb.nn.Params4bit(weights_backup, quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize)
else:
self.weight.copy_(weights_backup, non_blocking=True)
else:
self.weight.copy_(weights_backup, non_blocking=True)
if hasattr(self, "qweight") and hasattr(self, "freeze"):
self.freeze()
if bias_backup is not None:
if ex_bias is not None:
bias_backup = bias_backup.clone() + ex_bias.to(weights_backup)
self.bias.copy_(bias_backup)
else:
self.bias = None
self.to(device)
def maybe_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], wanted_names): # pylint: disable=W0613
weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None and wanted_names != (): # pylint: disable=C1803
if getattr(self.weight, "quant_type", None) in ['nf4', 'fp4']:
bnb = model_quant.load_bnb('Load network: type=LoRA', silent=True)
if bnb is not None:
with devices.inference_context():
weights_backup = bnb.functional.dequantize_4bit(self.weight, quant_state=self.weight.quant_state, quant_type=self.weight.quant_type, blocksize=self.weight.blocksize,)
self.quant_state = self.weight.quant_state
self.quant_type = self.weight.quant_type
self.blocksize = self.weight.blocksize
else:
weights_backup = self.weight.clone()
else:
weights_backup = self.weight.clone()
if shared.opts.lora_offload_backup and weights_backup is not None:
weights_backup = weights_backup.to(devices.cpu)
self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None:
if getattr(self, 'bias', None) is not None:
bias_backup = self.bias.clone()
else:
bias_backup = None
if shared.opts.lora_offload_backup and bias_backup is not None:
bias_backup = bias_backup.to(devices.cpu)
self.network_bias_backup = bias_backup
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
If not, restores orginal weights from backup and alters weights according to networks.
"""
network_layer_name = getattr(self, 'network_layer_name', None)
if network_layer_name is None:
return
t0 = time.time()
current_names = getattr(self, "network_current_names", ())
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
if any([net.modules.get(network_layer_name, None) for net in loaded_networks]): # noqa: C419 # pylint: disable=R1729
maybe_backup_weights(self, wanted_names)
if current_names != wanted_names:
for net in loaded_networks:
# default workflow where module is known and has weights
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
try:
with devices.inference_context():
weight = self.weight # calculate quant weights once
updown, ex_bias = module.calc_updown(weight)
set_weights(self, updown, ex_bias)
except RuntimeError as e:
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
if debug:
module_name = net.modules.get(network_layer_name, None)
shared.log.error(f'LoRA apply weight name="{net.name}" module="{module_name}" layer="{network_layer_name}" {e}')
errors.display(e, 'LoRA')
raise RuntimeError('LoRA apply weight') from e
continue
if module is None:
continue
shared.log.warning(f'LoRA network="{net.name}" layer="{network_layer_name}" unsupported operation')
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
if not loaded_networks: # restore from backup
t5 = time.time()
set_weights(self, None, None)
self.network_current_names = wanted_names
t1 = time.time()
timer['apply'] += t1 - t0
def network_load():
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
for component_name in ['text_encoder','text_encoder_2', 'unet', 'transformer']:
component = getattr(sd_model, component_name, None)
if component is not None:
for _, module in component.named_modules():
network_apply_weights(module)
def list_available_networks():
t0 = time.time()
available_networks.clear()
available_network_aliases.clear()
forbidden_network_aliases.clear()
available_network_hash_lookup.clear()
forbidden_network_aliases.update({"none": 1, "Addams": 1})
if not os.path.exists(shared.cmd_opts.lora_dir):
shared.log.warning(f'LoRA directory not found: path="{shared.cmd_opts.lora_dir}"')
def add_network(filename):
if not os.path.isfile(filename):
return
name = os.path.splitext(os.path.basename(filename))[0]
name = name.replace('.', '_')
try:
entry = network.NetworkOnDisk(name, filename)
available_networks[entry.name] = entry
if entry.alias in available_network_aliases:
forbidden_network_aliases[entry.alias.lower()] = 1
if shared.opts.lora_preferred_name == 'filename':
available_network_aliases[entry.name] = entry
else:
available_network_aliases[entry.alias] = entry
if entry.shorthash:
available_network_hash_lookup[entry.shorthash] = entry
except OSError as e: # should catch FileNotFoundError and PermissionError etc.
shared.log.error(f'LoRA: filename="{filename}" {e}')
candidates = list(files_cache.list_files(shared.cmd_opts.lora_dir, ext_filter=[".pt", ".ckpt", ".safetensors"]))
with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
for fn in candidates:
executor.submit(add_network, fn)
t1 = time.time()
shared.log.info(f'Available LoRAs: path="{shared.cmd_opts.lora_dir}" items={len(available_networks)} folders={len(forbidden_network_aliases)} time={t1 - t0:.2f}')
def infotext_pasted(infotext, params): # pylint: disable=W0613
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
return # if the other extension is active, it will handle those fields, no need to do anything
added = []
for k in params:
if not k.startswith("AddNet Model "):
continue
num = k[13:]
if params.get("AddNet Module " + num) != "LoRA":
continue
name = params.get("AddNet Model " + num)
if name is None:
continue
m = re_network_name.match(name)
if m:
name = m.group(1)
multiplier = params.get("AddNet Weight A " + num, "1.0")
added.append(f"<lora:{name}:{multiplier}>")
if added:
params["Prompt"] += "\n" + "".join(added)
list_available_networks()

View File

@ -0,0 +1,123 @@
import os
import json
import concurrent
import modules.lora.networks as networks
from modules import shared, ui_extra_networks
debug = os.environ.get('SD_LORA_DEBUG', None) is not None
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('Lora')
self.list_time = 0
def refresh(self):
networks.list_available_networks()
@staticmethod
def get_tags(l, info):
tags = {}
try:
if l.metadata is not None:
modelspec_tags = l.metadata.get('modelspec.tags', {})
possible_tags = l.metadata.get('ss_tag_frequency', {}) # tags from model metedata
if isinstance(possible_tags, str):
possible_tags = {}
if isinstance(modelspec_tags, str):
modelspec_tags = {}
if len(list(modelspec_tags)) > 0:
possible_tags.update(modelspec_tags)
for k, v in possible_tags.items():
words = k.split('_', 1) if '_' in k else [v, k]
words = [str(w).replace('.json', '') for w in words]
if words[0] == '{}':
words[0] = 0
tag = ' '.join(words[1:]).lower()
tags[tag] = words[0]
def find_version():
found_versions = []
current_hash = l.hash[:8].upper()
all_versions = info.get('modelVersions', [])
for v in info.get('modelVersions', []):
for f in v.get('files', []):
if any(h.startswith(current_hash) for h in f.get('hashes', {}).values()):
found_versions.append(v)
if len(found_versions) == 0:
found_versions = all_versions
return found_versions
for v in find_version(): # trigger words from info json
possible_tags = v.get('trainedWords', [])
if isinstance(possible_tags, list):
for tag_str in possible_tags:
for tag in tag_str.split(','):
tag = tag.strip().lower()
if tag not in tags:
tags[tag] = 0
possible_tags = info.get('tags', []) # tags from info json
if not isinstance(possible_tags, list):
possible_tags = list(possible_tags.values())
for tag in possible_tags:
tag = tag.strip().lower()
if tag not in tags:
tags[tag] = 0
except Exception:
pass
bad_chars = [';', ':', '<', ">", "*", '?', '\'', '\"', '(', ')', '[', ']', '{', '}', '\\', '/']
clean_tags = {}
for k, v in tags.items():
tag = ''.join(i for i in k if i not in bad_chars).strip()
clean_tags[tag] = v
clean_tags.pop('img', None)
clean_tags.pop('dataset', None)
return clean_tags
def create_item(self, name):
l = networks.available_networks.get(name)
if l is None:
shared.log.warning(f'Networks: type=lora registered={len(list(networks.available_networks))} file="{name}" not registered')
return None
try:
# path, _ext = os.path.splitext(l.filename)
name = os.path.splitext(os.path.relpath(l.filename, shared.cmd_opts.lora_dir))[0]
item = {
"type": 'Lora',
"name": name,
"filename": l.filename,
"hash": l.shorthash,
"prompt": json.dumps(f" <lora:{l.get_alias()}:{shared.opts.extra_networks_default_multiplier}>"),
"metadata": json.dumps(l.metadata, indent=4) if l.metadata else None,
"mtime": os.path.getmtime(l.filename),
"size": os.path.getsize(l.filename),
"version": l.sd_version,
}
info = self.find_info(l.filename)
item["info"] = info
item["description"] = self.find_description(l.filename, info) # use existing info instead of double-read
item["tags"] = self.get_tags(l, info)
return item
except Exception as e:
shared.log.error(f'Networks: type=lora file="{name}" {e}')
if debug:
from modules import errors
errors.display(e, 'Lora')
return None
def list_items(self):
items = []
with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
future_items = {executor.submit(self.create_item, net): net for net in networks.available_networks}
for future in concurrent.futures.as_completed(future_items):
item = future.result()
if item is not None:
items.append(item)
self.update_all_previews(items)
return items
def allowed_directories_for_previews(self):
return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir]

View File

@ -8,6 +8,8 @@ from modules import shared, devices, processing, sd_models, errors, sd_hijack_hy
from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled
from modules.processing_args import set_pipeline_args
from modules.onnx_impl import preprocess_pipeline as preprocess_onnx_pipeline, check_parameters_changed as olive_check_parameters_changed
from modules.lora.networks import network_load
from modules.lora.networks import timer as network_timer
debug = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None
@ -424,6 +426,9 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
p.prompts = p.all_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
if p.negative_prompts is None or len(p.negative_prompts) == 0:
p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
network_timer['apply'] = 0
network_timer['restore'] = 0
network_load()
sd_models.move_model(shared.sd_model, devices.device)
sd_models_compile.openvino_recompile_model(p, hires=False, refiner=False) # recompile if a parameter changes

View File

@ -908,6 +908,7 @@ options_templates.update(options_section(('extra_networks', "Networks"), {
"lora_in_memory_limit": OptionInfo(0, "LoRA memory cache", gr.Slider, {"minimum": 0, "maximum": 24, "step": 1}),
"lora_quant": OptionInfo("NF4","LoRA precision in quantized models", gr.Radio, {"choices": ["NF4", "FP4"]}),
"lora_load_gpu": OptionInfo(True if not cmd_opts.lowvram else False, "Load LoRA directly to GPU"),
"lora_offload_backup": OptionInfo(True, "Offload LoRA Backup Weights"),
}))
options_templates.update(options_section((None, "Internal options"), {

62
scripts/lora_script.py Normal file
View File

@ -0,0 +1,62 @@
import re
import modules.lora.networks as networks
from modules.lora.lora_extract import create_ui
from modules.lora.network import NetworkOnDisk
from modules.lora.ui_extra_networks_lora import ExtraNetworksPageLora
from modules.lora.extra_networks_lora import ExtraNetworkLora
from modules import script_callbacks, extra_networks, ui_extra_networks, ui_models, shared # pylint: disable=unused-import
re_lora = re.compile("<lora:([^:]+):")
def before_ui():
ui_extra_networks.register_page(ExtraNetworksPageLora())
networks.extra_network_lora = ExtraNetworkLora()
extra_networks.register_extra_network(networks.extra_network_lora)
ui_models.extra_ui.append(create_ui)
def create_lora_json(obj: NetworkOnDisk):
return {
"name": obj.name,
"alias": obj.alias,
"path": obj.filename,
"metadata": obj.metadata,
}
def api_networks(_, app):
@app.get("/sdapi/v1/loras")
async def get_loras():
return [create_lora_json(obj) for obj in networks.available_networks.values()]
@app.post("/sdapi/v1/refresh-loras")
async def refresh_loras():
return networks.list_available_networks()
def infotext_pasted(infotext, d): # pylint: disable=unused-argument
hashes = d.get("Lora hashes", None)
if hashes is None:
return
def network_replacement(m):
alias = m.group(1)
shorthash = hashes.get(alias)
if shorthash is None:
return m.group(0)
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
if network_on_disk is None:
return m.group(0)
return f'<lora:{network_on_disk.get_alias()}:'
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
if shared.native:
script_callbacks.on_app_started(api_networks)
script_callbacks.on_before_ui(before_ui)
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
script_callbacks.on_infotext_pasted(infotext_pasted)