mirror of https://github.com/vladmandic/automatic
Major lora refactor: works on my machine edition
parent
2b147272f8
commit
cb561fa486
|
|
@ -5,7 +5,7 @@ from lora_extract import create_ui
|
|||
from network import NetworkOnDisk
|
||||
from ui_extra_networks_lora import ExtraNetworksPageLora
|
||||
from extra_networks_lora import ExtraNetworkLora
|
||||
from modules import script_callbacks, extra_networks, ui_extra_networks, ui_models # pylint: disable=unused-import
|
||||
from modules import script_callbacks, extra_networks, ui_extra_networks, ui_models, shared # pylint: disable=unused-import
|
||||
|
||||
|
||||
re_lora = re.compile("<lora:([^:]+):")
|
||||
|
|
@ -56,9 +56,9 @@ def infotext_pasted(infotext, d): # pylint: disable=unused-argument
|
|||
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
|
||||
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
|
||||
|
||||
|
||||
script_callbacks.on_app_started(api_networks)
|
||||
script_callbacks.on_before_ui(before_ui)
|
||||
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
||||
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
|
||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||
if not shared.native:
|
||||
script_callbacks.on_app_started(api_networks)
|
||||
script_callbacks.on_before_ui(before_ui)
|
||||
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
||||
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
|
||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,151 @@
|
|||
import re
|
||||
import time
|
||||
import numpy as np
|
||||
import modules.lora.networks as networks
|
||||
from modules import extra_networks, shared
|
||||
|
||||
# from https://github.com/cheald/sd-webui-loractl/blob/master/loractl/lib/utils.py
|
||||
def get_stepwise(param, step, steps):
|
||||
def sorted_positions(raw_steps):
|
||||
steps = [[float(s.strip()) for s in re.split("[@~]", x)]
|
||||
for x in re.split("[,;]", str(raw_steps))]
|
||||
if len(steps[0]) == 1: # If we just got a single number, just return it
|
||||
return steps[0][0]
|
||||
steps = [[s[0], s[1] if len(s) == 2 else 1] for s in steps] # Add implicit 1s to any steps which don't have a weight
|
||||
steps.sort(key=lambda k: k[1]) # Sort by index
|
||||
steps = [list(v) for v in zip(*steps)]
|
||||
return steps
|
||||
|
||||
def calculate_weight(m, step, max_steps, step_offset=2):
|
||||
if isinstance(m, list):
|
||||
if m[1][-1] <= 1.0:
|
||||
step = step / (max_steps - step_offset) if max_steps > 0 else 1.0
|
||||
v = np.interp(step, m[1], m[0])
|
||||
return v
|
||||
else:
|
||||
return m
|
||||
|
||||
stepwise = calculate_weight(sorted_positions(param), step, steps)
|
||||
return stepwise
|
||||
|
||||
|
||||
def prompt(p):
|
||||
if shared.opts.lora_apply_tags == 0:
|
||||
return
|
||||
all_tags = []
|
||||
for loaded in networks.loaded_networks:
|
||||
page = [en for en in shared.extra_networks if en.name == 'lora'][0]
|
||||
item = page.create_item(loaded.name)
|
||||
tags = (item or {}).get("tags", {})
|
||||
loaded.tags = list(tags)
|
||||
if len(loaded.tags) == 0:
|
||||
loaded.tags.append(loaded.name)
|
||||
if shared.opts.lora_apply_tags > 0:
|
||||
loaded.tags = loaded.tags[:shared.opts.lora_apply_tags]
|
||||
all_tags.extend(loaded.tags)
|
||||
if len(all_tags) > 0:
|
||||
shared.log.debug(f"Load network: type=LoRA tags={all_tags} max={shared.opts.lora_apply_tags} apply")
|
||||
all_tags = ', '.join(all_tags)
|
||||
p.extra_generation_params["LoRA tags"] = all_tags
|
||||
if '_tags_' in p.prompt:
|
||||
p.prompt = p.prompt.replace('_tags_', all_tags)
|
||||
else:
|
||||
p.prompt = f"{p.prompt}, {all_tags}"
|
||||
if p.all_prompts is not None:
|
||||
for i in range(len(p.all_prompts)):
|
||||
if '_tags_' in p.all_prompts[i]:
|
||||
p.all_prompts[i] = p.all_prompts[i].replace('_tags_', all_tags)
|
||||
else:
|
||||
p.all_prompts[i] = f"{p.all_prompts[i]}, {all_tags}"
|
||||
|
||||
|
||||
def infotext(p):
|
||||
names = [i.name for i in networks.loaded_networks]
|
||||
if len(names) > 0:
|
||||
p.extra_generation_params["LoRA networks"] = ", ".join(names)
|
||||
if shared.opts.lora_add_hashes_to_infotext:
|
||||
network_hashes = []
|
||||
for item in networks.loaded_networks:
|
||||
if not item.network_on_disk.shorthash:
|
||||
continue
|
||||
network_hashes.append(item.network_on_disk.shorthash)
|
||||
if len(network_hashes) > 0:
|
||||
p.extra_generation_params["LoRA hashes"] = ", ".join(network_hashes)
|
||||
|
||||
|
||||
def parse(p, params_list, step=0):
|
||||
names = []
|
||||
te_multipliers = []
|
||||
unet_multipliers = []
|
||||
dyn_dims = []
|
||||
for params in params_list:
|
||||
assert params.items
|
||||
names.append(params.positional[0])
|
||||
te_multiplier = params.named.get("te", params.positional[1] if len(params.positional) > 1 else shared.opts.extra_networks_default_multiplier)
|
||||
if isinstance(te_multiplier, str) and "@" in te_multiplier:
|
||||
te_multiplier = get_stepwise(te_multiplier, step, p.steps)
|
||||
else:
|
||||
te_multiplier = float(te_multiplier)
|
||||
unet_multiplier = [params.positional[2] if len(params.positional) > 2 else te_multiplier] * 3
|
||||
unet_multiplier = [params.named.get("unet", unet_multiplier[0])] * 3
|
||||
unet_multiplier[0] = params.named.get("in", unet_multiplier[0])
|
||||
unet_multiplier[1] = params.named.get("mid", unet_multiplier[1])
|
||||
unet_multiplier[2] = params.named.get("out", unet_multiplier[2])
|
||||
for i in range(len(unet_multiplier)):
|
||||
if isinstance(unet_multiplier[i], str) and "@" in unet_multiplier[i]:
|
||||
unet_multiplier[i] = get_stepwise(unet_multiplier[i], step, p.steps)
|
||||
else:
|
||||
unet_multiplier[i] = float(unet_multiplier[i])
|
||||
dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
|
||||
dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
|
||||
te_multipliers.append(te_multiplier)
|
||||
unet_multipliers.append(unet_multiplier)
|
||||
dyn_dims.append(dyn_dim)
|
||||
return names, te_multipliers, unet_multipliers, dyn_dims
|
||||
|
||||
|
||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('lora')
|
||||
self.active = False
|
||||
self.model = None
|
||||
self.errors = {}
|
||||
|
||||
def activate(self, p, params_list, step=0):
|
||||
t0 = time.time()
|
||||
self.errors.clear()
|
||||
if self.active:
|
||||
if self.model != shared.opts.sd_model_checkpoint: # reset if model changed
|
||||
self.active = False
|
||||
if len(params_list) > 0 and not self.active: # activate patches once
|
||||
shared.log.debug(f'Activate network: type=LoRA model="{shared.opts.sd_model_checkpoint}"')
|
||||
self.active = True
|
||||
self.model = shared.opts.sd_model_checkpoint
|
||||
names, te_multipliers, unet_multipliers, dyn_dims = parse(p, params_list, step)
|
||||
networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
|
||||
t1 = time.time()
|
||||
if len(networks.loaded_networks) > 0 and step == 0:
|
||||
infotext(p)
|
||||
prompt(p)
|
||||
shared.log.info(f'Load network: type=LoRA apply={[n.name for n in networks.loaded_networks]} te={te_multipliers} unet={unet_multipliers} dims={dyn_dims} load={t1-t0:.2f}')
|
||||
|
||||
def deactivate(self, p):
|
||||
t0 = time.time()
|
||||
if shared.native and len(networks.diffuser_loaded) > 0:
|
||||
if hasattr(shared.sd_model, "unload_lora_weights") and hasattr(shared.sd_model, "text_encoder"):
|
||||
if not (shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled is True):
|
||||
try:
|
||||
if shared.opts.lora_fuse_diffusers:
|
||||
shared.sd_model.unfuse_lora()
|
||||
shared.sd_model.unload_lora_weights() # fails for non-CLIP models
|
||||
except Exception:
|
||||
pass
|
||||
t1 = time.time()
|
||||
networks.timer['restore'] += t1 - t0
|
||||
if self.active and networks.debug:
|
||||
shared.log.debug(f"Network end: type=LoRA load={networks.timer['load']:.2f} apply={networks.timer['apply']:.2f} restore={networks.timer['restore']:.2f}")
|
||||
if self.errors:
|
||||
for k, v in self.errors.items():
|
||||
shared.log.error(f'LoRA: name="{k}" errors={v}')
|
||||
self.errors.clear()
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# import networks
|
||||
#
|
||||
# list_available_loras = networks.list_available_networks
|
||||
# available_loras = networks.available_networks
|
||||
# available_lora_aliases = networks.available_network_aliases
|
||||
# available_lora_hash_lookup = networks.available_network_hash_lookup
|
||||
# forbidden_lora_aliases = networks.forbidden_network_aliases
|
||||
# loaded_loras = networks.loaded_networks
|
||||
|
|
@ -0,0 +1,477 @@
|
|||
import os
|
||||
import re
|
||||
import bisect
|
||||
from typing import Dict
|
||||
import torch
|
||||
from modules import shared
|
||||
|
||||
|
||||
debug = os.environ.get('SD_LORA_DEBUG', None) is not None
|
||||
suffix_conversion = {
|
||||
"attentions": {},
|
||||
"resnets": {
|
||||
"conv1": "in_layers_2",
|
||||
"conv2": "out_layers_3",
|
||||
"norm1": "in_layers_0",
|
||||
"norm2": "out_layers_0",
|
||||
"time_emb_proj": "emb_layers_1",
|
||||
"conv_shortcut": "skip_connection",
|
||||
}
|
||||
}
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
||||
re_compiled = {}
|
||||
|
||||
|
||||
def make_unet_conversion_map() -> Dict[str, str]:
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j + 1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j * 2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j + 1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j * 2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
||||
return sd_hf_conversion_map
|
||||
|
||||
|
||||
class KeyConvert:
|
||||
def __init__(self):
|
||||
self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
|
||||
self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
|
||||
self.LORA_PREFIX_UNET = "lora_unet_"
|
||||
self.LORA_PREFIX_TEXT_ENCODER = "lora_te_"
|
||||
self.OFT_PREFIX_UNET = "oft_unet_"
|
||||
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
||||
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1_"
|
||||
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2_"
|
||||
|
||||
def __call__(self, key):
|
||||
if self.is_sdxl:
|
||||
if "diffusion_model" in key: # Fix NTC Slider naming error
|
||||
key = key.replace("diffusion_model", "lora_unet")
|
||||
map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
||||
map_keys.sort()
|
||||
search_key = key.replace(self.LORA_PREFIX_UNET, "").replace(self.OFT_PREFIX_UNET, "").replace(self.LORA_PREFIX_TEXT_ENCODER1, "").replace(self.LORA_PREFIX_TEXT_ENCODER2, "")
|
||||
position = bisect.bisect_right(map_keys, search_key)
|
||||
map_key = map_keys[position - 1]
|
||||
if search_key.startswith(map_key):
|
||||
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft", "lora") # pylint: disable=unsubscriptable-object
|
||||
if "lycoris" in key and "transformer" in key:
|
||||
key = key.replace("lycoris", "lora_transformer")
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
if sd_module is None:
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key.replace("guidance", "timestep"), None) # FLUX1 fix
|
||||
if debug and sd_module is None:
|
||||
raise RuntimeError(f"LoRA key not found in network_layer_mapping: key={key} mapping={shared.sd_model.network_layer_mapping.keys()}")
|
||||
return key, sd_module
|
||||
|
||||
|
||||
# Taken from https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py
|
||||
# Modified from 'lora_A' and 'lora_B' to 'lora_down' and 'lora_up'
|
||||
# Added early exit
|
||||
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
|
||||
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
||||
# All credits go to `kohya-ss`.
|
||||
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
|
||||
# scale weight by alpha and dim
|
||||
rank = down_weight.shape[0]
|
||||
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
|
||||
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
ait_sd[ait_key + ".lora_down.weight"] = down_weight * scale_down
|
||||
ait_sd[ait_key + ".lora_up.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
|
||||
|
||||
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
|
||||
# scale weight by alpha and dim
|
||||
alpha = sds_sd.pop(sds_key + ".alpha")
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
# calculate scale_down and scale_up
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
down_weight = down_weight * scale_down
|
||||
up_weight = up_weight * scale_up
|
||||
|
||||
# calculate dims if not provided
|
||||
num_splits = len(ait_keys)
|
||||
if dims is None:
|
||||
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||
else:
|
||||
assert sum(dims) == up_weight.shape[0]
|
||||
|
||||
# check upweight is sparse or not
|
||||
is_sparse = False
|
||||
if sd_lora_rank % num_splits == 0:
|
||||
ait_rank = sd_lora_rank // num_splits
|
||||
is_sparse = True
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
for k in range(len(dims)):
|
||||
if j == k:
|
||||
continue
|
||||
is_sparse = is_sparse and torch.all(
|
||||
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
|
||||
)
|
||||
i += dims[j]
|
||||
# if is_sparse:
|
||||
# print(f"weight is sparse: {sds_key}")
|
||||
|
||||
# make ai-toolkit weight
|
||||
ait_down_keys = [k + ".lora_down.weight" for k in ait_keys]
|
||||
ait_up_keys = [k + ".lora_up.weight" for k in ait_keys]
|
||||
if not is_sparse:
|
||||
# down_weight is copied to each split
|
||||
ait_sd.update({k: down_weight for k in ait_down_keys})
|
||||
|
||||
# up_weight is split to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
|
||||
else:
|
||||
# down_weight is chunked to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
|
||||
|
||||
# up_weight is sparse: only non-zero values are copied to each split
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
|
||||
i += dims[j]
|
||||
|
||||
def _convert_text_encoder_lora_key(key, lora_name):
|
||||
"""
|
||||
Converts a text encoder LoRA key to a Diffusers compatible key.
|
||||
"""
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
|
||||
else:
|
||||
key_to_replace = "lora_te2_"
|
||||
|
||||
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
|
||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
|
||||
|
||||
if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
|
||||
pass
|
||||
elif "mlp" in diffusers_name:
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
return diffusers_name
|
||||
|
||||
def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
|
||||
ait_sd = {}
|
||||
for i in range(19):
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_out.0",
|
||||
)
|
||||
_convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mlp_0",
|
||||
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mlp_2",
|
||||
f"transformer.transformer_blocks.{i}.ff.net.2",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mod_lin",
|
||||
f"transformer.transformer_blocks.{i}.norm1.linear",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_add_out",
|
||||
)
|
||||
_convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mlp_0",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mlp_2",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.net.2",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mod_lin",
|
||||
f"transformer.transformer_blocks.{i}.norm1_context.linear",
|
||||
)
|
||||
|
||||
for i in range(38):
|
||||
_convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear1",
|
||||
[
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_v",
|
||||
f"transformer.single_transformer_blocks.{i}.proj_mlp",
|
||||
],
|
||||
dims=[3072, 3072, 3072, 12288],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear2",
|
||||
f"transformer.single_transformer_blocks.{i}.proj_out",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_modulation_lin",
|
||||
f"transformer.single_transformer_blocks.{i}.norm.linear",
|
||||
)
|
||||
|
||||
if len(sds_sd) > 0:
|
||||
return None
|
||||
|
||||
return ait_sd
|
||||
|
||||
return _convert_sd_scripts_to_ai_toolkit(state_dict)
|
||||
|
||||
def _convert_kohya_sd3_lora_to_diffusers(state_dict):
|
||||
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
|
||||
ait_sd = {}
|
||||
for i in range(38):
|
||||
_convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_joint_blocks_{i}_context_block_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_joint_blocks_{i}_context_block_mlp_fc1",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_joint_blocks_{i}_context_block_mlp_fc2",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.net.2",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_joint_blocks_{i}_x_block_mlp_fc1",
|
||||
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_joint_blocks_{i}_x_block_mlp_fc2",
|
||||
f"transformer.transformer_blocks.{i}.ff.net.2",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1",
|
||||
f"transformer.transformer_blocks.{i}.norm1_context.linear",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_joint_blocks_{i}_x_block_adaLN_modulation_1",
|
||||
f"transformer.transformer_blocks.{i}.norm1.linear",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_joint_blocks_{i}_context_block_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_add_out",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_joint_blocks_{i}_x_block_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_out_0",
|
||||
)
|
||||
|
||||
_convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_joint_blocks_{i}_x_block_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
)
|
||||
remaining_keys = list(sds_sd.keys())
|
||||
te_state_dict = {}
|
||||
if remaining_keys:
|
||||
if not all(k.startswith("lora_te1") for k in remaining_keys):
|
||||
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
|
||||
for key in remaining_keys:
|
||||
if not key.endswith("lora_down.weight"):
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
lora_name_up = f"{lora_name}.lora_up.weight"
|
||||
lora_name_alpha = f"{lora_name}.alpha"
|
||||
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
|
||||
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
down_weight = sds_sd.pop(key)
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
te_state_dict[diffusers_name] = down_weight
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
|
||||
|
||||
if lora_name_alpha in sds_sd:
|
||||
alpha = sds_sd.pop(lora_name_alpha).item()
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
te_state_dict[diffusers_name] *= scale_down
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
|
||||
|
||||
if len(sds_sd) > 0:
|
||||
print(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
|
||||
|
||||
if te_state_dict:
|
||||
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
|
||||
|
||||
new_state_dict = {**ait_sd, **te_state_dict}
|
||||
return new_state_dict
|
||||
|
||||
return _convert_sd_scripts_to_ai_toolkit(state_dict)
|
||||
|
|
@ -0,0 +1,271 @@
|
|||
import os
|
||||
import time
|
||||
import json
|
||||
import datetime
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
import gradio as gr
|
||||
from rich import progress as p
|
||||
from modules import shared, devices
|
||||
from modules.ui_common import create_refresh_button
|
||||
from modules.call_queue import wrap_gradio_gpu_call
|
||||
|
||||
|
||||
class SVDHandler:
|
||||
def __init__(self, maxrank=0, rank_ratio=1):
|
||||
self.network_name: str = None
|
||||
self.U: torch.Tensor = None
|
||||
self.S: torch.Tensor = None
|
||||
self.Vh: torch.Tensor = None
|
||||
self.maxrank: int = maxrank
|
||||
self.rank_ratio: float = rank_ratio
|
||||
self.rank: int = 0
|
||||
self.out_size: int = None
|
||||
self.in_size: int = None
|
||||
self.kernel_size: tuple[int, int] = None
|
||||
self.conv2d: bool = False
|
||||
|
||||
def decompose(self, weight, backupweight):
|
||||
self.conv2d = len(weight.size()) == 4
|
||||
self.kernel_size = None if not self.conv2d else weight.size()[2:4]
|
||||
self.out_size, self.in_size = weight.size()[0:2]
|
||||
diffweight = weight.clone().to(devices.device)
|
||||
diffweight -= backupweight.to(devices.device)
|
||||
if self.conv2d:
|
||||
if self.conv2d and self.kernel_size != (1, 1):
|
||||
diffweight = diffweight.flatten(start_dim=1)
|
||||
else:
|
||||
diffweight = diffweight.squeeze()
|
||||
self.U, self.S, self.Vh = torch.svd_lowrank(diffweight.to(device=devices.device, dtype=torch.float), self.maxrank, 2)
|
||||
# del diffweight
|
||||
self.U = self.U.to(device=devices.cpu, dtype=torch.bfloat16)
|
||||
self.S = self.S.to(device=devices.cpu, dtype=torch.bfloat16)
|
||||
self.Vh = self.Vh.t().to(device=devices.cpu, dtype=torch.bfloat16) # svd_lowrank outputs a transposed matrix
|
||||
|
||||
def findrank(self):
|
||||
if self.rank_ratio < 1:
|
||||
S_squared = self.S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, self.rank_ratio ** 2)) + 1
|
||||
index = max(1, min(index, len(self.S) - 1))
|
||||
self.rank = index
|
||||
if self.maxrank > 0:
|
||||
self.rank = min(self.rank, self.maxrank)
|
||||
else:
|
||||
self.rank = min(self.in_size, self.out_size, self.maxrank)
|
||||
|
||||
def makeweights(self):
|
||||
self.findrank()
|
||||
up = self.U[:, :self.rank] @ torch.diag(self.S[:self.rank])
|
||||
down = self.Vh[:self.rank, :]
|
||||
if self.conv2d and self.kernel_size is not None:
|
||||
up = up.reshape(self.out_size, self.rank, 1, 1)
|
||||
down = down.reshape(self.rank, self.in_size, self.kernel_size[0], self.kernel_size[1]) # pylint: disable=unsubscriptable-object
|
||||
return_dict = {f'{self.network_name}.lora_up.weight': up.contiguous(),
|
||||
f'{self.network_name}.lora_down.weight': down.contiguous(),
|
||||
f'{self.network_name}.alpha': torch.tensor(down.shape[0]),
|
||||
}
|
||||
return return_dict
|
||||
|
||||
|
||||
def loaded_lora():
|
||||
if not shared.sd_loaded:
|
||||
return ""
|
||||
loaded = set()
|
||||
if hasattr(shared.sd_model, 'unet'):
|
||||
for _name, module in shared.sd_model.unet.named_modules():
|
||||
current = getattr(module, "network_current_names", None)
|
||||
if current is not None:
|
||||
current = [item[0] for item in current]
|
||||
loaded.update(current)
|
||||
return list(loaded)
|
||||
|
||||
|
||||
def loaded_lora_str():
|
||||
return ", ".join(loaded_lora())
|
||||
|
||||
|
||||
def make_meta(fn, maxrank, rank_ratio):
|
||||
meta = {
|
||||
"model_spec.sai_model_spec": "1.0.0",
|
||||
"model_spec.title": os.path.splitext(os.path.basename(fn))[0],
|
||||
"model_spec.author": "SD.Next",
|
||||
"model_spec.implementation": "https://github.com/vladmandic/automatic",
|
||||
"model_spec.date": datetime.datetime.now().astimezone().replace(microsecond=0).isoformat(),
|
||||
"model_spec.base_model": shared.opts.sd_model_checkpoint,
|
||||
"model_spec.dtype": str(devices.dtype),
|
||||
"model_spec.base_lora": json.dumps(loaded_lora()),
|
||||
"model_spec.config": f"maxrank={maxrank} rank_ratio={rank_ratio}",
|
||||
}
|
||||
if shared.sd_model_type == "sdxl":
|
||||
meta["model_spec.architecture"] = "stable-diffusion-xl-v1-base/lora" # sai standard
|
||||
meta["ss_base_model_version"] = "sdxl_base_v1-0" # kohya standard
|
||||
elif shared.sd_model_type == "sd":
|
||||
meta["model_spec.architecture"] = "stable-diffusion-v1/lora"
|
||||
meta["ss_base_model_version"] = "sd_v1"
|
||||
elif shared.sd_model_type == "f1":
|
||||
meta["model_spec.architecture"] = "flux-1-dev/lora"
|
||||
meta["ss_base_model_version"] = "flux1"
|
||||
elif shared.sd_model_type == "sc":
|
||||
meta["model_spec.architecture"] = "stable-cascade-v1-prior/lora"
|
||||
return meta
|
||||
|
||||
|
||||
def make_lora(fn, maxrank, auto_rank, rank_ratio, modules, overwrite):
|
||||
if not shared.sd_loaded or not shared.native:
|
||||
msg = "LoRA extract: model not loaded"
|
||||
shared.log.warning(msg)
|
||||
yield msg
|
||||
return
|
||||
if loaded_lora() == "":
|
||||
msg = "LoRA extract: no LoRA detected"
|
||||
shared.log.warning(msg)
|
||||
yield msg
|
||||
return
|
||||
if not fn:
|
||||
msg = "LoRA extract: target filename required"
|
||||
shared.log.warning(msg)
|
||||
yield msg
|
||||
return
|
||||
t0 = time.time()
|
||||
maxrank = int(maxrank)
|
||||
rank_ratio = 1 if not auto_rank else rank_ratio
|
||||
shared.log.debug(f'LoRA extract: modules={modules} maxrank={maxrank} auto={auto_rank} ratio={rank_ratio} fn="{fn}"')
|
||||
shared.state.begin('LoRA extract')
|
||||
|
||||
with p.Progress(p.TextColumn('[cyan]LoRA extract'), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TextColumn('[cyan]{task.description}'), console=shared.console) as progress:
|
||||
|
||||
if 'te' in modules and getattr(shared.sd_model, 'text_encoder', None) is not None:
|
||||
modules = shared.sd_model.text_encoder.named_modules()
|
||||
task = progress.add_task(description="te1 decompose", total=len(list(modules)))
|
||||
for name, module in shared.sd_model.text_encoder.named_modules():
|
||||
progress.update(task, advance=1)
|
||||
weights_backup = getattr(module, "network_weights_backup", None)
|
||||
if weights_backup is None or getattr(module, "network_current_names", None) is None:
|
||||
continue
|
||||
prefix = "lora_te1_" if hasattr(shared.sd_model, 'text_encoder_2') else "lora_te_"
|
||||
module.svdhandler = SVDHandler(maxrank, rank_ratio)
|
||||
module.svdhandler.network_name = prefix + name.replace(".", "_")
|
||||
with devices.inference_context():
|
||||
module.svdhandler.decompose(module.weight, weights_backup)
|
||||
progress.remove_task(task)
|
||||
t1 = time.time()
|
||||
|
||||
if 'te' in modules and getattr(shared.sd_model, 'text_encoder_2', None) is not None:
|
||||
modules = shared.sd_model.text_encoder_2.named_modules()
|
||||
task = progress.add_task(description="te2 decompose", total=len(list(modules)))
|
||||
for name, module in shared.sd_model.text_encoder_2.named_modules():
|
||||
progress.update(task, advance=1)
|
||||
weights_backup = getattr(module, "network_weights_backup", None)
|
||||
if weights_backup is None or getattr(module, "network_current_names", None) is None:
|
||||
continue
|
||||
module.svdhandler = SVDHandler(maxrank, rank_ratio)
|
||||
module.svdhandler.network_name = "lora_te2_" + name.replace(".", "_")
|
||||
with devices.inference_context():
|
||||
module.svdhandler.decompose(module.weight, weights_backup)
|
||||
progress.remove_task(task)
|
||||
t2 = time.time()
|
||||
|
||||
if 'unet' in modules and getattr(shared.sd_model, 'unet', None) is not None:
|
||||
modules = shared.sd_model.unet.named_modules()
|
||||
task = progress.add_task(description="unet decompose", total=len(list(modules)))
|
||||
for name, module in shared.sd_model.unet.named_modules():
|
||||
progress.update(task, advance=1)
|
||||
weights_backup = getattr(module, "network_weights_backup", None)
|
||||
if weights_backup is None or getattr(module, "network_current_names", None) is None:
|
||||
continue
|
||||
module.svdhandler = SVDHandler(maxrank, rank_ratio)
|
||||
module.svdhandler.network_name = "lora_unet_" + name.replace(".", "_")
|
||||
with devices.inference_context():
|
||||
module.svdhandler.decompose(module.weight, weights_backup)
|
||||
progress.remove_task(task)
|
||||
t3 = time.time()
|
||||
|
||||
# TODO: Handle quant for Flux
|
||||
# if 'te' in modules and getattr(shared.sd_model, 'transformer', None) is not None:
|
||||
# for name, module in shared.sd_model.transformer.named_modules():
|
||||
# if "norm" in name and "linear" not in name:
|
||||
# continue
|
||||
# weights_backup = getattr(module, "network_weights_backup", None)
|
||||
# if weights_backup is None:
|
||||
# continue
|
||||
# module.svdhandler = SVDHandler()
|
||||
# module.svdhandler.network_name = "lora_transformer_" + name.replace(".", "_")
|
||||
# module.svdhandler.decompose(module.weight, weights_backup)
|
||||
# module.svdhandler.findrank(rank, rank_ratio)
|
||||
|
||||
lora_state_dict = {}
|
||||
for sub in ['text_encoder', 'text_encoder_2', 'unet', 'transformer']:
|
||||
submodel = getattr(shared.sd_model, sub, None)
|
||||
if submodel is not None:
|
||||
modules = submodel.named_modules()
|
||||
task = progress.add_task(description=f"{sub} exctract", total=len(list(modules)))
|
||||
for _name, module in submodel.named_modules():
|
||||
progress.update(task, advance=1)
|
||||
if not hasattr(module, "svdhandler"):
|
||||
continue
|
||||
lora_state_dict.update(module.svdhandler.makeweights())
|
||||
del module.svdhandler
|
||||
progress.remove_task(task)
|
||||
t4 = time.time()
|
||||
|
||||
if not os.path.isabs(fn):
|
||||
fn = os.path.join(shared.cmd_opts.lora_dir, fn)
|
||||
if not fn.endswith('.safetensors'):
|
||||
fn += '.safetensors'
|
||||
if os.path.exists(fn):
|
||||
if overwrite:
|
||||
os.remove(fn)
|
||||
else:
|
||||
msg = f'LoRA extract: fn="{fn}" file exists'
|
||||
shared.log.warning(msg)
|
||||
yield msg
|
||||
return
|
||||
|
||||
shared.state.end()
|
||||
meta = make_meta(fn, maxrank, rank_ratio)
|
||||
shared.log.debug(f'LoRA metadata: {meta}')
|
||||
try:
|
||||
save_file(tensors=lora_state_dict, metadata=meta, filename=fn)
|
||||
except Exception as e:
|
||||
msg = f'LoRA extract error: fn="{fn}" {e}'
|
||||
shared.log.error(msg)
|
||||
yield msg
|
||||
return
|
||||
t5 = time.time()
|
||||
shared.log.debug(f'LoRA extract: time={t5-t0:.2f} te1={t1-t0:.2f} te2={t2-t1:.2f} unet={t3-t2:.2f} save={t5-t4:.2f}')
|
||||
keys = list(lora_state_dict.keys())
|
||||
msg = f'LoRA extract: fn="{fn}" keys={len(keys)}'
|
||||
shared.log.info(msg)
|
||||
yield msg
|
||||
|
||||
|
||||
def create_ui():
|
||||
def gr_show(visible=True):
|
||||
return {"visible": visible, "__type__": "update"}
|
||||
|
||||
with gr.Tab(label="Extract LoRA"):
|
||||
with gr.Row():
|
||||
loaded = gr.Textbox(placeholder="Press refresh to query loaded LoRA", label="Loaded LoRA", interactive=False)
|
||||
create_refresh_button(loaded, lambda: None, lambda: {'value': loaded_lora_str()}, "testid")
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
modules = gr.CheckboxGroup(label="Modules to extract", value=['unet'], choices=['te', 'unet'])
|
||||
with gr.Row():
|
||||
auto_rank = gr.Checkbox(value=False, label="Automatically determine rank")
|
||||
rank_ratio = gr.Slider(label="Autorank ratio", value=1, minimum=0, maximum=1, step=0.05, visible=False)
|
||||
rank = gr.Slider(label="Maximum rank", value=32, minimum=1, maximum=256)
|
||||
with gr.Row():
|
||||
filename = gr.Textbox(label="LoRA target filename")
|
||||
overwrite = gr.Checkbox(value=False, label="Overwrite existing file")
|
||||
with gr.Row():
|
||||
extract = gr.Button(value="Extract LoRA", variant='primary')
|
||||
status = gr.HTML(value="", show_label=False)
|
||||
|
||||
auto_rank.change(fn=lambda x: gr_show(x), inputs=[auto_rank], outputs=[rank_ratio])
|
||||
extract.click(
|
||||
fn=wrap_gradio_gpu_call(make_lora, extra_outputs=[]),
|
||||
inputs=[filename, rank, auto_rank, rank_ratio, modules, overwrite],
|
||||
outputs=[status]
|
||||
)
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
import torch
|
||||
|
||||
|
||||
def make_weight_cp(t, wa, wb):
|
||||
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
||||
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
||||
|
||||
|
||||
def rebuild_conventional(up, down, shape, dyn_dim=None):
|
||||
up = up.reshape(up.size(0), -1)
|
||||
down = down.reshape(down.size(0), -1)
|
||||
if dyn_dim is not None:
|
||||
up = up[:, :dyn_dim]
|
||||
down = down[:dyn_dim, :]
|
||||
return (up @ down).reshape(shape)
|
||||
|
||||
|
||||
def rebuild_cp_decomposition(up, down, mid):
|
||||
up = up.reshape(up.size(0), -1)
|
||||
down = down.reshape(down.size(0), -1)
|
||||
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
||||
|
||||
|
||||
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
|
||||
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
"""
|
||||
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||
second value is higher or equal than first value.
|
||||
|
||||
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||
secon value is a value for weight.
|
||||
|
||||
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||
|
||||
examples
|
||||
factor
|
||||
-1 2 4 8 16 ...
|
||||
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||
"""
|
||||
|
||||
if factor > 0 and (dimension % factor) == 0:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
if factor < 0:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m<n:
|
||||
new_m = m + 1
|
||||
while dimension%new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m>factor:
|
||||
break
|
||||
m, n = new_m, new_n
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
|
|
@ -0,0 +1,187 @@
|
|||
import os
|
||||
from collections import namedtuple
|
||||
import enum
|
||||
|
||||
from modules import sd_models, hashes, shared
|
||||
|
||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||
|
||||
|
||||
class SdVersion(enum.Enum):
|
||||
Unknown = 1
|
||||
SD1 = 2
|
||||
SD2 = 3
|
||||
SD3 = 3
|
||||
SDXL = 4
|
||||
SC = 5
|
||||
F1 = 6
|
||||
|
||||
|
||||
class NetworkOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
self.shorthash = None
|
||||
self.hash = None
|
||||
self.name = name
|
||||
self.filename = filename
|
||||
if filename.startswith(shared.cmd_opts.lora_dir):
|
||||
self.fullname = os.path.splitext(filename[len(shared.cmd_opts.lora_dir):].strip("/"))[0]
|
||||
else:
|
||||
self.fullname = name
|
||||
self.metadata = {}
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
if self.is_safetensors:
|
||||
self.metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||
if self.metadata:
|
||||
m = {}
|
||||
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||
m[k] = v
|
||||
self.metadata = m
|
||||
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||
sha256 = hashes.sha256_from_cache(self.filename, "lora/" + self.name) or hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=True) or self.metadata.get('sshs_model_hash')
|
||||
self.set_hash(sha256)
|
||||
self.sd_version = self.detect_version()
|
||||
|
||||
def detect_version(self):
|
||||
base = str(self.metadata.get('ss_base_model_version', "")).lower()
|
||||
arch = str(self.metadata.get('modelspec.architecture', "")).lower()
|
||||
if base.startswith("sd_v1"):
|
||||
return 'sd1'
|
||||
if base.startswith("sdxl"):
|
||||
return 'xl'
|
||||
if base.startswith("stable_cascade"):
|
||||
return 'sc'
|
||||
if base.startswith("sd3"):
|
||||
return 'sd3'
|
||||
if base.startswith("flux"):
|
||||
return 'f1'
|
||||
|
||||
if arch.startswith("stable-diffusion-v1"):
|
||||
return 'sd1'
|
||||
if arch.startswith("stable-diffusion-xl"):
|
||||
return 'xl'
|
||||
if arch.startswith("stable-cascade"):
|
||||
return 'sc'
|
||||
if arch.startswith("flux"):
|
||||
return 'f1'
|
||||
|
||||
if "v1-5" in str(self.metadata.get('ss_sd_model_name', "")):
|
||||
return 'sd1'
|
||||
if str(self.metadata.get('ss_v2', "")) == "True":
|
||||
return 'sd2'
|
||||
if 'flux' in self.name.lower():
|
||||
return 'f1'
|
||||
if 'xl' in self.name.lower():
|
||||
return 'xl'
|
||||
|
||||
return ''
|
||||
|
||||
def set_hash(self, v):
|
||||
self.hash = v or ''
|
||||
self.shorthash = self.hash[0:8]
|
||||
|
||||
def read_hash(self):
|
||||
if not self.hash:
|
||||
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
||||
|
||||
def get_alias(self):
|
||||
import modules.lora.networks as networks
|
||||
return self.name if shared.opts.lora_preferred_name == "filename" or self.alias.lower() in networks.forbidden_network_aliases else self.alias
|
||||
|
||||
|
||||
class Network: # LoraModule
|
||||
def __init__(self, name, network_on_disk: NetworkOnDisk):
|
||||
self.name = name
|
||||
self.network_on_disk = network_on_disk
|
||||
self.te_multiplier = 1.0
|
||||
self.unet_multiplier = [1.0] * 3
|
||||
self.dyn_dim = None
|
||||
self.modules = {}
|
||||
self.bundle_embeddings = {}
|
||||
self.mtime = None
|
||||
self.mentioned_name = None
|
||||
self.tags = None
|
||||
"""the text that was used to add the network to prompt - can be either name or an alias"""
|
||||
|
||||
|
||||
class ModuleType:
|
||||
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: # pylint: disable=W0613
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModule:
|
||||
def __init__(self, net: Network, weights: NetworkWeights):
|
||||
self.network = net
|
||||
self.network_key = weights.network_key
|
||||
self.sd_key = weights.sd_key
|
||||
self.sd_module = weights.sd_module
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
self.dim = None
|
||||
self.bias = weights.w.get("bias")
|
||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
||||
self.dora_scale = weights.w.get("dora_scale", None)
|
||||
self.dora_norm_dims = len(self.shape) - 1
|
||||
|
||||
def multiplier(self):
|
||||
unet_multiplier = 3 * [self.network.unet_multiplier] if not isinstance(self.network.unet_multiplier, list) else self.network.unet_multiplier
|
||||
if 'transformer' in self.sd_key[:20]:
|
||||
return self.network.te_multiplier
|
||||
if "down_blocks" in self.sd_key:
|
||||
return unet_multiplier[0]
|
||||
if "mid_block" in self.sd_key:
|
||||
return unet_multiplier[1]
|
||||
if "up_blocks" in self.sd_key:
|
||||
return unet_multiplier[2]
|
||||
else:
|
||||
return unet_multiplier[0]
|
||||
|
||||
def calc_scale(self):
|
||||
if self.scale is not None:
|
||||
return self.scale
|
||||
if self.dim is not None and self.alpha is not None:
|
||||
return self.alpha / self.dim
|
||||
return 1.0
|
||||
|
||||
def apply_weight_decompose(self, updown, orig_weight):
|
||||
# Match the device/dtype
|
||||
orig_weight = orig_weight.to(updown.dtype)
|
||||
dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
|
||||
updown = updown.to(orig_weight.device)
|
||||
|
||||
merged_scale1 = updown + orig_weight
|
||||
merged_scale1_norm = (
|
||||
merged_scale1.transpose(0, 1)
|
||||
.reshape(merged_scale1.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
dora_merged = (
|
||||
merged_scale1 * (dora_scale / merged_scale1_norm)
|
||||
)
|
||||
final_updown = dora_merged - orig_weight
|
||||
return final_updown
|
||||
|
||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||
if self.bias is not None:
|
||||
updown = updown.reshape(self.bias.shape)
|
||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = updown.reshape(output_shape)
|
||||
if len(output_shape) == 4:
|
||||
updown = updown.reshape(output_shape)
|
||||
if orig_weight.size().numel() == updown.size().numel():
|
||||
updown = updown.reshape(orig_weight.shape)
|
||||
if ex_bias is not None:
|
||||
ex_bias = ex_bias * self.multiplier()
|
||||
if self.dora_scale is not None:
|
||||
updown = self.apply_weight_decompose(updown, orig_weight)
|
||||
return updown * self.calc_scale() * self.multiplier(), ex_bias
|
||||
|
||||
def calc_updown(self, target):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x, y):
|
||||
raise NotImplementedError
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
import modules.lora.network as network
|
||||
|
||||
|
||||
class ModuleTypeFull(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["diff"]):
|
||||
return NetworkModuleFull(net, weights)
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleFull(network.NetworkModule): # pylint: disable=abstract-method
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.weight = weights.w.get("diff")
|
||||
self.ex_bias = weights.w.get("diff_b")
|
||||
|
||||
def calc_updown(self, target):
|
||||
output_shape = self.weight.shape
|
||||
updown = self.weight.to(target.device, dtype=target.dtype)
|
||||
if self.ex_bias is not None:
|
||||
ex_bias = self.ex_bias.to(target.device, dtype=target.dtype)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
return self.finalize_updown(updown, target, output_shape, ex_bias)
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
import modules.lora.network as network
|
||||
|
||||
|
||||
class ModuleTypeGLora(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
|
||||
return NetworkModuleGLora(net, weights)
|
||||
return None
|
||||
|
||||
# adapted from https://github.com/KohakuBlueleaf/LyCORIS
|
||||
class NetworkModuleGLora(network.NetworkModule): # pylint: disable=abstract-method
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
|
||||
self.w1a = weights.w["a1.weight"]
|
||||
self.w1b = weights.w["b1.weight"]
|
||||
self.w2a = weights.w["a2.weight"]
|
||||
self.w2b = weights.w["b2.weight"]
|
||||
|
||||
def calc_updown(self, target): # pylint: disable=arguments-differ
|
||||
w1a = self.w1a.to(target.device, dtype=target.dtype)
|
||||
w1b = self.w1b.to(target.device, dtype=target.dtype)
|
||||
w2a = self.w2a.to(target.device, dtype=target.dtype)
|
||||
w2b = self.w2b.to(target.device, dtype=target.dtype)
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
updown = (w2b @ w1b) + ((target @ w2a) @ w1a)
|
||||
return self.finalize_updown(updown, target, output_shape)
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
import modules.lora.lyco_helpers as lyco_helpers
|
||||
import modules.lora.network as network
|
||||
|
||||
|
||||
class ModuleTypeHada(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
|
||||
return NetworkModuleHada(net, weights)
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleHada(network.NetworkModule): # pylint: disable=abstract-method
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
self.w1a = weights.w["hada_w1_a"]
|
||||
self.w1b = weights.w["hada_w1_b"]
|
||||
self.dim = self.w1b.shape[0]
|
||||
self.w2a = weights.w["hada_w2_a"]
|
||||
self.w2b = weights.w["hada_w2_b"]
|
||||
self.t1 = weights.w.get("hada_t1")
|
||||
self.t2 = weights.w.get("hada_t2")
|
||||
|
||||
def calc_updown(self, target):
|
||||
w1a = self.w1a.to(target.device, dtype=target.dtype)
|
||||
w1b = self.w1b.to(target.device, dtype=target.dtype)
|
||||
w2a = self.w2a.to(target.device, dtype=target.dtype)
|
||||
w2b = self.w2b.to(target.device, dtype=target.dtype)
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
if self.t1 is not None:
|
||||
output_shape = [w1a.size(1), w1b.size(1)]
|
||||
t1 = self.t1.to(target.device, dtype=target.dtype)
|
||||
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
|
||||
output_shape += t1.shape[2:]
|
||||
else:
|
||||
if len(w1b.shape) == 4:
|
||||
output_shape += w1b.shape[2:]
|
||||
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
|
||||
if self.t2 is not None:
|
||||
t2 = self.t2.to(target.device, dtype=target.dtype)
|
||||
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
else:
|
||||
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
|
||||
updown = updown1 * updown2
|
||||
return self.finalize_updown(updown, target, output_shape)
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
import modules.lora.network as network
|
||||
|
||||
class ModuleTypeIa3(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["weight"]):
|
||||
return NetworkModuleIa3(net, weights)
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleIa3(network.NetworkModule): # pylint: disable=abstract-method
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
self.w = weights.w["weight"]
|
||||
self.on_input = weights.w["on_input"].item()
|
||||
|
||||
def calc_updown(self, target):
|
||||
w = self.w.to(target.device, dtype=target.dtype)
|
||||
output_shape = [w.size(0), target.size(1)]
|
||||
if self.on_input:
|
||||
output_shape.reverse()
|
||||
else:
|
||||
w = w.reshape(-1, 1)
|
||||
updown = target * w
|
||||
return self.finalize_updown(updown, target, output_shape)
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
import torch
|
||||
import modules.lora.lyco_helpers as lyco_helpers
|
||||
import modules.lora.network as network
|
||||
|
||||
|
||||
class ModuleTypeLokr(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
|
||||
has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
|
||||
if has_1 and has_2:
|
||||
return NetworkModuleLokr(net, weights)
|
||||
return None
|
||||
|
||||
|
||||
def make_kron(orig_shape, w1, w2):
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
return torch.kron(w1, w2).reshape(orig_shape)
|
||||
|
||||
|
||||
class NetworkModuleLokr(network.NetworkModule): # pylint: disable=abstract-method
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
self.w1 = weights.w.get("lokr_w1")
|
||||
self.w1a = weights.w.get("lokr_w1_a")
|
||||
self.w1b = weights.w.get("lokr_w1_b")
|
||||
self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
|
||||
self.w2 = weights.w.get("lokr_w2")
|
||||
self.w2a = weights.w.get("lokr_w2_a")
|
||||
self.w2b = weights.w.get("lokr_w2_b")
|
||||
self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
|
||||
self.t2 = weights.w.get("lokr_t2")
|
||||
|
||||
def calc_updown(self, target):
|
||||
if self.w1 is not None:
|
||||
w1 = self.w1.to(target.device, dtype=target.dtype)
|
||||
else:
|
||||
w1a = self.w1a.to(target.device, dtype=target.dtype)
|
||||
w1b = self.w1b.to(target.device, dtype=target.dtype)
|
||||
w1 = w1a @ w1b
|
||||
if self.w2 is not None:
|
||||
w2 = self.w2.to(target.device, dtype=target.dtype)
|
||||
elif self.t2 is None:
|
||||
w2a = self.w2a.to(target.device, dtype=target.dtype)
|
||||
w2b = self.w2b.to(target.device, dtype=target.dtype)
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
t2 = self.t2.to(target.device, dtype=target.dtype)
|
||||
w2a = self.w2a.to(target.device, dtype=target.dtype)
|
||||
w2b = self.w2b.to(target.device, dtype=target.dtype)
|
||||
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
||||
if len(target.shape) == 4:
|
||||
output_shape = target.shape
|
||||
updown = make_kron(output_shape, w1, w2)
|
||||
return self.finalize_updown(updown, target, output_shape)
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
import torch
|
||||
import diffusers.models.lora as diffusers_lora
|
||||
import modules.lora.lyco_helpers as lyco_helpers
|
||||
import modules.lora.network as network
|
||||
from modules import devices
|
||||
|
||||
|
||||
class ModuleTypeLora(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
|
||||
return NetworkModuleLora(net, weights)
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleLora(network.NetworkModule):
|
||||
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
self.up_model = self.create_module(weights.w, "lora_up.weight")
|
||||
self.down_model = self.create_module(weights.w, "lora_down.weight")
|
||||
self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
|
||||
self.dim = weights.w["lora_down.weight"].shape[0]
|
||||
|
||||
def create_module(self, weights, key, none_ok=False):
|
||||
from modules.shared import opts
|
||||
weight = weights.get(key)
|
||||
if weight is None and none_ok:
|
||||
return None
|
||||
linear_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear]
|
||||
is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear", "Linear4bit"}
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] or self.sd_module.__class__.__name__ in {"NNCFConv2d", "QConv2d"}
|
||||
if is_linear:
|
||||
weight = weight.reshape(weight.shape[0], -1)
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif is_conv and key == "lora_down.weight" or key == "dyn_up":
|
||||
if len(weight.shape) == 2:
|
||||
weight = weight.reshape(weight.shape[0], -1, 1, 1)
|
||||
if weight.shape[2] != 1 or weight.shape[3] != 1:
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
|
||||
else:
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||
elif is_conv and key == "lora_mid.weight":
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
|
||||
elif is_conv and key == "lora_up.weight" or key == "dyn_down":
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||
else:
|
||||
raise AssertionError(f'Lora unsupported: layer={self.network_key} type={type(self.sd_module).__name__}')
|
||||
with torch.no_grad():
|
||||
if weight.shape != module.weight.shape:
|
||||
weight = weight.reshape(module.weight.shape)
|
||||
module.weight.copy_(weight)
|
||||
if opts.lora_load_gpu:
|
||||
module = module.to(device=devices.device, dtype=devices.dtype)
|
||||
module.weight.requires_grad_(False)
|
||||
return module
|
||||
|
||||
def calc_updown(self, target): # pylint: disable=W0237
|
||||
target_dtype = target.dtype if target.dtype != torch.uint8 else self.up_model.weight.dtype
|
||||
up = self.up_model.weight.to(target.device, dtype=target_dtype)
|
||||
down = self.down_model.weight.to(target.device, dtype=target_dtype)
|
||||
output_shape = [up.size(0), down.size(1)]
|
||||
if self.mid_model is not None:
|
||||
# cp-decomposition
|
||||
mid = self.mid_model.weight.to(target.device, dtype=target_dtype)
|
||||
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
|
||||
output_shape += mid.shape[2:]
|
||||
else:
|
||||
if len(down.shape) == 4:
|
||||
output_shape += down.shape[2:]
|
||||
updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
|
||||
return self.finalize_updown(updown, target, output_shape)
|
||||
|
||||
def forward(self, x, y):
|
||||
self.up_model.to(device=devices.device)
|
||||
self.down_model.to(device=devices.device)
|
||||
if hasattr(y, "scale"):
|
||||
return y(scale=1) + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
|
||||
return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
import modules.lora.network as network
|
||||
|
||||
class ModuleTypeNorm(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["w_norm", "b_norm"]):
|
||||
return NetworkModuleNorm(net, weights)
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleNorm(network.NetworkModule): # pylint: disable=abstract-method
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
self.w_norm = weights.w.get("w_norm")
|
||||
self.b_norm = weights.w.get("b_norm")
|
||||
|
||||
def calc_updown(self, target):
|
||||
output_shape = self.w_norm.shape
|
||||
updown = self.w_norm.to(target.device, dtype=target.dtype)
|
||||
if self.b_norm is not None:
|
||||
ex_bias = self.b_norm.to(target.device, dtype=target.dtype)
|
||||
else:
|
||||
ex_bias = None
|
||||
return self.finalize_updown(updown, target, output_shape, ex_bias)
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
import torch
|
||||
import modules.lora.network as network
|
||||
from modules.lora.lyco_helpers import factorization
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class ModuleTypeOFT(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
|
||||
return NetworkModuleOFT(net, weights)
|
||||
return None
|
||||
|
||||
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
|
||||
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
|
||||
class NetworkModuleOFT(network.NetworkModule): # pylint: disable=abstract-method
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
self.lin_module = None
|
||||
self.org_module: list[torch.Module] = [self.sd_module]
|
||||
self.scale = 1.0
|
||||
|
||||
# kohya-ss
|
||||
if "oft_blocks" in weights.w.keys():
|
||||
self.is_kohya = True
|
||||
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||
self.alpha = weights.w["alpha"] # alpha is constraint
|
||||
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||
# LyCORIS
|
||||
elif "oft_diag" in weights.w.keys():
|
||||
self.is_kohya = False
|
||||
self.oft_blocks = weights.w["oft_diag"]
|
||||
# self.alpha is unused
|
||||
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||
|
||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||
|
||||
if is_linear:
|
||||
self.out_dim = self.sd_module.out_features
|
||||
elif is_conv:
|
||||
self.out_dim = self.sd_module.out_channels
|
||||
elif is_other_linear:
|
||||
self.out_dim = self.sd_module.embed_dim
|
||||
|
||||
if self.is_kohya:
|
||||
self.constraint = self.alpha * self.out_dim
|
||||
self.num_blocks = self.dim
|
||||
self.block_size = self.out_dim // self.dim
|
||||
else:
|
||||
self.constraint = None
|
||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||
|
||||
def calc_updown(self, target):
|
||||
oft_blocks = self.oft_blocks.to(target.device, dtype=target.dtype)
|
||||
eye = torch.eye(self.block_size, device=target.device)
|
||||
constraint = self.constraint.to(target.device)
|
||||
|
||||
if self.is_kohya:
|
||||
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||
norm_Q = torch.norm(block_Q.flatten()).to(target.device)
|
||||
new_norm_Q = torch.clamp(norm_Q, max=constraint)
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
mat1 = eye + block_Q
|
||||
mat2 = (eye - block_Q).float().inverse()
|
||||
oft_blocks = torch.matmul(mat1, mat2)
|
||||
|
||||
R = oft_blocks.to(target.device, dtype=target.dtype)
|
||||
|
||||
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||
merged_weight = rearrange(target, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||
merged_weight = torch.einsum(
|
||||
'k n m, k n ... -> k m ...',
|
||||
R,
|
||||
merged_weight
|
||||
)
|
||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||
|
||||
updown = merged_weight.to(target.device, dtype=target.dtype) - target
|
||||
output_shape = target.shape
|
||||
return self.finalize_updown(updown, target, output_shape)
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
from modules import shared
|
||||
|
||||
|
||||
maybe_diffusers = [ # forced if lora_maybe_diffusers is enabled
|
||||
'aaebf6360f7d', # sd15-lcm
|
||||
'3d18b05e4f56', # sdxl-lcm
|
||||
'b71dcb732467', # sdxl-tcd
|
||||
'813ea5fb1c67', # sdxl-turbo
|
||||
# not really needed, but just in case
|
||||
'5a48ac366664', # hyper-sd15-1step
|
||||
'ee0ff23dcc42', # hyper-sd15-2step
|
||||
'e476eb1da5df', # hyper-sd15-4step
|
||||
'ecb844c3f3b0', # hyper-sd15-8step
|
||||
'1ab289133ebb', # hyper-sd15-8step-cfg
|
||||
'4f494295edb1', # hyper-sdxl-8step
|
||||
'ca14a8c621f8', # hyper-sdxl-8step-cfg
|
||||
'1c88f7295856', # hyper-sdxl-4step
|
||||
'fdd5dcd1d88a', # hyper-sdxl-2step
|
||||
'8cca3706050b', # hyper-sdxl-1step
|
||||
]
|
||||
|
||||
force_diffusers = [ # forced always
|
||||
'816d0eed49fd', # flash-sdxl
|
||||
'c2ec22757b46', # flash-sd15
|
||||
]
|
||||
|
||||
force_models = [ # forced always
|
||||
'sc',
|
||||
# 'sd3',
|
||||
'kandinsky',
|
||||
'hunyuandit',
|
||||
'auraflow',
|
||||
]
|
||||
|
||||
force_classes = [ # forced always
|
||||
]
|
||||
|
||||
|
||||
def check_override(shorthash=''):
|
||||
force = False
|
||||
force = force or (shared.sd_model_type in force_models)
|
||||
force = force or (shared.sd_model.__class__.__name__ in force_classes)
|
||||
if len(shorthash) < 4:
|
||||
return force
|
||||
force = force or (any(x.startswith(shorthash) for x in maybe_diffusers) if shared.opts.lora_maybe_diffusers else False)
|
||||
force = force or any(x.startswith(shorthash) for x in force_diffusers)
|
||||
if force and shared.opts.lora_maybe_diffusers:
|
||||
shared.log.debug('LoRA override: force diffusers')
|
||||
return force
|
||||
|
|
@ -0,0 +1,453 @@
|
|||
from typing import Union, List
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import concurrent
|
||||
import modules.lora.network as network
|
||||
import modules.lora.network_lora as network_lora
|
||||
import modules.lora.network_hada as network_hada
|
||||
import modules.lora.network_ia3 as network_ia3
|
||||
import modules.lora.network_oft as network_oft
|
||||
import modules.lora.network_lokr as network_lokr
|
||||
import modules.lora.network_full as network_full
|
||||
import modules.lora.network_norm as network_norm
|
||||
import modules.lora.network_glora as network_glora
|
||||
import modules.lora.network_overrides as network_overrides
|
||||
import modules.lora.lora_convert as lora_convert
|
||||
import torch
|
||||
import diffusers.models.lora
|
||||
from modules import shared, devices, sd_models, sd_models_compile, errors, scripts, files_cache, model_quant
|
||||
|
||||
|
||||
debug = os.environ.get('SD_LORA_DEBUG', None) is not None
|
||||
extra_network_lora = None
|
||||
available_networks = {}
|
||||
available_network_aliases = {}
|
||||
loaded_networks: List[network.Network] = []
|
||||
timer = { 'load': 0, 'apply': 0, 'restore': 0, 'deactivate': 0 }
|
||||
lora_cache = {}
|
||||
diffuser_loaded = []
|
||||
diffuser_scales = []
|
||||
available_network_hash_lookup = {}
|
||||
forbidden_network_aliases = {}
|
||||
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||
module_types = [
|
||||
network_lora.ModuleTypeLora(),
|
||||
network_hada.ModuleTypeHada(),
|
||||
network_ia3.ModuleTypeIa3(),
|
||||
network_oft.ModuleTypeOFT(),
|
||||
network_lokr.ModuleTypeLokr(),
|
||||
network_full.ModuleTypeFull(),
|
||||
network_norm.ModuleTypeNorm(),
|
||||
network_glora.ModuleTypeGLora(),
|
||||
]
|
||||
|
||||
|
||||
def assign_network_names_to_compvis_modules(sd_model):
|
||||
if sd_model is None:
|
||||
return
|
||||
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
|
||||
network_layer_mapping = {}
|
||||
if hasattr(sd_model, 'text_encoder') and sd_model.text_encoder is not None:
|
||||
for name, module in sd_model.text_encoder.named_modules():
|
||||
prefix = "lora_te1_" if hasattr(sd_model, 'text_encoder_2') else "lora_te_"
|
||||
network_name = prefix + name.replace(".", "_")
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
if hasattr(sd_model, 'text_encoder_2'):
|
||||
for name, module in sd_model.text_encoder_2.named_modules():
|
||||
network_name = "lora_te2_" + name.replace(".", "_")
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
if hasattr(sd_model, 'unet'):
|
||||
for name, module in sd_model.unet.named_modules():
|
||||
network_name = "lora_unet_" + name.replace(".", "_")
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
if hasattr(sd_model, 'transformer'):
|
||||
for name, module in sd_model.transformer.named_modules():
|
||||
network_name = "lora_transformer_" + name.replace(".", "_")
|
||||
network_layer_mapping[network_name] = module
|
||||
if "norm" in network_name and "linear" not in network_name and shared.sd_model_type != "sd3":
|
||||
continue
|
||||
module.network_layer_name = network_name
|
||||
shared.sd_model.network_layer_mapping = network_layer_mapping
|
||||
|
||||
|
||||
def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> network.Network | None:
|
||||
name = name.replace(".", "_")
|
||||
shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" detected={network_on_disk.sd_version} method=diffusers scale={lora_scale} fuse={shared.opts.lora_fuse_diffusers}')
|
||||
if not shared.native:
|
||||
return None
|
||||
if not hasattr(shared.sd_model, 'load_lora_weights'):
|
||||
shared.log.error(f'Load network: type=LoRA class={shared.sd_model.__class__} does not implement load lora')
|
||||
return None
|
||||
try:
|
||||
shared.sd_model.load_lora_weights(network_on_disk.filename, adapter_name=name)
|
||||
except Exception as e:
|
||||
if 'already in use' in str(e):
|
||||
pass
|
||||
else:
|
||||
if 'The following keys have not been correctly renamed' in str(e):
|
||||
shared.log.error(f'Load network: type=LoRA name="{name}" diffusers unsupported format')
|
||||
else:
|
||||
shared.log.error(f'Load network: type=LoRA name="{name}" {e}')
|
||||
if debug:
|
||||
errors.display(e, "LoRA")
|
||||
return None
|
||||
if name not in diffuser_loaded:
|
||||
diffuser_loaded.append(name)
|
||||
diffuser_scales.append(lora_scale)
|
||||
net = network.Network(name, network_on_disk)
|
||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||
return net
|
||||
|
||||
|
||||
def load_network(name, network_on_disk) -> network.Network | None:
|
||||
if not shared.sd_loaded:
|
||||
return None
|
||||
|
||||
cached = lora_cache.get(name, None)
|
||||
if debug:
|
||||
shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" type=lora {"cached" if cached else ""}')
|
||||
if cached is not None:
|
||||
return cached
|
||||
net = network.Network(name, network_on_disk)
|
||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||
sd = sd_models.read_state_dict(network_on_disk.filename, what='network')
|
||||
if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict
|
||||
sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # pylint: disable=protected-access
|
||||
if shared.sd_model_type == 'sd3': # if kohya flux lora, convert state_dict
|
||||
try:
|
||||
sd = lora_convert._convert_kohya_sd3_lora_to_diffusers(sd) or sd # pylint: disable=protected-access
|
||||
except ValueError: # EAFP for diffusers PEFT keys
|
||||
pass
|
||||
assign_network_names_to_compvis_modules(shared.sd_model)
|
||||
keys_failed_to_match = {}
|
||||
matched_networks = {}
|
||||
bundle_embeddings = {}
|
||||
convert = lora_convert.KeyConvert()
|
||||
for key_network, weight in sd.items():
|
||||
parts = key_network.split('.')
|
||||
if parts[0] == "bundle_emb":
|
||||
emb_name, vec_name = parts[1], key_network.split(".", 2)[-1]
|
||||
emb_dict = bundle_embeddings.get(emb_name, {})
|
||||
emb_dict[vec_name] = weight
|
||||
bundle_embeddings[emb_name] = emb_dict
|
||||
continue
|
||||
if len(parts) > 5: # messy handler for diffusers peft lora
|
||||
key_network_without_network_parts = '_'.join(parts[:-2])
|
||||
if not key_network_without_network_parts.startswith('lora_'):
|
||||
key_network_without_network_parts = 'lora_' + key_network_without_network_parts
|
||||
network_part = '.'.join(parts[-2:]).replace('lora_A', 'lora_down').replace('lora_B', 'lora_up')
|
||||
else:
|
||||
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
||||
key, sd_module = convert(key_network_without_network_parts)
|
||||
if sd_module is None:
|
||||
keys_failed_to_match[key_network] = key
|
||||
continue
|
||||
if key not in matched_networks:
|
||||
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
|
||||
matched_networks[key].w[network_part] = weight
|
||||
network_types = []
|
||||
for key, weights in matched_networks.items():
|
||||
net_module = None
|
||||
for nettype in module_types:
|
||||
net_module = nettype.create_module(net, weights)
|
||||
if net_module is not None:
|
||||
network_types.append(nettype.__class__.__name__)
|
||||
break
|
||||
if net_module is None:
|
||||
shared.log.error(f'LoRA unhandled: name={name} key={key} weights={weights.w.keys()}')
|
||||
else:
|
||||
net.modules[key] = net_module
|
||||
if len(keys_failed_to_match) > 0:
|
||||
shared.log.warning(f'LoRA name="{name}" type={set(network_types)} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}')
|
||||
if debug:
|
||||
shared.log.debug(f'LoRA name="{name}" unmatched={keys_failed_to_match}')
|
||||
else:
|
||||
shared.log.debug(f'LoRA name="{name}" type={set(network_types)} keys={len(matched_networks)}')
|
||||
if len(matched_networks) == 0:
|
||||
return None
|
||||
lora_cache[name] = net
|
||||
net.bundle_embeddings = bundle_embeddings
|
||||
return net
|
||||
|
||||
def maybe_recompile_model(names, te_multipliers):
|
||||
recompile_model = False
|
||||
if shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled:
|
||||
if len(names) == len(shared.compiled_model_state.lora_model):
|
||||
for i, name in enumerate(names):
|
||||
if shared.compiled_model_state.lora_model[
|
||||
i] != f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}":
|
||||
recompile_model = True
|
||||
shared.compiled_model_state.lora_model = []
|
||||
break
|
||||
if not recompile_model:
|
||||
if len(loaded_networks) > 0 and debug:
|
||||
shared.log.debug('Model Compile: Skipping LoRa loading')
|
||||
return
|
||||
else:
|
||||
recompile_model = True
|
||||
shared.compiled_model_state.lora_model = []
|
||||
if recompile_model:
|
||||
backup_cuda_compile = shared.opts.cuda_compile
|
||||
sd_models.unload_model_weights(op='model')
|
||||
shared.opts.cuda_compile = []
|
||||
sd_models.reload_model_weights(op='model')
|
||||
shared.opts.cuda_compile = backup_cuda_compile
|
||||
return recompile_model
|
||||
|
||||
|
||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||
networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names]
|
||||
if any(x is None for x in networks_on_disk):
|
||||
list_available_networks()
|
||||
networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names]
|
||||
failed_to_load_networks = []
|
||||
recompile_model = maybe_recompile_model(names, te_multipliers)
|
||||
|
||||
loaded_networks.clear()
|
||||
diffuser_loaded.clear()
|
||||
diffuser_scales.clear()
|
||||
timer['load'] = 0
|
||||
t0 = time.time()
|
||||
|
||||
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
||||
net = None
|
||||
if network_on_disk is not None:
|
||||
shorthash = getattr(network_on_disk, 'shorthash', '').lower()
|
||||
if debug:
|
||||
shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" hash="{shorthash}"')
|
||||
try:
|
||||
if recompile_model:
|
||||
shared.compiled_model_state.lora_model.append(f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}")
|
||||
if shared.opts.lora_force_diffusers or network_overrides.check_override(shorthash): # OpenVINO only works with Diffusers LoRa loading
|
||||
net = load_diffusers(name, network_on_disk, lora_scale=te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier)
|
||||
else:
|
||||
net = load_network(name, network_on_disk)
|
||||
if net is not None:
|
||||
net.mentioned_name = name
|
||||
network_on_disk.read_hash()
|
||||
except Exception as e:
|
||||
shared.log.error(f'Load network: type=LoRA file="{network_on_disk.filename}" {e}')
|
||||
if debug:
|
||||
errors.display(e, 'LoRA')
|
||||
continue
|
||||
if net is None:
|
||||
failed_to_load_networks.append(name)
|
||||
shared.log.error(f'Load network: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} failed')
|
||||
continue
|
||||
shared.sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings)
|
||||
net.te_multiplier = te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier
|
||||
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else shared.opts.extra_networks_default_multiplier
|
||||
net.dyn_dim = dyn_dims[i] if dyn_dims else shared.opts.extra_networks_default_multiplier
|
||||
loaded_networks.append(net)
|
||||
|
||||
while len(lora_cache) > shared.opts.lora_in_memory_limit:
|
||||
name = next(iter(lora_cache))
|
||||
lora_cache.pop(name, None)
|
||||
|
||||
if len(diffuser_loaded) > 0:
|
||||
shared.log.debug(f'Load network: type=LoRA loaded={diffuser_loaded} available={shared.sd_model.get_list_adapters()} active={shared.sd_model.get_active_adapters()} scales={diffuser_scales}')
|
||||
try:
|
||||
shared.sd_model.set_adapters(adapter_names=diffuser_loaded, adapter_weights=diffuser_scales)
|
||||
if shared.opts.lora_fuse_diffusers:
|
||||
shared.sd_model.fuse_lora(adapter_names=diffuser_loaded, lora_scale=1.0, fuse_unet=True, fuse_text_encoder=True) # fuse uses fixed scale since later apply does the scaling
|
||||
shared.sd_model.unload_lora_weights()
|
||||
except Exception as e:
|
||||
shared.log.error(f'Load network: type=LoRA {e}')
|
||||
if debug:
|
||||
errors.display(e, 'LoRA')
|
||||
|
||||
if len(loaded_networks) > 0 and debug:
|
||||
shared.log.debug(f'Load network: type=LoRA loaded={len(loaded_networks)} cache={list(lora_cache)}')
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
if recompile_model:
|
||||
shared.log.info("Load network: type=LoRA recompiling model")
|
||||
backup_lora_model = shared.compiled_model_state.lora_model
|
||||
if 'Model' in shared.opts.cuda_compile:
|
||||
shared.sd_model = sd_models_compile.compile_diffusers(shared.sd_model)
|
||||
|
||||
shared.compiled_model_state.lora_model = backup_lora_model
|
||||
if shared.opts.diffusers_offload_mode == "balanced":
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
t1 = time.time()
|
||||
timer['load'] += t1 - t0
|
||||
|
||||
def set_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown, ex_bias):
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
if weights_backup is None and bias_backup is None:
|
||||
return
|
||||
device = self.weight.device
|
||||
with devices.inference_context():
|
||||
if weights_backup is not None:
|
||||
if updown is not None:
|
||||
if len(weights_backup.shape) == 4 and weights_backup.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable
|
||||
weights_backup = weights_backup.clone().to(device)
|
||||
weights_backup += updown.to(weights_backup)
|
||||
if getattr(self, "quant_type", None) in ['nf4', 'fp4']:
|
||||
bnb = model_quant.load_bnb('Load network: type=LoRA', silent=True)
|
||||
if bnb is not None:
|
||||
self.weight = bnb.nn.Params4bit(weights_backup, quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize)
|
||||
else:
|
||||
self.weight.copy_(weights_backup, non_blocking=True)
|
||||
else:
|
||||
self.weight.copy_(weights_backup, non_blocking=True)
|
||||
if hasattr(self, "qweight") and hasattr(self, "freeze"):
|
||||
self.freeze()
|
||||
if bias_backup is not None:
|
||||
if ex_bias is not None:
|
||||
bias_backup = bias_backup.clone() + ex_bias.to(weights_backup)
|
||||
self.bias.copy_(bias_backup)
|
||||
else:
|
||||
self.bias = None
|
||||
self.to(device)
|
||||
|
||||
|
||||
def maybe_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], wanted_names): # pylint: disable=W0613
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
if weights_backup is None and wanted_names != (): # pylint: disable=C1803
|
||||
if getattr(self.weight, "quant_type", None) in ['nf4', 'fp4']:
|
||||
bnb = model_quant.load_bnb('Load network: type=LoRA', silent=True)
|
||||
if bnb is not None:
|
||||
with devices.inference_context():
|
||||
weights_backup = bnb.functional.dequantize_4bit(self.weight, quant_state=self.weight.quant_state, quant_type=self.weight.quant_type, blocksize=self.weight.blocksize,)
|
||||
self.quant_state = self.weight.quant_state
|
||||
self.quant_type = self.weight.quant_type
|
||||
self.blocksize = self.weight.blocksize
|
||||
else:
|
||||
weights_backup = self.weight.clone()
|
||||
else:
|
||||
weights_backup = self.weight.clone()
|
||||
if shared.opts.lora_offload_backup and weights_backup is not None:
|
||||
weights_backup = weights_backup.to(devices.cpu)
|
||||
self.network_weights_backup = weights_backup
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
if bias_backup is None:
|
||||
if getattr(self, 'bias', None) is not None:
|
||||
bias_backup = self.bias.clone()
|
||||
else:
|
||||
bias_backup = None
|
||||
if shared.opts.lora_offload_backup and bias_backup is not None:
|
||||
bias_backup = bias_backup.to(devices.cpu)
|
||||
self.network_bias_backup = bias_backup
|
||||
|
||||
|
||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv]):
|
||||
"""
|
||||
Applies the currently selected set of networks to the weights of torch layer self.
|
||||
If weights already have this particular set of networks applied, does nothing.
|
||||
If not, restores orginal weights from backup and alters weights according to networks.
|
||||
"""
|
||||
network_layer_name = getattr(self, 'network_layer_name', None)
|
||||
if network_layer_name is None:
|
||||
return
|
||||
t0 = time.time()
|
||||
current_names = getattr(self, "network_current_names", ())
|
||||
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
||||
if any([net.modules.get(network_layer_name, None) for net in loaded_networks]): # noqa: C419 # pylint: disable=R1729
|
||||
maybe_backup_weights(self, wanted_names)
|
||||
if current_names != wanted_names:
|
||||
for net in loaded_networks:
|
||||
# default workflow where module is known and has weights
|
||||
module = net.modules.get(network_layer_name, None)
|
||||
if module is not None and hasattr(self, 'weight'):
|
||||
try:
|
||||
with devices.inference_context():
|
||||
weight = self.weight # calculate quant weights once
|
||||
updown, ex_bias = module.calc_updown(weight)
|
||||
set_weights(self, updown, ex_bias)
|
||||
except RuntimeError as e:
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
if debug:
|
||||
module_name = net.modules.get(network_layer_name, None)
|
||||
shared.log.error(f'LoRA apply weight name="{net.name}" module="{module_name}" layer="{network_layer_name}" {e}')
|
||||
errors.display(e, 'LoRA')
|
||||
raise RuntimeError('LoRA apply weight') from e
|
||||
continue
|
||||
if module is None:
|
||||
continue
|
||||
shared.log.warning(f'LoRA network="{net.name}" layer="{network_layer_name}" unsupported operation')
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
if not loaded_networks: # restore from backup
|
||||
t5 = time.time()
|
||||
set_weights(self, None, None)
|
||||
self.network_current_names = wanted_names
|
||||
t1 = time.time()
|
||||
timer['apply'] += t1 - t0
|
||||
|
||||
def network_load():
|
||||
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
|
||||
for component_name in ['text_encoder','text_encoder_2', 'unet', 'transformer']:
|
||||
component = getattr(sd_model, component_name, None)
|
||||
if component is not None:
|
||||
for _, module in component.named_modules():
|
||||
network_apply_weights(module)
|
||||
|
||||
|
||||
def list_available_networks():
|
||||
t0 = time.time()
|
||||
available_networks.clear()
|
||||
available_network_aliases.clear()
|
||||
forbidden_network_aliases.clear()
|
||||
available_network_hash_lookup.clear()
|
||||
forbidden_network_aliases.update({"none": 1, "Addams": 1})
|
||||
if not os.path.exists(shared.cmd_opts.lora_dir):
|
||||
shared.log.warning(f'LoRA directory not found: path="{shared.cmd_opts.lora_dir}"')
|
||||
|
||||
def add_network(filename):
|
||||
if not os.path.isfile(filename):
|
||||
return
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
name = name.replace('.', '_')
|
||||
try:
|
||||
entry = network.NetworkOnDisk(name, filename)
|
||||
available_networks[entry.name] = entry
|
||||
if entry.alias in available_network_aliases:
|
||||
forbidden_network_aliases[entry.alias.lower()] = 1
|
||||
if shared.opts.lora_preferred_name == 'filename':
|
||||
available_network_aliases[entry.name] = entry
|
||||
else:
|
||||
available_network_aliases[entry.alias] = entry
|
||||
if entry.shorthash:
|
||||
available_network_hash_lookup[entry.shorthash] = entry
|
||||
except OSError as e: # should catch FileNotFoundError and PermissionError etc.
|
||||
shared.log.error(f'LoRA: filename="{filename}" {e}')
|
||||
|
||||
candidates = list(files_cache.list_files(shared.cmd_opts.lora_dir, ext_filter=[".pt", ".ckpt", ".safetensors"]))
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
|
||||
for fn in candidates:
|
||||
executor.submit(add_network, fn)
|
||||
t1 = time.time()
|
||||
shared.log.info(f'Available LoRAs: path="{shared.cmd_opts.lora_dir}" items={len(available_networks)} folders={len(forbidden_network_aliases)} time={t1 - t0:.2f}')
|
||||
|
||||
|
||||
def infotext_pasted(infotext, params): # pylint: disable=W0613
|
||||
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||
added = []
|
||||
for k in params:
|
||||
if not k.startswith("AddNet Model "):
|
||||
continue
|
||||
num = k[13:]
|
||||
if params.get("AddNet Module " + num) != "LoRA":
|
||||
continue
|
||||
name = params.get("AddNet Model " + num)
|
||||
if name is None:
|
||||
continue
|
||||
m = re_network_name.match(name)
|
||||
if m:
|
||||
name = m.group(1)
|
||||
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
||||
added.append(f"<lora:{name}:{multiplier}>")
|
||||
if added:
|
||||
params["Prompt"] += "\n" + "".join(added)
|
||||
|
||||
|
||||
list_available_networks()
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
import os
|
||||
import json
|
||||
import concurrent
|
||||
import modules.lora.networks as networks
|
||||
from modules import shared, ui_extra_networks
|
||||
|
||||
|
||||
debug = os.environ.get('SD_LORA_DEBUG', None) is not None
|
||||
|
||||
|
||||
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Lora')
|
||||
self.list_time = 0
|
||||
|
||||
def refresh(self):
|
||||
networks.list_available_networks()
|
||||
|
||||
@staticmethod
|
||||
def get_tags(l, info):
|
||||
tags = {}
|
||||
try:
|
||||
if l.metadata is not None:
|
||||
modelspec_tags = l.metadata.get('modelspec.tags', {})
|
||||
possible_tags = l.metadata.get('ss_tag_frequency', {}) # tags from model metedata
|
||||
if isinstance(possible_tags, str):
|
||||
possible_tags = {}
|
||||
if isinstance(modelspec_tags, str):
|
||||
modelspec_tags = {}
|
||||
if len(list(modelspec_tags)) > 0:
|
||||
possible_tags.update(modelspec_tags)
|
||||
for k, v in possible_tags.items():
|
||||
words = k.split('_', 1) if '_' in k else [v, k]
|
||||
words = [str(w).replace('.json', '') for w in words]
|
||||
if words[0] == '{}':
|
||||
words[0] = 0
|
||||
tag = ' '.join(words[1:]).lower()
|
||||
tags[tag] = words[0]
|
||||
|
||||
def find_version():
|
||||
found_versions = []
|
||||
current_hash = l.hash[:8].upper()
|
||||
all_versions = info.get('modelVersions', [])
|
||||
for v in info.get('modelVersions', []):
|
||||
for f in v.get('files', []):
|
||||
if any(h.startswith(current_hash) for h in f.get('hashes', {}).values()):
|
||||
found_versions.append(v)
|
||||
if len(found_versions) == 0:
|
||||
found_versions = all_versions
|
||||
return found_versions
|
||||
|
||||
for v in find_version(): # trigger words from info json
|
||||
possible_tags = v.get('trainedWords', [])
|
||||
if isinstance(possible_tags, list):
|
||||
for tag_str in possible_tags:
|
||||
for tag in tag_str.split(','):
|
||||
tag = tag.strip().lower()
|
||||
if tag not in tags:
|
||||
tags[tag] = 0
|
||||
|
||||
possible_tags = info.get('tags', []) # tags from info json
|
||||
if not isinstance(possible_tags, list):
|
||||
possible_tags = list(possible_tags.values())
|
||||
for tag in possible_tags:
|
||||
tag = tag.strip().lower()
|
||||
if tag not in tags:
|
||||
tags[tag] = 0
|
||||
except Exception:
|
||||
pass
|
||||
bad_chars = [';', ':', '<', ">", "*", '?', '\'', '\"', '(', ')', '[', ']', '{', '}', '\\', '/']
|
||||
clean_tags = {}
|
||||
for k, v in tags.items():
|
||||
tag = ''.join(i for i in k if i not in bad_chars).strip()
|
||||
clean_tags[tag] = v
|
||||
|
||||
clean_tags.pop('img', None)
|
||||
clean_tags.pop('dataset', None)
|
||||
return clean_tags
|
||||
|
||||
def create_item(self, name):
|
||||
l = networks.available_networks.get(name)
|
||||
if l is None:
|
||||
shared.log.warning(f'Networks: type=lora registered={len(list(networks.available_networks))} file="{name}" not registered')
|
||||
return None
|
||||
try:
|
||||
# path, _ext = os.path.splitext(l.filename)
|
||||
name = os.path.splitext(os.path.relpath(l.filename, shared.cmd_opts.lora_dir))[0]
|
||||
item = {
|
||||
"type": 'Lora',
|
||||
"name": name,
|
||||
"filename": l.filename,
|
||||
"hash": l.shorthash,
|
||||
"prompt": json.dumps(f" <lora:{l.get_alias()}:{shared.opts.extra_networks_default_multiplier}>"),
|
||||
"metadata": json.dumps(l.metadata, indent=4) if l.metadata else None,
|
||||
"mtime": os.path.getmtime(l.filename),
|
||||
"size": os.path.getsize(l.filename),
|
||||
"version": l.sd_version,
|
||||
}
|
||||
info = self.find_info(l.filename)
|
||||
item["info"] = info
|
||||
item["description"] = self.find_description(l.filename, info) # use existing info instead of double-read
|
||||
item["tags"] = self.get_tags(l, info)
|
||||
return item
|
||||
except Exception as e:
|
||||
shared.log.error(f'Networks: type=lora file="{name}" {e}')
|
||||
if debug:
|
||||
from modules import errors
|
||||
errors.display(e, 'Lora')
|
||||
return None
|
||||
|
||||
def list_items(self):
|
||||
items = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
|
||||
future_items = {executor.submit(self.create_item, net): net for net in networks.available_networks}
|
||||
for future in concurrent.futures.as_completed(future_items):
|
||||
item = future.result()
|
||||
if item is not None:
|
||||
items.append(item)
|
||||
self.update_all_previews(items)
|
||||
return items
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir]
|
||||
|
|
@ -8,6 +8,8 @@ from modules import shared, devices, processing, sd_models, errors, sd_hijack_hy
|
|||
from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled
|
||||
from modules.processing_args import set_pipeline_args
|
||||
from modules.onnx_impl import preprocess_pipeline as preprocess_onnx_pipeline, check_parameters_changed as olive_check_parameters_changed
|
||||
from modules.lora.networks import network_load
|
||||
from modules.lora.networks import timer as network_timer
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
|
@ -424,6 +426,9 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
|
|||
p.prompts = p.all_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
|
||||
if p.negative_prompts is None or len(p.negative_prompts) == 0:
|
||||
p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
|
||||
network_timer['apply'] = 0
|
||||
network_timer['restore'] = 0
|
||||
network_load()
|
||||
|
||||
sd_models.move_model(shared.sd_model, devices.device)
|
||||
sd_models_compile.openvino_recompile_model(p, hires=False, refiner=False) # recompile if a parameter changes
|
||||
|
|
|
|||
|
|
@ -908,6 +908,7 @@ options_templates.update(options_section(('extra_networks', "Networks"), {
|
|||
"lora_in_memory_limit": OptionInfo(0, "LoRA memory cache", gr.Slider, {"minimum": 0, "maximum": 24, "step": 1}),
|
||||
"lora_quant": OptionInfo("NF4","LoRA precision in quantized models", gr.Radio, {"choices": ["NF4", "FP4"]}),
|
||||
"lora_load_gpu": OptionInfo(True if not cmd_opts.lowvram else False, "Load LoRA directly to GPU"),
|
||||
"lora_offload_backup": OptionInfo(True, "Offload LoRA Backup Weights"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section((None, "Internal options"), {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
import re
|
||||
import modules.lora.networks as networks
|
||||
from modules.lora.lora_extract import create_ui
|
||||
from modules.lora.network import NetworkOnDisk
|
||||
from modules.lora.ui_extra_networks_lora import ExtraNetworksPageLora
|
||||
from modules.lora.extra_networks_lora import ExtraNetworkLora
|
||||
from modules import script_callbacks, extra_networks, ui_extra_networks, ui_models, shared # pylint: disable=unused-import
|
||||
|
||||
|
||||
re_lora = re.compile("<lora:([^:]+):")
|
||||
|
||||
|
||||
def before_ui():
|
||||
ui_extra_networks.register_page(ExtraNetworksPageLora())
|
||||
networks.extra_network_lora = ExtraNetworkLora()
|
||||
extra_networks.register_extra_network(networks.extra_network_lora)
|
||||
ui_models.extra_ui.append(create_ui)
|
||||
|
||||
|
||||
def create_lora_json(obj: NetworkOnDisk):
|
||||
return {
|
||||
"name": obj.name,
|
||||
"alias": obj.alias,
|
||||
"path": obj.filename,
|
||||
"metadata": obj.metadata,
|
||||
}
|
||||
|
||||
|
||||
def api_networks(_, app):
|
||||
@app.get("/sdapi/v1/loras")
|
||||
async def get_loras():
|
||||
return [create_lora_json(obj) for obj in networks.available_networks.values()]
|
||||
|
||||
@app.post("/sdapi/v1/refresh-loras")
|
||||
async def refresh_loras():
|
||||
return networks.list_available_networks()
|
||||
|
||||
|
||||
def infotext_pasted(infotext, d): # pylint: disable=unused-argument
|
||||
hashes = d.get("Lora hashes", None)
|
||||
if hashes is None:
|
||||
return
|
||||
|
||||
def network_replacement(m):
|
||||
alias = m.group(1)
|
||||
shorthash = hashes.get(alias)
|
||||
if shorthash is None:
|
||||
return m.group(0)
|
||||
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
|
||||
if network_on_disk is None:
|
||||
return m.group(0)
|
||||
return f'<lora:{network_on_disk.get_alias()}:'
|
||||
|
||||
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
|
||||
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
|
||||
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
|
||||
|
||||
if shared.native:
|
||||
script_callbacks.on_app_started(api_networks)
|
||||
script_callbacks.on_before_ui(before_ui)
|
||||
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
|
||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||
Loading…
Reference in New Issue