mirror of https://github.com/vladmandic/automatic
new lora handler and remove lyco
parent
7dc098182a
commit
3727bf3d02
16
CHANGELOG.md
16
CHANGELOG.md
|
|
@ -45,6 +45,9 @@ Upgrades are still possible and supported, but above is recommended for best exp
|
|||
can be disabled in *settings -> extra networks -> show built-in*
|
||||
- **VAE**
|
||||
- VAEs are now also listed as part of extra networks
|
||||
- **LoRA**
|
||||
- LoRAs are now automatically filtered based on compatibility with currently loaded model
|
||||
note that if lora type cannot be auto-determined, it will be left in the list
|
||||
- **Refiner**
|
||||
- you can load model from extra networks as base model or as refiner
|
||||
simply select button in top-right of models page
|
||||
|
|
@ -91,6 +94,10 @@ Upgrades are still possible and supported, but above is recommended for best exp
|
|||
- fix long outstanding memory leak in legacy code, amazing this went undetected for so long
|
||||
- more high quality upscalers available by default
|
||||
**SwinIR** (2), **ESRGAN** (12), **RealESRGAN** (6), **SCUNet** (2)
|
||||
- if that is not enough, there is new **chaiNNer** integration:
|
||||
adds 15 more upscalers from different families out-of-the-box:
|
||||
**HAT** (6), **RealHAT** (2), **DAT** (1), **RRDBNet** (1), **SPSRNet** (1), **SRFormer** (2), **SwiftSR** (2)
|
||||
and yes, you can download and add your own, just place them in `models/chaiNNer`
|
||||
- two additional latent upscalers based on SD upscale models when using Diffusers backend
|
||||
**SD Upscale 2x**, **SD Upscale 4x***
|
||||
note: Recommended usage for *SD Upscale* is by using second pass instead of upscaler
|
||||
|
|
@ -102,12 +109,6 @@ Upgrades are still possible and supported, but above is recommended for best exp
|
|||
simply set *denoising strength* to 0 so hires does not get triggered
|
||||
- unified init/download/execute/progress code
|
||||
- easier installation
|
||||
- and if that is not enough, install extension:
|
||||
<https://github.com/vladmandic/sd-extension-chainner>
|
||||
and it will add 15 more upscalers from different families:
|
||||
**HAT** (6), **RealHAT** (2), **DAT** (1), **RRDBNet** (1), **SPSRNet** (1), **SRFormer** (2), **SwiftSR** (2)
|
||||
and yes, you can download and add your own, just place them in `models/chaiNNer`
|
||||
note: extension will probably be added to default built-in list in the near-future
|
||||
- **Samplers**:
|
||||
- moved ui options to submenu
|
||||
- default list for new installs is now all samplers, list can be modified in settings
|
||||
|
|
@ -128,6 +129,9 @@ Upgrades are still possible and supported, but above is recommended for best exp
|
|||
description file present in format of *[model].txt*
|
||||
- to enable search, make sure all models have set hash values
|
||||
*Models -> Valida -> Calculate hashes*
|
||||
- **LoRA**
|
||||
- for *backend:original*, lyco handler has been removed and replaced with new
|
||||
unified lora/lyco handler that supports all variants of loras
|
||||
- **Compute**
|
||||
- **Intel Arc/IPEX**:
|
||||
- tons of optimizations, built-in binary wheels for Windows
|
||||
|
|
|
|||
|
|
@ -1,40 +1,53 @@
|
|||
import lora
|
||||
import networks
|
||||
from modules import extra_networks, shared
|
||||
|
||||
|
||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
def __init__(self):
|
||||
super().__init__('lora')
|
||||
self.errors = {}
|
||||
"""mapping of network names to the number of errors the network had during operation"""
|
||||
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_lora
|
||||
|
||||
if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
self.errors.clear()
|
||||
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
|
||||
names = []
|
||||
multipliers = []
|
||||
te_multipliers = []
|
||||
unet_multipliers = []
|
||||
dyn_dims = []
|
||||
for params in params_list:
|
||||
assert len(params.items) > 0
|
||||
names.append(params.items[0])
|
||||
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
||||
|
||||
lora.load_loras(names, multipliers)
|
||||
|
||||
assert params.items
|
||||
names.append(params.positional[0])
|
||||
te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
|
||||
te_multiplier = float(params.named.get("te", te_multiplier))
|
||||
unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier
|
||||
unet_multiplier = float(params.named.get("unet", unet_multiplier))
|
||||
dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
|
||||
dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
|
||||
te_multipliers.append(te_multiplier)
|
||||
unet_multipliers.append(unet_multiplier)
|
||||
dyn_dims.append(dyn_dim)
|
||||
networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
|
||||
if shared.opts.lora_add_hashes_to_infotext:
|
||||
lora_hashes = []
|
||||
for item in lora.loaded_loras:
|
||||
shorthash = item.lora_on_disk.shorthash
|
||||
network_hashes = []
|
||||
for item in networks.loaded_networks:
|
||||
shorthash = item.network_on_disk.shorthash
|
||||
if not shorthash:
|
||||
continue
|
||||
alias = item.mentioned_name
|
||||
if not alias:
|
||||
continue
|
||||
alias = alias.replace(":", "").replace(",", "")
|
||||
lora_hashes.append(f"{alias}: {shorthash}")
|
||||
if lora_hashes:
|
||||
p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
|
||||
network_hashes.append(f"{alias}: {shorthash}")
|
||||
if network_hashes:
|
||||
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
||||
|
||||
def deactivate(self, p):
|
||||
pass
|
||||
if self.errors:
|
||||
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
|
||||
for k, v in self.errors.items():
|
||||
shared.log.error(f'LoRA errors: file="{k}" errors={v}')
|
||||
self.errors.clear()
|
||||
|
|
|
|||
|
|
@ -1,499 +1,8 @@
|
|||
import os
|
||||
import re
|
||||
from typing import Union
|
||||
import torch
|
||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes
|
||||
from modules.modelloader import directory_files, extension_filter
|
||||
|
||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
||||
re_compiled = {}
|
||||
|
||||
suffix_conversion = {
|
||||
"attentions": {},
|
||||
"resnets": {
|
||||
"conv1": "in_layers_2",
|
||||
"conv2": "out_layers_3",
|
||||
"time_emb_proj": "emb_layers_1",
|
||||
"conv_shortcut": "skip_connection",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def convert_diffusers_name_to_compvis(key, is_sd2):
|
||||
def match(match_list, regex_text):
|
||||
regex = re_compiled.get(regex_text)
|
||||
if regex is None:
|
||||
regex = re.compile(regex_text)
|
||||
re_compiled[regex_text] = regex
|
||||
|
||||
r = re.match(regex, key)
|
||||
if not r:
|
||||
return False
|
||||
|
||||
match_list.clear()
|
||||
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
||||
return True
|
||||
|
||||
m = []
|
||||
|
||||
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
|
||||
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
|
||||
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
|
||||
|
||||
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
|
||||
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
|
||||
|
||||
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
|
||||
if is_sd2:
|
||||
if 'mlp_fc1' in m[1]:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||
elif 'mlp_fc2' in m[1]:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||
else:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||
|
||||
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||
|
||||
return key
|
||||
|
||||
|
||||
class LoraOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
self.name = name
|
||||
self.filename = filename
|
||||
self.metadata = {}
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
if self.is_safetensors:
|
||||
try:
|
||||
self.metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading lora metadata: {filename}")
|
||||
if self.metadata:
|
||||
m = {}
|
||||
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||
m[k] = v
|
||||
self.metadata = m
|
||||
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
||||
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||
self.hash = None
|
||||
self.shorthash = None
|
||||
self.set_hash(self.metadata.get('sshs_model_hash') or (hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors)) or '')
|
||||
|
||||
def set_hash(self, v):
|
||||
self.hash = v
|
||||
self.shorthash = self.hash[0:10]
|
||||
if self.shorthash:
|
||||
available_lora_hash_lookup[self.shorthash] = self
|
||||
|
||||
def read_hash(self):
|
||||
if not self.hash:
|
||||
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
||||
|
||||
def get_alias(self):
|
||||
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
|
||||
return self.name
|
||||
else:
|
||||
return self.alias
|
||||
|
||||
|
||||
class LoraModule:
|
||||
def __init__(self, name, lora_on_disk: LoraOnDisk):
|
||||
self.name = name
|
||||
self.lora_on_disk = lora_on_disk
|
||||
self.multiplier = 1.0
|
||||
self.modules = {}
|
||||
self.mtime = None
|
||||
self.mentioned_name = None
|
||||
"""the text that was used to add lora to prompt - can be either name or an alias"""
|
||||
|
||||
|
||||
class LoraUpDownModule:
|
||||
def __init__(self):
|
||||
self.up = None
|
||||
self.down = None
|
||||
self.alpha = None
|
||||
|
||||
|
||||
def assign_lora_names_to_compvis_modules(sd_model):
|
||||
lora_layer_mapping = {}
|
||||
if not hasattr(shared.sd_model, 'cond_stage_model'):
|
||||
return
|
||||
|
||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||
lora_name = name.replace(".", "_")
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
|
||||
for name, module in shared.sd_model.model.named_modules():
|
||||
lora_name = name.replace(".", "_")
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
|
||||
sd_model.lora_layer_mapping = lora_layer_mapping
|
||||
|
||||
def load_diffuser_lora(name, lora_on_disk, multiplier, num_loras):
|
||||
lora = LoraModule(name, lora_on_disk)
|
||||
lora.mtime = os.path.getmtime(lora_on_disk.filename)
|
||||
from modules.lora_diffusers import load_diffusers_lora
|
||||
load_diffusers_lora(name, lora_on_disk, multiplier, num_loras)
|
||||
return lora
|
||||
|
||||
|
||||
def load_lora(name, lora_on_disk):
|
||||
lora = LoraModule(name, lora_on_disk)
|
||||
lora.mtime = os.path.getmtime(lora_on_disk.filename)
|
||||
|
||||
sd = sd_models.read_state_dict(lora_on_disk.filename)
|
||||
|
||||
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
|
||||
if not hasattr(shared.sd_model, 'lora_layer_mapping'):
|
||||
assign_lora_names_to_compvis_modules(shared.sd_model)
|
||||
|
||||
keys_failed_to_match = {}
|
||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
||||
|
||||
for key_diffusers, weight in sd.items():
|
||||
key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
|
||||
key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
|
||||
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||
|
||||
if sd_module is None:
|
||||
m = re_x_proj.match(key)
|
||||
if m:
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
|
||||
|
||||
if sd_module is None:
|
||||
keys_failed_to_match[key_diffusers] = key
|
||||
continue
|
||||
|
||||
lora_module = lora.modules.get(key, None)
|
||||
if lora_module is None:
|
||||
lora_module = LoraUpDownModule()
|
||||
lora.modules[key] = lora_module
|
||||
|
||||
if lora_key == "alpha":
|
||||
lora_module.alpha = weight.item()
|
||||
continue
|
||||
|
||||
if type(sd_module) == torch.nn.Linear:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(sd_module) == torch.nn.MultiheadAttention:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
||||
else:
|
||||
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||
continue
|
||||
|
||||
with torch.no_grad():
|
||||
module.weight.copy_(weight)
|
||||
|
||||
module.to(device=devices.cpu, dtype=devices.dtype)
|
||||
|
||||
if lora_key == "lora_up.weight":
|
||||
lora_module.up = module
|
||||
elif lora_key == "lora_down.weight":
|
||||
lora_module.down = module
|
||||
else:
|
||||
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
|
||||
|
||||
if len(keys_failed_to_match) > 0:
|
||||
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
|
||||
|
||||
return lora
|
||||
|
||||
|
||||
def load_loras(names, multipliers=None):
|
||||
already_loaded = {}
|
||||
for lora in loaded_loras:
|
||||
if lora.name in names:
|
||||
already_loaded[lora.name] = lora
|
||||
loaded_loras.clear()
|
||||
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||
if any(x is None for x in loras_on_disk):
|
||||
list_available_loras()
|
||||
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||
failed_to_load_loras = []
|
||||
recompile_model = False
|
||||
if shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx":
|
||||
if len(names) == len(shared.compiled_model_state.lora_model):
|
||||
for i, name in enumerate(names):
|
||||
if shared.compiled_model_state.lora_model[i] != f"{name}:{multipliers[i]}":
|
||||
recompile_model = True
|
||||
break
|
||||
else:
|
||||
recompile_model = True
|
||||
shared.compiled_model_state.lora_model = []
|
||||
if recompile_model:
|
||||
sd_models.unload_model_weights(op='model')
|
||||
shared.opts.cuda_compile = False
|
||||
sd_models.reload_model_weights(op='model')
|
||||
shared.opts.cuda_compile = True
|
||||
|
||||
for i, name in enumerate(names):
|
||||
lora = already_loaded.get(name, None) if shared.backend == shared.Backend.ORIGINAL else None
|
||||
lora_on_disk = loras_on_disk[i]
|
||||
if lora_on_disk is not None:
|
||||
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
||||
try:
|
||||
if shared.backend == shared.Backend.DIFFUSERS:
|
||||
lora = load_diffuser_lora(name, lora_on_disk, multipliers[i] if multipliers else 1.0, len(names))
|
||||
else:
|
||||
lora = load_lora(name, lora_on_disk)
|
||||
except Exception as e:
|
||||
errors.display(e, f"loading Lora {lora_on_disk.filename}")
|
||||
continue
|
||||
lora.mentioned_name = name
|
||||
lora_on_disk.read_hash()
|
||||
if lora is None:
|
||||
failed_to_load_loras.append(name)
|
||||
print(f"Couldn't find Lora with name {name}")
|
||||
continue
|
||||
lora.multiplier = multipliers[i] if multipliers else 1.0
|
||||
loaded_loras.append(lora)
|
||||
|
||||
if len(failed_to_load_loras) > 0:
|
||||
sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
|
||||
|
||||
if recompile_model:
|
||||
shared.log.info("Lora: Recompiling model")
|
||||
sd_models.compile_diffusers(shared.sd_model)
|
||||
|
||||
|
||||
def lora_calc_updown(lora, module, target):
|
||||
with torch.no_grad():
|
||||
up = module.up.weight.to(target.device, dtype=target.dtype)
|
||||
down = module.down.weight.to(target.device, dtype=target.dtype)
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||
else:
|
||||
updown = up @ down
|
||||
updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
return updown
|
||||
|
||||
|
||||
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||
weights_backup = getattr(self, "lora_weights_backup", None)
|
||||
if weights_backup is None:
|
||||
return
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.in_proj_weight.copy_(weights_backup[0])
|
||||
self.out_proj.weight.copy_(weights_backup[1])
|
||||
else:
|
||||
self.weight.copy_(weights_backup)
|
||||
|
||||
|
||||
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||
"""
|
||||
Applies the currently selected set of Loras to the weights of torch layer self.
|
||||
If weights already have this particular set of loras applied, does nothing.
|
||||
If not, restores orginal weights from backup and alters weights according to loras.
|
||||
"""
|
||||
|
||||
lora_layer_name = getattr(self, 'lora_layer_name', None)
|
||||
if lora_layer_name is None:
|
||||
return
|
||||
|
||||
current_names = getattr(self, "lora_current_names", ())
|
||||
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
|
||||
|
||||
weights_backup = getattr(self, "lora_weights_backup", None)
|
||||
if weights_backup is None:
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
||||
else:
|
||||
weights_backup = self.weight.to(devices.cpu, copy=True)
|
||||
|
||||
self.lora_weights_backup = weights_backup
|
||||
|
||||
if current_names != wanted_names:
|
||||
lora_restore_weights_from_backup(self)
|
||||
|
||||
for lora in loaded_loras:
|
||||
module = lora.modules.get(lora_layer_name, None)
|
||||
if module is not None and hasattr(self, 'weight'):
|
||||
self.weight += lora_calc_updown(lora, module, self.weight)
|
||||
continue
|
||||
|
||||
module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
|
||||
module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
|
||||
module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
|
||||
module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||
updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
|
||||
updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
|
||||
updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
|
||||
self.in_proj_weight += updown_qkv
|
||||
self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
|
||||
continue
|
||||
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
||||
|
||||
self.lora_current_names = wanted_names
|
||||
|
||||
|
||||
def lora_forward(module, input, original_forward):
|
||||
"""
|
||||
Old way of applying Lora by executing operations during layer's forward.
|
||||
Stacking many loras this way results in big performance degradation.
|
||||
"""
|
||||
|
||||
if len(loaded_loras) == 0:
|
||||
return original_forward(module, input)
|
||||
|
||||
input = devices.cond_cast_unet(input)
|
||||
|
||||
lora_restore_weights_from_backup(module)
|
||||
lora_reset_cached_weight(module)
|
||||
|
||||
res = original_forward(module, input)
|
||||
|
||||
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
||||
for lora in loaded_loras:
|
||||
module = lora.modules.get(lora_layer_name, None)
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
module.up.to(device=devices.device)
|
||||
module.down.to(device=devices.device)
|
||||
|
||||
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||
self.lora_current_names = ()
|
||||
self.lora_weights_backup = None
|
||||
|
||||
|
||||
def lora_Linear_forward(self, input):
|
||||
if shared.opts.lora_functional:
|
||||
return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
||||
|
||||
lora_apply_weights(self)
|
||||
|
||||
return torch.nn.Linear_forward_before_lora(self, input)
|
||||
|
||||
|
||||
def lora_Linear_load_state_dict(self, *args, **kwargs):
|
||||
lora_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
|
||||
|
||||
|
||||
def lora_Conv2d_forward(self, input):
|
||||
if shared.opts.lora_functional:
|
||||
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
|
||||
|
||||
lora_apply_weights(self)
|
||||
|
||||
return torch.nn.Conv2d_forward_before_lora(self, input)
|
||||
|
||||
|
||||
def lora_Conv2d_load_state_dict(self, *args, **kwargs):
|
||||
lora_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
|
||||
|
||||
|
||||
def lora_MultiheadAttention_forward(self, *args, **kwargs):
|
||||
lora_apply_weights(self)
|
||||
|
||||
return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
|
||||
|
||||
|
||||
def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
||||
lora_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
|
||||
|
||||
|
||||
def list_available_loras():
|
||||
from modules.paths_internal import script_path
|
||||
available_loras.clear()
|
||||
available_lora_aliases.clear()
|
||||
forbidden_lora_aliases.clear()
|
||||
available_lora_hash_lookup.clear()
|
||||
forbidden_lora_aliases.update({"none": 1, "Addams": 1})
|
||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||
for filename in sorted([*filter(extension_filter(['.PT', '.CKPT', '.SAFETENSORS']), directory_files(shared.cmd_opts.lora_dir))], key=str.lower):
|
||||
if filename.startswith(script_path):
|
||||
filename = os.path.relpath(filename, script_path)
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
entry = LoraOnDisk(name, filename)
|
||||
available_loras[name] = entry
|
||||
if entry.alias in available_lora_aliases:
|
||||
forbidden_lora_aliases[entry.alias.lower()] = 1
|
||||
available_lora_aliases[name] = entry
|
||||
available_lora_aliases[entry.alias] = entry
|
||||
|
||||
|
||||
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||
|
||||
|
||||
def infotext_pasted(infotext, params):
|
||||
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||
|
||||
added = []
|
||||
|
||||
for k in params:
|
||||
if not k.startswith("AddNet Model "):
|
||||
continue
|
||||
|
||||
num = k[13:]
|
||||
|
||||
if params.get("AddNet Module " + num) != "LoRA":
|
||||
continue
|
||||
|
||||
name = params.get("AddNet Model " + num)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
m = re_lora_name.match(name)
|
||||
if m:
|
||||
name = m.group(1)
|
||||
|
||||
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
||||
|
||||
added.append(f"<lora:{name}:{multiplier}>")
|
||||
|
||||
if added:
|
||||
params["Prompt"] += "\n" + "".join(added)
|
||||
|
||||
|
||||
available_loras = {}
|
||||
available_lora_aliases = {}
|
||||
available_lora_hash_lookup = {}
|
||||
forbidden_lora_aliases = {}
|
||||
loaded_loras = []
|
||||
|
||||
list_available_loras()
|
||||
import networks
|
||||
|
||||
list_available_loras = networks.list_available_networks
|
||||
available_loras = networks.available_networks
|
||||
available_lora_aliases = networks.available_network_aliases
|
||||
available_lora_hash_lookup = networks.available_network_hash_lookup
|
||||
forbidden_lora_aliases = networks.forbidden_network_aliases
|
||||
loaded_loras = networks.loaded_networks
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
import torch
|
||||
import networks
|
||||
from modules import patches
|
||||
|
||||
|
||||
class LoraPatches:
|
||||
def __init__(self):
|
||||
self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
|
||||
self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
|
||||
self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
|
||||
self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
|
||||
self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
|
||||
self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
|
||||
self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
|
||||
self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
|
||||
self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
|
||||
self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
|
||||
|
||||
def undo(self):
|
||||
self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') # pylint: disable=E1128
|
||||
self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict') # pylint: disable=E1128
|
||||
self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward') # pylint: disable=E1128
|
||||
self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict') # pylint: disable=E1128
|
||||
self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward') # pylint: disable=E1128
|
||||
self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict') # pylint: disable=E1128
|
||||
self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward') # pylint: disable=E1128
|
||||
self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict') # pylint: disable=E1128
|
||||
self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward') # pylint: disable=E1128
|
||||
self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict') # pylint: disable=E1128
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
import torch
|
||||
|
||||
|
||||
def make_weight_cp(t, wa, wb):
|
||||
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
||||
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
||||
|
||||
|
||||
def rebuild_conventional(up, down, shape, dyn_dim=None):
|
||||
up = up.reshape(up.size(0), -1)
|
||||
down = down.reshape(down.size(0), -1)
|
||||
if dyn_dim is not None:
|
||||
up = up[:, :dyn_dim]
|
||||
down = down[:dyn_dim, :]
|
||||
return (up @ down).reshape(shape)
|
||||
|
||||
|
||||
def rebuild_cp_decomposition(up, down, mid):
|
||||
up = up.reshape(up.size(0), -1)
|
||||
down = down.reshape(down.size(0), -1)
|
||||
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
from __future__ import annotations
|
||||
import os
|
||||
from collections import namedtuple
|
||||
import enum
|
||||
|
||||
from modules import sd_models, hashes, shared
|
||||
|
||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||
|
||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||
|
||||
|
||||
class SdVersion(enum.Enum):
|
||||
Unknown = 1
|
||||
SD1 = 2
|
||||
SD2 = 3
|
||||
SDXL = 4
|
||||
|
||||
|
||||
class NetworkOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
self.name = name
|
||||
self.filename = filename
|
||||
self.metadata = {}
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
|
||||
def read_metadata(): # # pylint: disable=W0612
|
||||
metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
|
||||
return metadata
|
||||
|
||||
if self.is_safetensors:
|
||||
self.metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||
"""
|
||||
try:
|
||||
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading lora {filename}")
|
||||
"""
|
||||
|
||||
if self.metadata:
|
||||
m = {}
|
||||
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||
m[k] = v
|
||||
self.metadata = m
|
||||
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||
self.hash = None
|
||||
self.shorthash = None
|
||||
self.set_hash(
|
||||
self.metadata.get('sshs_model_hash') or
|
||||
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
|
||||
''
|
||||
)
|
||||
self.sd_version = self.detect_version()
|
||||
|
||||
def detect_version(self):
|
||||
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
|
||||
return SdVersion.SDXL
|
||||
elif str(self.metadata.get('ss_v2', "")) == "True":
|
||||
return SdVersion.SD2
|
||||
elif len(self.metadata):
|
||||
return SdVersion.SD1
|
||||
return SdVersion.Unknown
|
||||
|
||||
def set_hash(self, v):
|
||||
self.hash = v
|
||||
self.shorthash = self.hash[0:12]
|
||||
if self.shorthash:
|
||||
import networks
|
||||
networks.available_network_hash_lookup[self.shorthash] = self
|
||||
|
||||
def read_hash(self):
|
||||
if not self.hash:
|
||||
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
||||
|
||||
def get_alias(self):
|
||||
import networks
|
||||
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
|
||||
return self.name
|
||||
else:
|
||||
return self.alias
|
||||
|
||||
|
||||
class Network: # LoraModule
|
||||
def __init__(self, name, network_on_disk: NetworkOnDisk):
|
||||
self.name = name
|
||||
self.network_on_disk = network_on_disk
|
||||
self.te_multiplier = 1.0
|
||||
self.unet_multiplier = 1.0
|
||||
self.dyn_dim = None
|
||||
self.modules = {}
|
||||
self.mtime = None
|
||||
self.mentioned_name = None
|
||||
"""the text that was used to add the network to prompt - can be either name or an alias"""
|
||||
|
||||
|
||||
class ModuleType:
|
||||
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: # pylint: disable=W0613
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModule:
|
||||
def __init__(self, net: Network, weights: NetworkWeights):
|
||||
self.network = net
|
||||
self.network_key = weights.network_key
|
||||
self.sd_key = weights.sd_key
|
||||
self.sd_module = weights.sd_module
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
self.dim = None
|
||||
self.bias = weights.w.get("bias")
|
||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
||||
|
||||
def multiplier(self):
|
||||
if 'transformer' in self.sd_key[:20]:
|
||||
return self.network.te_multiplier
|
||||
else:
|
||||
return self.network.unet_multiplier
|
||||
|
||||
def calc_scale(self):
|
||||
if self.scale is not None:
|
||||
return self.scale
|
||||
if self.dim is not None and self.alpha is not None:
|
||||
return self.alpha / self.dim
|
||||
return 1.0
|
||||
|
||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||
if self.bias is not None:
|
||||
updown = updown.reshape(self.bias.shape)
|
||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = updown.reshape(output_shape)
|
||||
if len(output_shape) == 4:
|
||||
updown = updown.reshape(output_shape)
|
||||
if orig_weight.size().numel() == updown.size().numel():
|
||||
updown = updown.reshape(orig_weight.shape)
|
||||
if ex_bias is not None:
|
||||
ex_bias = ex_bias * self.multiplier()
|
||||
return updown * self.calc_scale() * self.multiplier(), ex_bias
|
||||
|
||||
def calc_updown(self, target):
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, y):
|
||||
raise NotImplementedError()
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
import network
|
||||
|
||||
|
||||
class ModuleTypeFull(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["diff"]):
|
||||
return NetworkModuleFull(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleFull(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.weight = weights.w.get("diff")
|
||||
self.ex_bias = weights.w.get("diff_b")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.weight.shape
|
||||
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
if self.ex_bias is not None:
|
||||
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
import lyco_helpers
|
||||
import network
|
||||
|
||||
|
||||
class ModuleTypeHada(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
|
||||
return NetworkModuleHada(net, weights)
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleHada(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
self.w1a = weights.w["hada_w1_a"]
|
||||
self.w1b = weights.w["hada_w1_b"]
|
||||
self.dim = self.w1b.shape[0]
|
||||
self.w2a = weights.w["hada_w2_a"]
|
||||
self.w2b = weights.w["hada_w2_b"]
|
||||
self.t1 = weights.w.get("hada_t1")
|
||||
self.t2 = weights.w.get("hada_t2")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
if self.t1 is not None:
|
||||
output_shape = [w1a.size(1), w1b.size(1)]
|
||||
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
|
||||
output_shape += t1.shape[2:]
|
||||
else:
|
||||
if len(w1b.shape) == 4:
|
||||
output_shape += w1b.shape[2:]
|
||||
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
|
||||
if self.t2 is not None:
|
||||
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
else:
|
||||
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
|
||||
updown = updown1 * updown2
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
import network
|
||||
|
||||
|
||||
class ModuleTypeIa3(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["weight"]):
|
||||
return NetworkModuleIa3(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleIa3(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
self.w = weights.w["weight"]
|
||||
self.on_input = weights.w["on_input"].item()
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
output_shape = [w.size(0), orig_weight.size(1)]
|
||||
if self.on_input:
|
||||
output_shape.reverse()
|
||||
else:
|
||||
w = w.reshape(-1, 1)
|
||||
updown = orig_weight * w
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
import torch
|
||||
import lyco_helpers
|
||||
import network
|
||||
|
||||
|
||||
class ModuleTypeLokr(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
|
||||
has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
|
||||
if has_1 and has_2:
|
||||
return NetworkModuleLokr(net, weights)
|
||||
return None
|
||||
|
||||
|
||||
def make_kron(orig_shape, w1, w2):
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
return torch.kron(w1, w2).reshape(orig_shape)
|
||||
|
||||
|
||||
class NetworkModuleLokr(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
self.w1 = weights.w.get("lokr_w1")
|
||||
self.w1a = weights.w.get("lokr_w1_a")
|
||||
self.w1b = weights.w.get("lokr_w1_b")
|
||||
self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
|
||||
self.w2 = weights.w.get("lokr_w2")
|
||||
self.w2a = weights.w.get("lokr_w2_a")
|
||||
self.w2b = weights.w.get("lokr_w2_b")
|
||||
self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
|
||||
self.t2 = weights.w.get("lokr_t2")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
if self.w1 is not None:
|
||||
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
else:
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1 = w1a @ w1b
|
||||
if self.w2 is not None:
|
||||
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
elif self.t2 is None:
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
||||
if len(orig_weight.shape) == 4:
|
||||
output_shape = orig_weight.shape
|
||||
updown = make_kron(output_shape, w1, w2)
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import torch
|
||||
|
||||
import lyco_helpers
|
||||
import network
|
||||
from modules import devices
|
||||
|
||||
|
||||
class ModuleTypeLora(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
|
||||
return NetworkModuleLora(net, weights)
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleLora(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
self.up_model = self.create_module(weights.w, "lora_up.weight")
|
||||
self.down_model = self.create_module(weights.w, "lora_down.weight")
|
||||
self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
|
||||
self.dim = weights.w["lora_down.weight"].shape[0]
|
||||
|
||||
def create_module(self, weights, key, none_ok=False):
|
||||
weight = weights.get(key)
|
||||
if weight is None and none_ok:
|
||||
return None
|
||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||
if is_linear:
|
||||
weight = weight.reshape(weight.shape[0], -1)
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif is_conv and key == "lora_down.weight" or key == "dyn_up":
|
||||
if len(weight.shape) == 2:
|
||||
weight = weight.reshape(weight.shape[0], -1, 1, 1)
|
||||
if weight.shape[2] != 1 or weight.shape[3] != 1:
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
|
||||
else:
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||
elif is_conv and key == "lora_mid.weight":
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
|
||||
elif is_conv and key == "lora_up.weight" or key == "dyn_down":
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||
else:
|
||||
raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
|
||||
with torch.no_grad():
|
||||
if weight.shape != module.weight.shape:
|
||||
weight = weight.reshape(module.weight.shape)
|
||||
module.weight.copy_(weight)
|
||||
module.to(device=devices.cpu, dtype=devices.dtype)
|
||||
module.weight.requires_grad_(False)
|
||||
return module
|
||||
|
||||
def calc_updown(self, orig_weight): # pylint: disable=W0237
|
||||
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
output_shape = [up.size(0), down.size(1)]
|
||||
if self.mid_model is not None:
|
||||
# cp-decomposition
|
||||
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
|
||||
output_shape += mid.shape[2:]
|
||||
else:
|
||||
if len(down.shape) == 4:
|
||||
output_shape += down.shape[2:]
|
||||
updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
||||
def forward(self, x, y):
|
||||
self.up_model.to(device=devices.device)
|
||||
self.down_model.to(device=devices.device)
|
||||
return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
import network
|
||||
|
||||
|
||||
class ModuleTypeNorm(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["w_norm", "b_norm"]):
|
||||
return NetworkModuleNorm(net, weights)
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleNorm(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
self.w_norm = weights.w.get("w_norm")
|
||||
self.b_norm = weights.w.get("b_norm")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.w_norm.shape
|
||||
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
if self.b_norm is not None:
|
||||
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
else:
|
||||
ex_bias = None
|
||||
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
||||
|
|
@ -0,0 +1,478 @@
|
|||
from typing import Union
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import lora_patches
|
||||
import network
|
||||
import network_lora
|
||||
import network_hada
|
||||
import network_ia3
|
||||
import network_lokr
|
||||
import network_full
|
||||
import network_norm
|
||||
import torch
|
||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
||||
|
||||
|
||||
module_types = [
|
||||
network_lora.ModuleTypeLora(),
|
||||
network_hada.ModuleTypeHada(),
|
||||
network_ia3.ModuleTypeIa3(),
|
||||
network_lokr.ModuleTypeLokr(),
|
||||
network_full.ModuleTypeFull(),
|
||||
network_norm.ModuleTypeNorm(),
|
||||
]
|
||||
|
||||
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
||||
re_compiled = {}
|
||||
|
||||
suffix_conversion = {
|
||||
"attentions": {},
|
||||
"resnets": {
|
||||
"conv1": "in_layers_2",
|
||||
"conv2": "out_layers_3",
|
||||
"norm1": "in_layers_0",
|
||||
"norm2": "out_layers_0",
|
||||
"time_emb_proj": "emb_layers_1",
|
||||
"conv_shortcut": "skip_connection",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def convert_diffusers_name_to_compvis(key, is_sd2):
|
||||
def match(match_list, regex_text):
|
||||
regex = re_compiled.get(regex_text)
|
||||
if regex is None:
|
||||
regex = re.compile(regex_text)
|
||||
re_compiled[regex_text] = regex
|
||||
r = re.match(regex, key)
|
||||
if not r:
|
||||
return False
|
||||
match_list.clear()
|
||||
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
||||
return True
|
||||
|
||||
m = []
|
||||
if match(m, r"lora_unet_conv_in(.*)"):
|
||||
return f'diffusion_model_input_blocks_0_0{m[0]}'
|
||||
if match(m, r"lora_unet_conv_out(.*)"):
|
||||
return f'diffusion_model_out_2{m[0]}'
|
||||
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
|
||||
return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
|
||||
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
|
||||
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
|
||||
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
|
||||
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
|
||||
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
|
||||
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
|
||||
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
|
||||
if is_sd2:
|
||||
if 'mlp_fc1' in m[1]:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||
elif 'mlp_fc2' in m[1]:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||
else:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
|
||||
if 'mlp_fc1' in m[1]:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||
elif 'mlp_fc2' in m[1]:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||
else:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||
return key
|
||||
|
||||
|
||||
def assign_network_names_to_compvis_modules(sd_model):
|
||||
"""
|
||||
if shared.sd_model.is_sdxl:
|
||||
for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
|
||||
if not hasattr(embedder, 'wrapped'):
|
||||
continue
|
||||
for name, module in embedder.wrapped.named_modules():
|
||||
network_name = f'{i}_{name.replace(".", "_")}'
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
else:
|
||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||
network_name = name.replace(".", "_")
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
"""
|
||||
if not hasattr(shared.sd_model, 'cond_stage_model'):
|
||||
return
|
||||
network_layer_mapping = {}
|
||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||
network_name = name.replace(".", "_")
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
for name, module in shared.sd_model.model.named_modules():
|
||||
network_name = name.replace(".", "_")
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
sd_model.network_layer_mapping = network_layer_mapping
|
||||
|
||||
|
||||
def load_network(name, network_on_disk):
|
||||
net = network.Network(name, network_on_disk)
|
||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||
sd = sd_models.read_state_dict(network_on_disk.filename)
|
||||
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
|
||||
if not hasattr(shared.sd_model, 'network_layer_mapping'):
|
||||
assign_network_names_to_compvis_modules(shared.sd_model)
|
||||
keys_failed_to_match = {}
|
||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
|
||||
matched_networks = {}
|
||||
for key_network, weight in sd.items():
|
||||
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
||||
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
if sd_module is None:
|
||||
m = re_x_proj.match(key)
|
||||
if m:
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
|
||||
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
|
||||
if sd_module is None and "lora_unet" in key_network_without_network_parts:
|
||||
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
|
||||
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
# some SD1 Loras also have correct compvis keys
|
||||
if sd_module is None:
|
||||
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
if sd_module is None:
|
||||
keys_failed_to_match[key_network] = key
|
||||
continue
|
||||
if key not in matched_networks:
|
||||
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
|
||||
matched_networks[key].w[network_part] = weight
|
||||
for key, weights in matched_networks.items():
|
||||
net_module = None
|
||||
for nettype in module_types:
|
||||
net_module = nettype.create_module(net, weights)
|
||||
if net_module is not None:
|
||||
break
|
||||
if net_module is None:
|
||||
raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
|
||||
net.modules[key] = net_module
|
||||
if keys_failed_to_match:
|
||||
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
||||
return net
|
||||
|
||||
|
||||
def load_diffusers(name, network_on_disk, te_multiplier: float, unet_multiplier: float, dyn_dim): # pylint: disable=W0613
|
||||
net = network.Network(name, network_on_disk)
|
||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||
from modules.lora_diffusers import load_diffusers_lora
|
||||
load_diffusers_lora(name, network_on_disk, te_multiplier, unet_multiplier, dyn_dim)
|
||||
return net
|
||||
|
||||
|
||||
def purge_networks_from_memory():
|
||||
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
|
||||
name = next(iter(networks_in_memory))
|
||||
networks_in_memory.pop(name, None)
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||
already_loaded = {}
|
||||
for net in loaded_networks:
|
||||
if net.name in names:
|
||||
already_loaded[net.name] = net
|
||||
loaded_networks.clear()
|
||||
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
||||
if any(x is None for x in networks_on_disk):
|
||||
list_available_networks()
|
||||
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
||||
failed_to_load_networks = []
|
||||
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
||||
net = already_loaded.get(name, None)
|
||||
if network_on_disk is not None:
|
||||
if net is None:
|
||||
net = networks_in_memory.get(name)
|
||||
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
||||
try:
|
||||
if shared.backend == shared.Backend.ORIGINAL:
|
||||
net = load_network(name, network_on_disk)
|
||||
networks_in_memory.pop(name, None)
|
||||
networks_in_memory[name] = net
|
||||
elif shared.backend == shared.Backend.DIFFUSERS:
|
||||
net = load_diffusers(name, network_on_disk, te_multipliers[i] if te_multipliers else 1.0, unet_multipliers[i] if unet_multipliers else 1.0, dyn_dims[i] if dyn_dims else 1.0)
|
||||
except Exception as e:
|
||||
errors.display(e, f"loading network {network_on_disk.filename}")
|
||||
continue
|
||||
net.mentioned_name = name
|
||||
network_on_disk.read_hash()
|
||||
if net is None:
|
||||
failed_to_load_networks.append(name)
|
||||
logging.info(f"Couldn't find network with name {name}")
|
||||
continue
|
||||
else:
|
||||
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
||||
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
|
||||
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
|
||||
if shared.backend == shared.Backend.ORIGINAL: # load_diffusers cache is handled separately
|
||||
loaded_networks.append(net)
|
||||
if failed_to_load_networks:
|
||||
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||
purge_networks_from_memory()
|
||||
|
||||
|
||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
if weights_backup is None and bias_backup is None:
|
||||
return
|
||||
if weights_backup is not None:
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.in_proj_weight.copy_(weights_backup[0])
|
||||
self.out_proj.weight.copy_(weights_backup[1])
|
||||
else:
|
||||
self.weight.copy_(weights_backup)
|
||||
|
||||
if bias_backup is not None:
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.out_proj.bias.copy_(bias_backup)
|
||||
else:
|
||||
self.bias.copy_(bias_backup)
|
||||
else:
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.out_proj.bias = None
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||
"""
|
||||
Applies the currently selected set of networks to the weights of torch layer self.
|
||||
If weights already have this particular set of networks applied, does nothing.
|
||||
If not, restores orginal weights from backup and alters weights according to networks.
|
||||
"""
|
||||
network_layer_name = getattr(self, 'network_layer_name', None)
|
||||
if network_layer_name is None:
|
||||
return
|
||||
current_names = getattr(self, "network_current_names", ())
|
||||
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
if weights_backup is None and wanted_names != (): # pylint: disable=C1803
|
||||
if current_names != ():
|
||||
raise RuntimeError("no backup weights found and current weights are not unchanged")
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
||||
else:
|
||||
weights_backup = self.weight.to(devices.cpu, copy=True)
|
||||
self.network_weights_backup = weights_backup
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
if bias_backup is None:
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
|
||||
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
|
||||
elif getattr(self, 'bias', None) is not None:
|
||||
bias_backup = self.bias.to(devices.cpu, copy=True)
|
||||
else:
|
||||
bias_backup = None
|
||||
self.network_bias_backup = bias_backup
|
||||
if current_names != wanted_names:
|
||||
network_restore_weights_from_backup(self)
|
||||
for net in loaded_networks:
|
||||
module = net.modules.get(network_layer_name, None)
|
||||
if module is not None and hasattr(self, 'weight'):
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown, ex_bias = module.calc_updown(self.weight)
|
||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
self.weight += updown
|
||||
if ex_bias is not None and hasattr(self, 'bias'):
|
||||
if self.bias is None:
|
||||
self.bias = torch.nn.Parameter(ex_bias)
|
||||
else:
|
||||
self.bias += ex_bias
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
continue
|
||||
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
||||
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
||||
module_v = net.modules.get(network_layer_name + "_v_proj", None)
|
||||
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
|
||||
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
|
||||
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
||||
self.in_proj_weight += updown_qkv
|
||||
self.out_proj.weight += updown_out
|
||||
if ex_bias is not None:
|
||||
if self.out_proj.bias is None:
|
||||
self.out_proj.bias = torch.nn.Parameter(ex_bias)
|
||||
else:
|
||||
self.out_proj.bias += ex_bias
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
continue
|
||||
if module is None:
|
||||
continue
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
self.network_current_names = wanted_names
|
||||
|
||||
|
||||
def network_forward(module, input, original_forward): # pylint: disable=W0622
|
||||
"""
|
||||
Old way of applying Lora by executing operations during layer's forward.
|
||||
Stacking many loras this way results in big performance degradation.
|
||||
"""
|
||||
if len(loaded_networks) == 0:
|
||||
return original_forward(module, input)
|
||||
input = devices.cond_cast_unet(input)
|
||||
network_restore_weights_from_backup(module)
|
||||
network_reset_cached_weight(module)
|
||||
y = original_forward(module, input)
|
||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||
for lora in loaded_networks:
|
||||
module = lora.modules.get(network_layer_name, None)
|
||||
if module is None:
|
||||
continue
|
||||
y = module.forward(input, y)
|
||||
return y
|
||||
|
||||
|
||||
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||
self.network_current_names = ()
|
||||
self.network_weights_backup = None
|
||||
|
||||
|
||||
def network_Linear_forward(self, input): # pylint: disable=W0622
|
||||
if shared.opts.lora_functional:
|
||||
return network_forward(self, input, originals.Linear_forward)
|
||||
network_apply_weights(self)
|
||||
return originals.Linear_forward(self, input)
|
||||
|
||||
|
||||
def network_Linear_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
return originals.Linear_load_state_dict(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_Conv2d_forward(self, input): # pylint: disable=W0622
|
||||
if shared.opts.lora_functional:
|
||||
return network_forward(self, input, originals.Conv2d_forward)
|
||||
network_apply_weights(self)
|
||||
return originals.Conv2d_forward(self, input)
|
||||
|
||||
|
||||
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
return originals.Conv2d_load_state_dict(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_GroupNorm_forward(self, input): # pylint: disable=W0622
|
||||
if shared.opts.lora_functional:
|
||||
return network_forward(self, input, originals.GroupNorm_forward)
|
||||
network_apply_weights(self)
|
||||
return originals.GroupNorm_forward(self, input)
|
||||
|
||||
|
||||
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_LayerNorm_forward(self, input): # pylint: disable=W0622
|
||||
if shared.opts.lora_functional:
|
||||
return network_forward(self, input, originals.LayerNorm_forward)
|
||||
network_apply_weights(self)
|
||||
return originals.LayerNorm_forward(self, input)
|
||||
|
||||
|
||||
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
||||
network_apply_weights(self)
|
||||
return originals.MultiheadAttention_forward(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
|
||||
|
||||
|
||||
def list_available_networks():
|
||||
available_networks.clear()
|
||||
available_network_aliases.clear()
|
||||
forbidden_network_aliases.clear()
|
||||
available_network_hash_lookup.clear()
|
||||
forbidden_network_aliases.update({"none": 1, "Addams": 1})
|
||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||
for filename in candidates:
|
||||
if os.path.isdir(filename):
|
||||
continue
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
try:
|
||||
entry = network.NetworkOnDisk(name, filename)
|
||||
except OSError as e: # should catch FileNotFoundError and PermissionError etc.
|
||||
shared.log.error(f"Failed to load network {name} from {filename} {e}")
|
||||
continue
|
||||
available_networks[name] = entry
|
||||
if entry.alias in available_network_aliases:
|
||||
forbidden_network_aliases[entry.alias.lower()] = 1
|
||||
available_network_aliases[name] = entry
|
||||
available_network_aliases[entry.alias] = entry
|
||||
|
||||
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||
|
||||
|
||||
def infotext_pasted(infotext, params): # pylint: disable=W0613
|
||||
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||
added = []
|
||||
for k in params:
|
||||
if not k.startswith("AddNet Model "):
|
||||
continue
|
||||
num = k[13:]
|
||||
if params.get("AddNet Module " + num) != "LoRA":
|
||||
continue
|
||||
name = params.get("AddNet Model " + num)
|
||||
if name is None:
|
||||
continue
|
||||
m = re_network_name.match(name)
|
||||
if m:
|
||||
name = m.group(1)
|
||||
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
||||
added.append(f"<lora:{name}:{multiplier}>")
|
||||
if added:
|
||||
params["Prompt"] += "\n" + "".join(added)
|
||||
|
||||
|
||||
originals: lora_patches.LoraPatches = None
|
||||
extra_network_lora = None
|
||||
available_networks = {}
|
||||
available_network_aliases = {}
|
||||
loaded_networks = []
|
||||
networks_in_memory = {}
|
||||
available_network_hash_lookup = {}
|
||||
forbidden_network_aliases = {}
|
||||
list_available_networks()
|
||||
|
|
@ -4,3 +4,4 @@ from modules import paths
|
|||
|
||||
def preload(parser):
|
||||
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
||||
parser.add_argument("--lyco-dir-backcompat", type=str, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
|
||||
|
|
|
|||
|
|
@ -1,67 +1,49 @@
|
|||
import re
|
||||
|
||||
import torch
|
||||
import gradio as gr
|
||||
from fastapi import FastAPI
|
||||
|
||||
import lora
|
||||
import network
|
||||
import networks
|
||||
import lora # noqa:F401 # pylint: disable=unused-import
|
||||
import lora_patches
|
||||
import extra_networks_lora
|
||||
import ui_extra_networks_lora
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||
|
||||
|
||||
def unload():
|
||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
||||
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
|
||||
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
|
||||
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
|
||||
networks.originals.undo()
|
||||
|
||||
|
||||
def before_ui():
|
||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
|
||||
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
|
||||
extra_networks.register_extra_network(networks.extra_network_lora)
|
||||
# extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
|
||||
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
||||
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
|
||||
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
||||
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
|
||||
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
|
||||
|
||||
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
|
||||
torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
|
||||
|
||||
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
|
||||
torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
|
||||
|
||||
torch.nn.Linear.forward = lora.lora_Linear_forward
|
||||
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
|
||||
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
||||
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
|
||||
torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
|
||||
torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
|
||||
|
||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||
networks.originals = lora_patches.LoraPatches()
|
||||
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
||||
script_callbacks.on_script_unloaded(unload)
|
||||
script_callbacks.on_before_ui(before_ui)
|
||||
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
|
||||
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
|
||||
|
||||
|
||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: { "choices": ["None", *lora.available_loras], "visible": False }, refresh=lora.list_available_loras),
|
||||
"lora_preferred_name": shared.OptionInfo("Filename", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
|
||||
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext", gr.Checkbox, { "visible": False }),
|
||||
"sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks], "visible": False}, refresh=networks.list_available_networks),
|
||||
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
|
||||
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
||||
# "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||
# "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
||||
}))
|
||||
|
||||
|
||||
def create_lora_json(obj: lora.LoraOnDisk):
|
||||
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
||||
"lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
||||
}))
|
||||
|
||||
|
||||
def create_lora_json(obj: network.NetworkOnDisk):
|
||||
return {
|
||||
"name": obj.name,
|
||||
"alias": obj.alias,
|
||||
|
|
@ -70,42 +52,37 @@ def create_lora_json(obj: lora.LoraOnDisk):
|
|||
}
|
||||
|
||||
|
||||
def api_loras(_: gr.Blocks, app: FastAPI):
|
||||
def api_networks(_: gr.Blocks, app: FastAPI):
|
||||
@app.get("/sdapi/v1/loras")
|
||||
async def get_loras():
|
||||
return [create_lora_json(obj) for obj in lora.available_loras.values()]
|
||||
return [create_lora_json(obj) for obj in networks.available_networks.values()]
|
||||
|
||||
@app.post("/sdapi/v1/refresh-loras")
|
||||
async def refresh_loras():
|
||||
return lora.list_available_loras()
|
||||
return networks.list_available_networks()
|
||||
|
||||
|
||||
script_callbacks.on_app_started(api_loras)
|
||||
|
||||
script_callbacks.on_app_started(api_networks)
|
||||
re_lora = re.compile("<lora:([^:]+):")
|
||||
|
||||
|
||||
def infotext_pasted(infotext, d):
|
||||
def infotext_pasted(infotext, d): # pylint: disable=unused-argument
|
||||
hashes = d.get("Lora hashes")
|
||||
if not hashes:
|
||||
return
|
||||
|
||||
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
|
||||
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
|
||||
|
||||
def lora_replacement(m):
|
||||
def network_replacement(m):
|
||||
alias = m.group(1)
|
||||
shorthash = hashes.get(alias)
|
||||
if shorthash is None:
|
||||
return m.group(0)
|
||||
|
||||
lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
|
||||
if lora_on_disk is None:
|
||||
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
|
||||
if network_on_disk is None:
|
||||
return m.group(0)
|
||||
|
||||
return f'<lora:{lora_on_disk.get_alias()}:'
|
||||
|
||||
d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
|
||||
return f'<lora:{network_on_disk.get_alias()}:'
|
||||
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
|
||||
|
||||
|
||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import json
|
||||
import lora
|
||||
|
||||
import network
|
||||
import networks
|
||||
from modules import shared, ui_extra_networks
|
||||
|
||||
|
||||
|
|
@ -10,39 +10,100 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||
super().__init__('Lora')
|
||||
|
||||
def refresh(self):
|
||||
lora.list_available_loras()
|
||||
networks.list_available_networks()
|
||||
|
||||
def create_item(self, name):
|
||||
l = networks.available_networks.get(name)
|
||||
# alias = lora_on_disk.get_alias()
|
||||
try:
|
||||
path, _ext = os.path.splitext(l.filename)
|
||||
possible_tags = l.metadata.get('ss_tag_frequency', {}) if l.metadata is not None else {}
|
||||
if shared.backend == shared.Backend.ORIGINAL:
|
||||
if l.sd_version == network.SdVersion.SDXL:
|
||||
return None
|
||||
elif shared.backend == shared.Backend.DIFFUSERS:
|
||||
if shared.sd_model_type == 'none': # return all when model is not loaded
|
||||
pass
|
||||
elif shared.sd_model_type == 'sdxl':
|
||||
if l.sd_version == network.SdVersion.SD1 or l.sd_version == network.SdVersion.SD2:
|
||||
return None
|
||||
elif shared.sd_model_type == 'sd':
|
||||
if l.sd_version == network.SdVersion.SDXL:
|
||||
return None
|
||||
if isinstance(possible_tags, str):
|
||||
possible_tags = {}
|
||||
tags = {}
|
||||
for k, v in possible_tags.items():
|
||||
words = k.split('_', 1) if '_' in k else [v, k]
|
||||
words = [str(w).replace('.json', '') for w in words]
|
||||
if words[0] == '{}':
|
||||
words[0] = 0
|
||||
tags[' '.join(words[1:])] = words[0]
|
||||
name = os.path.splitext(os.path.relpath(l.filename, shared.cmd_opts.lora_dir))[0]
|
||||
item = {
|
||||
"type": 'Lora',
|
||||
"name": name,
|
||||
"filename": l.filename,
|
||||
"hash": l.shorthash,
|
||||
"search_term": self.search_terms_from_path(l.filename) + ' '.join(tags.keys()),
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"info": self.find_info(path),
|
||||
"prompt": json.dumps(f" <lora:{l.get_alias()}:{shared.opts.extra_networks_default_multiplier}>"),
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": json.dumps(l.metadata, indent=4) if l.metadata else None,
|
||||
"tags": tags,
|
||||
}
|
||||
return item
|
||||
except Exception as e:
|
||||
shared.log.debug(f"Extra networks error: type=lora file={name} {e}")
|
||||
return None
|
||||
|
||||
"""
|
||||
item = {
|
||||
"name": name,
|
||||
"filename": lora_on_disk.filename,
|
||||
"shorthash": lora_on_disk.shorthash,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": lora_on_disk.metadata,
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
||||
"sd_version": lora_on_disk.sd_version.name,
|
||||
}
|
||||
self.read_user_metadata(item)
|
||||
activation_text = item["user_metadata"].get("activation text")
|
||||
preferred_weight = item["user_metadata"].get("preferred weight", 0.0)
|
||||
item["prompt"] = quote_js(f"<lora:{alias}:") + " + " + (str(preferred_weight) if preferred_weight else "opts.extra_networks_default_multiplier") + " + " + quote_js(">")
|
||||
if activation_text:
|
||||
item["prompt"] += " + " + quote_js(" " + activation_text)
|
||||
sd_version = item["user_metadata"].get("sd version")
|
||||
if sd_version in network.SdVersion.__members__:
|
||||
item["sd_version"] = sd_version
|
||||
sd_version = network.SdVersion[sd_version]
|
||||
else:
|
||||
sd_version = lora_on_disk.sd_version
|
||||
if shared.opts.lora_show_all or not enable_filter:
|
||||
pass
|
||||
elif sd_version == network.SdVersion.Unknown:
|
||||
model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
|
||||
if model_version.name in shared.opts.lora_hide_unknown_for_versions:
|
||||
return None
|
||||
elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
|
||||
return None
|
||||
elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:
|
||||
return None
|
||||
elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:
|
||||
return None
|
||||
return item
|
||||
"""
|
||||
|
||||
def list_items(self):
|
||||
for name, l in lora.available_loras.items():
|
||||
try:
|
||||
path, _ext = os.path.splitext(l.filename)
|
||||
possible_tags = l.metadata.get('ss_tag_frequency', {}) if l.metadata is not None else {}
|
||||
if isinstance(possible_tags, str):
|
||||
possible_tags = {}
|
||||
tags = {}
|
||||
for k, v in possible_tags.items():
|
||||
words = k.split('_', 1) if '_' in k else [v, k]
|
||||
words = [str(w).replace('.json', '') for w in words]
|
||||
if words[0] == '{}':
|
||||
words[0] = 0
|
||||
tags[' '.join(words[1:])] = words[0]
|
||||
name = os.path.splitext(os.path.relpath(l.filename, shared.cmd_opts.lora_dir))[0]
|
||||
yield {
|
||||
"type": 'Lora',
|
||||
"name": name,
|
||||
"filename": l.filename,
|
||||
"hash": l.shorthash,
|
||||
"search_term": self.search_terms_from_path(l.filename) + ' '.join(tags.keys()),
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"info": self.find_info(path),
|
||||
"prompt": json.dumps(f" <lora:{l.get_alias()}:{shared.opts.extra_networks_default_multiplier}>"),
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": json.dumps(l.metadata, indent=4) if l.metadata else None,
|
||||
"tags": tags,
|
||||
}
|
||||
except Exception as e:
|
||||
shared.log.debug(f"Extra networks error: type=lora file={name} {e}")
|
||||
for _index, name in enumerate(networks.available_networks):
|
||||
item = self.create_item(name)
|
||||
if item is not None:
|
||||
yield item
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [shared.cmd_opts.lora_dir]
|
||||
return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir]
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit 543bc702ba3b04877328bfb7cded5679597b627f
|
||||
Subproject commit 153157bd527865fbfdbd2b676f63cb14ae93a211
|
||||
|
|
@ -1,28 +1,35 @@
|
|||
import os
|
||||
import time
|
||||
import diffusers
|
||||
import diffusers.models.lora as diffusers_lora
|
||||
# from modules import shared
|
||||
import modules.shared as shared
|
||||
import modules.errors
|
||||
|
||||
|
||||
lora_state = { # TODO Lora state for Diffusers
|
||||
debug_output = os.environ.get('SD_LORA_DEBUG', None)
|
||||
debug = shared.log.info if debug_output is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
lora_state = { # Lora state for Diffusers
|
||||
'multiplier': [],
|
||||
'active': False,
|
||||
'loaded': [],
|
||||
'all_loras': [],
|
||||
}
|
||||
|
||||
def unload_diffusers_lora():
|
||||
try:
|
||||
pipe = shared.sd_model
|
||||
if shared.opts.diffusers_lora_loader == "diffusers":
|
||||
if len(lora_state['loaded']) > 1 and hasattr(pipe, "unfuse_lora"):
|
||||
debug(f'LoRA unfuse: loader={shared.opts.diffusers_lora_loader}')
|
||||
pipe.unfuse_lora()
|
||||
pipe.unload_lora_weights()
|
||||
pipe._remove_text_encoder_monkey_patch() # pylint: disable=W0212
|
||||
proc_cls_name = next(iter(pipe.unet.attn_processors.values())).__class__.__name__
|
||||
non_lora_proc_cls = getattr(diffusers.models.attention_processor, proc_cls_name)#[len("LORA"):])
|
||||
pipe.unet.set_attn_processor(non_lora_proc_cls())
|
||||
# shared.log.debug('Diffusers LoRA unloaded')
|
||||
else:
|
||||
lora_state['all_loras'].reverse()
|
||||
lora_state['multiplier'].reverse()
|
||||
|
|
@ -35,29 +42,30 @@ def unload_diffusers_lora():
|
|||
lora_state['loaded'].clear()
|
||||
lora_state['all_loras'] = []
|
||||
lora_state['multiplier'] = []
|
||||
debug(f'LoRA unloaded: loader={shared.opts.diffusers_lora_loader}')
|
||||
except Exception as e:
|
||||
shared.log.error(f"LoRA unload failed: {e}")
|
||||
|
||||
|
||||
def load_diffusers_lora(name, lora, strength = 1.0, num_loras = 1):
|
||||
if f'{lora.filename}:{strength}' in lora_state['loaded']:
|
||||
shared.log.info(f'LoRA cached: {name} strength={strength}')
|
||||
def load_diffusers_lora(name, lora, te_multiplier = 1.0, unet_multiplier = 1.0, dyn_dim = None): # TODO: te_multiplier is used as strength and unet_multiplier is ignored
|
||||
if f'{lora.filename}:{te_multiplier}' in lora_state['loaded']:
|
||||
debug(f'LoRA cached: {name} te-strength={te_multiplier} unet-strength={unet_multiplier} dyn-dim={dyn_dim}')
|
||||
return
|
||||
try:
|
||||
t0 = time.time()
|
||||
pipe = shared.sd_model
|
||||
lora_state['active'] = True
|
||||
lora_state['multiplier'].append(strength)
|
||||
lora_state['multiplier'].append(te_multiplier)
|
||||
fuse = 0
|
||||
if shared.opts.diffusers_lora_loader.startswith("diffusers"):
|
||||
pipe.load_lora_weights(lora.filename, cache_dir=shared.opts.diffusers_dir, local_files_only=True, lora_scale=strength, low_cpu_mem_usage=True)
|
||||
if num_loras > 1 and hasattr(pipe, "fuse_lora"):
|
||||
pipe.load_lora_weights(lora.filename, cache_dir=shared.opts.diffusers_dir, local_files_only=True, lora_scale=te_multiplier, low_cpu_mem_usage=True)
|
||||
if hasattr(pipe, "fuse_lora"):
|
||||
t2 = time.time()
|
||||
pipe.fuse_lora(lora_scale=strength)
|
||||
pipe.fuse_lora(lora_scale=te_multiplier)
|
||||
fuse = time.time() - t2
|
||||
lora_state['loaded'].append(f'{lora.filename}:{strength}')
|
||||
lora_state['loaded'].append(f'{lora.filename}:{te_multiplier}')
|
||||
if shared.compiled_model_state is not None: #filename breaks caching
|
||||
shared.compiled_model_state.lora_model.append(f'{name}:{strength}')
|
||||
shared.compiled_model_state.lora_model.append(f'{name}:{te_multiplier}')
|
||||
else:
|
||||
from safetensors.torch import load_file
|
||||
lora_sd = load_file(lora.filename)
|
||||
|
|
@ -65,23 +73,26 @@ def load_diffusers_lora(name, lora, strength = 1.0, num_loras = 1):
|
|||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
else:
|
||||
text_encoders = pipe.text_encoder
|
||||
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=strength)
|
||||
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=te_multiplier)
|
||||
lora_network.load_state_dict(lora_sd)
|
||||
if shared.opts.diffusers_lora_loader == "merge and apply":
|
||||
lora_network.merge_to(multiplier=strength)
|
||||
lora_network.merge_to(multiplier=te_multiplier)
|
||||
if shared.opts.diffusers_lora_loader == "sequential apply":
|
||||
lora_network.to(shared.device, dtype=pipe.unet.dtype)
|
||||
lora_network.apply_to(multiplier=strength)
|
||||
lora_network.apply_to(multiplier=te_multiplier)
|
||||
lora_state['all_loras'].append(lora_network)
|
||||
lora_state['loaded'].append(f'{lora.filename}:{strength}')
|
||||
lora_state['loaded'].append(f'{lora.filename}:{te_multiplier}')
|
||||
if shared.compiled_model_state is not None: #filename breaks caching
|
||||
shared.compiled_model_state.lora_model.append(f'{name}:{strength}')
|
||||
shared.compiled_model_state.lora_model.append(f'{name}:{te_multiplier}')
|
||||
t1 = time.time()
|
||||
fuse = f'fuse={fuse:.2f}s' if fuse > 0 else ''
|
||||
shared.log.info(f'LoRA loaded: {name} strength={strength} loader="{shared.opts.diffusers_lora_loader}" lora={t1-t0:.2f}s {fuse}')
|
||||
shared.log.info(f'LoRA loaded: {name} strength={te_multiplier} loader="{shared.opts.diffusers_lora_loader}" lora={t1-t0:.2f}s {fuse}')
|
||||
except Exception as e:
|
||||
lines = str(e).splitlines()
|
||||
shared.log.error(f'LoRA loading failed: {name} loader="{shared.opts.diffusers_lora_loader}" {lines[0]}')
|
||||
if debug_output is None:
|
||||
shared.log.error(f'LoRA load failed: {name} loader="{shared.opts.diffusers_lora_loader}" {lines[0]}')
|
||||
else:
|
||||
modules.errors.display(e, 'LoRA load failed')
|
||||
|
||||
|
||||
# Diffusersで動くLoRA。このファイル単独で完結する。
|
||||
|
|
@ -249,7 +260,7 @@ class LoRAModule(torch.nn.Module):
|
|||
self.org_module[0].forward = self.org_forward
|
||||
|
||||
# forward with lora
|
||||
def forward(self, x, scale = 1.0):
|
||||
def forward(self, x, scale = 1.0): # pylint: disable=unused-argument
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
|
@ -442,7 +453,7 @@ class LoRANetwork(torch.nn.Module): # pylint: disable=abstract-method
|
|||
|
||||
self.unet_loras: List[LoRAModule]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
shared.log.debug(f"LoRA modules loaded/skipped: te={len(self.text_encoder_loras)}/{len(skipped_te)} unet={len(self.unet_loras)}/skip={len(skipped_un)}")
|
||||
debug(f"LoRA module: te_loaded={len(self.text_encoder_loras)} te_skipped={len(skipped_te)} unet_loaded={len(self.unet_loras)} unet_skipped={len(skipped_un)}")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
|
|
@ -475,9 +486,7 @@ class LoRANetwork(torch.nn.Module): # pylint: disable=abstract-method
|
|||
converted_count += 1
|
||||
else:
|
||||
not_converted_count += 1
|
||||
if not_converted_count > 0:
|
||||
shared.log.warning(f'LoRA modules not converted: {not_converted_count}')
|
||||
|
||||
debug(f'LoRA module: unet converted={converted_count}/{not_converted_count}')
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
|
|
|
|||
|
|
@ -99,6 +99,8 @@ def download_civit_preview(model_path: str, preview_url: str):
|
|||
raise ValueError(f'removed invalid download: bytes={written}')
|
||||
img = Image.open(preview_file)
|
||||
except Exception as e:
|
||||
os.remove(preview_file)
|
||||
res += f' error={e}'
|
||||
shared.log.error(f'CivitAI download error: url={preview_url} file={preview_file} {e}')
|
||||
shared.state.end()
|
||||
if img is None:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
from collections import defaultdict
|
||||
|
||||
|
||||
def patch(key, obj, field, replacement):
|
||||
"""Replaces a function in a module or a class.
|
||||
|
||||
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
|
||||
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
|
||||
|
||||
Arguments:
|
||||
key: identifying information for who is doing the replacement. You can use __name__.
|
||||
obj: the module or the class
|
||||
field: name of the function as a string
|
||||
replacement: the new function
|
||||
|
||||
Returns:
|
||||
the original function
|
||||
"""
|
||||
|
||||
patch_key = (obj, field)
|
||||
if patch_key in originals[key]:
|
||||
raise RuntimeError(f"patch for {field} is already applied")
|
||||
|
||||
original_func = getattr(obj, field)
|
||||
originals[key][patch_key] = original_func
|
||||
|
||||
setattr(obj, field, replacement)
|
||||
|
||||
return original_func
|
||||
|
||||
|
||||
def undo(key, obj, field):
|
||||
"""Undoes the peplacement by the patch().
|
||||
|
||||
If the function is not replaced, raises an exception.
|
||||
|
||||
Arguments:
|
||||
key: identifying information for who is doing the replacement. You can use __name__.
|
||||
obj: the module or the class
|
||||
field: name of the function as a string
|
||||
|
||||
Returns:
|
||||
Always None
|
||||
"""
|
||||
|
||||
patch_key = (obj, field)
|
||||
|
||||
if patch_key not in originals[key]:
|
||||
raise RuntimeError(f"there is no patch for {field} to undo")
|
||||
|
||||
original_func = originals[key].pop(patch_key)
|
||||
setattr(obj, field, original_func)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def original(key, obj, field):
|
||||
"""Returns the original function for the patch created by the patch() function"""
|
||||
patch_key = (obj, field)
|
||||
|
||||
return originals[key].get(patch_key, None)
|
||||
|
||||
|
||||
originals = defaultdict(dict)
|
||||
|
|
@ -164,12 +164,16 @@ class StableDiffusionProcessing:
|
|||
self.s_max = shared.opts.s_max
|
||||
self.s_tmin = shared.opts.s_tmin
|
||||
self.s_tmax = float('inf') # not representable as a standard ui option
|
||||
self.comments = {}
|
||||
shared.opts.data['clip_skip'] = clip_skip
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
return shared.sd_model
|
||||
|
||||
def comment(self, text):
|
||||
self.comments[text] = 1
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ def prepare_embedding_providers(pipe, clip_skip):
|
|||
def pad_to_same_length(embeds):
|
||||
try: #SDXL
|
||||
empty_embed = shared.sd_model.encode_prompt("")
|
||||
except: #SD1.5
|
||||
except Exception: #SD1.5
|
||||
empty_embed = shared.sd_model.encode_prompt("",shared.sd_model.device, 1, False)
|
||||
|
||||
empty_batched = torch.cat([empty_embed[0]] * embeds[0].shape[0])
|
||||
|
|
|
|||
|
|
@ -544,6 +544,8 @@ class ModelData:
|
|||
with self.lock:
|
||||
try:
|
||||
self.sd_model = reload_model_weights(op='model')
|
||||
if self.sd_model is not None:
|
||||
self.sd_model.is_sdxl = False # a1111 compatibility item
|
||||
self.initial = False
|
||||
except Exception as e:
|
||||
shared.log.error("Failed to load stable diffusion model")
|
||||
|
|
|
|||
|
|
@ -151,6 +151,9 @@ class State:
|
|||
devices.torch_gc()
|
||||
|
||||
def end(self):
|
||||
if self.time_start is None: # someone called end before being
|
||||
log.debug(f'Access state.end: {sys._getframe().f_back.f_code.co_name}') # pylint: disable=protected-access
|
||||
self.time_start = time.time()
|
||||
log.debug(f'State end: {self.job} time={time.time() - self.time_start:.2f}s')
|
||||
self.job = ""
|
||||
self.job_count = 0
|
||||
|
|
@ -679,9 +682,9 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|||
"extra_networks_card_square": OptionInfo(True, "UI disable variable aspect ratio"),
|
||||
"extra_networks_card_fit": OptionInfo("cover", "UI image contain method", gr.Radio, lambda: {"choices": ["contain", "cover", "fill"], "visible": False}),
|
||||
"extra_network_skip_indexing": OptionInfo(False, "Do not automatically build extra network pages", gr.Checkbox),
|
||||
"lyco_patch_lora": OptionInfo(False, "Use LyCoris handler for all LoRA types", gr.Checkbox),
|
||||
"lyco_patch_lora": OptionInfo(False, "Use LyCoris handler for all LoRA types", gr.Checkbox, { "visible": False }),
|
||||
"lora_functional": OptionInfo(False, "Use Kohya method for handling multiple LoRA", gr.Checkbox, { "visible": False }),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: { "choices": ["None"] + list(hypernetworks.keys()), "visible": False }, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
|
|
@ -1023,6 +1026,7 @@ def req(url_addr, headers = None, **kwargs):
|
|||
class Shared(sys.modules[__name__].__class__): # this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than at program startup.
|
||||
@property
|
||||
def sd_model(self):
|
||||
# log.debug(f'Access shared.sd_model: {sys._getframe().f_back.f_code.co_name}') # pylint: disable=protected-access
|
||||
import modules.sd_models # pylint: disable=W0621
|
||||
return modules.sd_models.model_data.get_sd_model()
|
||||
|
||||
|
|
@ -1048,6 +1052,10 @@ class Shared(sys.modules[__name__].__class__): # this class is here to provide s
|
|||
@property
|
||||
def sd_model_type(self):
|
||||
try:
|
||||
import modules.sd_models # pylint: disable=W0621
|
||||
if modules.sd_models.model_data.sd_model is None:
|
||||
model_type = 'none'
|
||||
return model_type
|
||||
if backend == Backend.ORIGINAL:
|
||||
model_type = 'ldm'
|
||||
elif "StableDiffusionXL" in self.sd_model.__class__.__name__:
|
||||
|
|
@ -1065,6 +1073,10 @@ class Shared(sys.modules[__name__].__class__): # this class is here to provide s
|
|||
@property
|
||||
def sd_refiner_type(self):
|
||||
try:
|
||||
import modules.sd_models # pylint: disable=W0621
|
||||
if modules.sd_models.model_data.sd_refiner is None:
|
||||
model_type = 'none'
|
||||
return model_type
|
||||
if backend == Backend.ORIGINAL:
|
||||
model_type = 'ldm'
|
||||
elif "StableDiffusionXL" in self.sd_refiner.__class__.__name__:
|
||||
|
|
|
|||
|
|
@ -129,6 +129,8 @@ class EmbeddingDatabase:
|
|||
return vec.shape[1]
|
||||
|
||||
def load_diffusers_embedding(self, filename: str, path: str):
|
||||
if shared.sd_model is None:
|
||||
return
|
||||
fn, ext = os.path.splitext(filename)
|
||||
if ext.lower() != ".pt" and ext.lower() != ".safetensors":
|
||||
return
|
||||
|
|
@ -276,6 +278,8 @@ class EmbeddingDatabase:
|
|||
continue
|
||||
|
||||
def load_textual_inversion_embeddings(self, force_reload=False):
|
||||
if shared.sd_model is None:
|
||||
return
|
||||
t0 = time.time()
|
||||
if not force_reload:
|
||||
need_reload = False
|
||||
|
|
|
|||
|
|
@ -178,6 +178,7 @@ class ExtraNetworksPage:
|
|||
try:
|
||||
img = Image.open(f)
|
||||
except Exception:
|
||||
img = None
|
||||
shared.log.warning(f'Extra network removing invalid image: {f}')
|
||||
try:
|
||||
if img is None:
|
||||
|
|
@ -203,7 +204,7 @@ class ExtraNetworksPage:
|
|||
self.refresh_time = time.time()
|
||||
except Exception as e:
|
||||
self.items = []
|
||||
shared.log.error(f'Extra networks error listing items: class={self.__class__} tab={tabname} {e}')
|
||||
shared.log.error(f'Extra networks error listing items: class={self.__class__.__name__} tab={tabname} {e}')
|
||||
for item in self.items:
|
||||
self.metadata[item["name"]] = item.get("metadata", {})
|
||||
t1 = time.time()
|
||||
|
|
@ -612,11 +613,19 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
|
|||
<tr><td>Resolution</td><td>{meta.get('modelspec.resolution', 'N/A')}</td></tr>
|
||||
'''
|
||||
if page.title == 'Lora':
|
||||
tags = getattr(item, 'tags', {})
|
||||
tags = [f'{name}:{tags[name]}' for i, name in enumerate(tags)]
|
||||
tags = ' '.join(tags)
|
||||
try:
|
||||
tags = getattr(item, 'tags', {})
|
||||
tags = [f'{name}:{tags[name]}' for i, name in enumerate(tags)]
|
||||
tags = ' '.join(tags)
|
||||
except Exception:
|
||||
tags = ''
|
||||
try:
|
||||
triggers = ' '.join(info.get('tags', []))
|
||||
except Exception:
|
||||
triggers = ''
|
||||
lora = f'''
|
||||
<tr><td>Tags</td><td>{tags}</td></tr>
|
||||
<tr><td>Model tags</td><td>{tags}</td></tr>
|
||||
<tr><td>User tags</td><td>{triggers}</td></tr>
|
||||
<tr><td>Base model</td><td>{meta.get('ss_sd_model_name', 'N/A')}</td></tr>
|
||||
<tr><td>Resolution</td><td>{meta.get('ss_resolution', 'N/A')}</td></tr>
|
||||
<tr><td>Training images</td><td>{meta.get('ss_num_train_images', 'N/A')}</td></tr>
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||
|
||||
def list_items(self):
|
||||
checkpoint: sd_models.CheckpointInfo
|
||||
for name, checkpoint in sd_models.checkpoints_list.items():
|
||||
checkpoints = sd_models.checkpoints_list.copy()
|
||||
for name, checkpoint in checkpoints.items():
|
||||
try:
|
||||
fn = os.path.splitext(checkpoint.filename)[0]
|
||||
record = {
|
||||
|
|
|
|||
|
|
@ -361,10 +361,14 @@ def create_ui():
|
|||
if r.status_code == 200:
|
||||
d = r.json()
|
||||
res.append(download_civit_meta(item['filename'], d['modelId']))
|
||||
if d.get('images') is not None and len(d['images']) > 0 and len(d['images'][0]['url']) > 0:
|
||||
preview_url = d['images'][0]['url']
|
||||
res.append(download_civit_preview(item['filename'], preview_url))
|
||||
found = True
|
||||
if d.get('images') is not None:
|
||||
for i in d['images']:
|
||||
preview_url = i['url']
|
||||
img_res = download_civit_preview(item['filename'], preview_url)
|
||||
res.append(img_res)
|
||||
if 'error' not in img_res:
|
||||
found = True
|
||||
break
|
||||
if not found and civit_previews_rehash and os.stat(item['filename']).st_size < (1024 * 1024 * 1024):
|
||||
sha = modules.hashes.calculate_sha256(item['filename'], quiet=True)[:10]
|
||||
r = req(f'https://civitai.com/api/v1/model-versions/by-hash/{sha}')
|
||||
|
|
@ -372,9 +376,14 @@ def create_ui():
|
|||
if r.status_code == 200:
|
||||
d = r.json()
|
||||
res.append(download_civit_meta(item['filename'], d['modelId']))
|
||||
if d.get('images') is not None and len(d['images']) > 0 and len(d['images'][0]['url']) > 0:
|
||||
preview_url = d['images'][0]['url']
|
||||
res.append(download_civit_preview(item['filename'], preview_url))
|
||||
if d.get('images') is not None:
|
||||
for i in d['images']:
|
||||
preview_url = i['url']
|
||||
img_res = download_civit_preview(item['filename'], preview_url)
|
||||
res.append(img_res)
|
||||
if 'error' not in img_res:
|
||||
found = True
|
||||
break
|
||||
txt = '<br>'.join([r for r in res if len(r) > 0])
|
||||
return txt
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ exclude = [
|
|||
"extensions",
|
||||
"extensions-builtin",
|
||||
"modules/lora",
|
||||
"modules/lycoris",
|
||||
"modules/dml",
|
||||
"modules/models/diffusion",
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue