lint fixes

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/3593/head
Vladimir Mandic 2024-11-29 08:01:26 -05:00
parent d6c1487f9a
commit a635421231
7 changed files with 26 additions and 25 deletions

@ -1 +1 @@
Subproject commit 3008cee4b67bb00f8f1a4fe4510ec27ba92aa418
Subproject commit f083ce41a9f18b500f26745ea9e86855e509d2cb

View File

@ -1,8 +0,0 @@
# import networks
#
# list_available_loras = networks.list_available_networks
# available_loras = networks.available_networks
# available_lora_aliases = networks.available_network_aliases
# available_lora_hash_lookup = networks.available_network_hash_lookup
# forbidden_lora_aliases = networks.forbidden_network_aliases
# loaded_loras = networks.loaded_networks

View File

@ -107,14 +107,14 @@ def make_unet_conversion_map() -> Dict[str, str]:
class KeyConvert:
def __init__(self):
self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
self.LORA_PREFIX_UNET = "lora_unet_"
self.LORA_PREFIX_TEXT_ENCODER = "lora_te_"
self.OFT_PREFIX_UNET = "oft_unet_"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1_"
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2_"
self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
self.LORA_PREFIX_UNET = "lora_unet_"
self.LORA_PREFIX_TEXT_ENCODER = "lora_te_"
self.OFT_PREFIX_UNET = "oft_unet_"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1_"
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2_"
def __call__(self, key):
if self.is_sdxl:
@ -446,6 +446,7 @@ def _convert_kohya_sd3_lora_to_diffusers(state_dict):
lora_name_alpha = f"{lora_name}.alpha"
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
sd_lora_rank = 1
if lora_name.startswith(("lora_te_", "lora_te1_")):
down_weight = sds_sd.pop(key)
sd_lora_rank = down_weight.shape[0]

View File

@ -1,9 +1,11 @@
import os
from collections import namedtuple
import enum
from typing import Union
from collections import namedtuple
from modules import sd_models, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@ -105,7 +107,7 @@ class Network: # LoraModule
class ModuleType:
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: # pylint: disable=W0613
def create_module(self, net: Network, weights: NetworkWeights) -> Union[Network, None]: # pylint: disable=W0613
return None

View File

@ -1,5 +1,6 @@
import modules.lora.network as network
class ModuleTypeNorm(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["w_norm", "b_norm"]):

View File

@ -1,7 +1,7 @@
import torch
from einops import rearrange
import modules.lora.network as network
from modules.lora.lyco_helpers import factorization
from einops import rearrange
class ModuleTypeOFT(network.ModuleType):
@ -10,6 +10,7 @@ class ModuleTypeOFT(network.ModuleType):
return NetworkModuleOFT(net, weights)
return None
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
class NetworkModuleOFT(network.NetworkModule): # pylint: disable=abstract-method

View File

@ -3,6 +3,9 @@ import os
import re
import time
import concurrent
import torch
import diffusers.models.lora
import modules.lora.network as network
import modules.lora.network_lora as network_lora
import modules.lora.network_hada as network_hada
@ -14,8 +17,6 @@ import modules.lora.network_norm as network_norm
import modules.lora.network_glora as network_glora
import modules.lora.network_overrides as network_overrides
import modules.lora.lora_convert as lora_convert
import torch
import diffusers.models.lora
from modules import shared, devices, sd_models, sd_models_compile, errors, scripts, files_cache, model_quant
@ -74,7 +75,7 @@ def assign_network_names_to_compvis_modules(sd_model):
shared.sd_model.network_layer_mapping = network_layer_mapping
def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> network.Network | None:
def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> Union[network.Network, None]:
name = name.replace(".", "_")
shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" detected={network_on_disk.sd_version} method=diffusers scale={lora_scale} fuse={shared.opts.lora_fuse_diffusers}')
if not shared.native:
@ -103,7 +104,7 @@ def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_
return net
def load_network(name, network_on_disk) -> network.Network | None:
def load_network(name, network_on_disk) -> Union[network.Network, None]:
if not shared.sd_loaded:
return None
@ -173,6 +174,7 @@ def load_network(name, network_on_disk) -> network.Network | None:
net.bundle_embeddings = bundle_embeddings
return net
def maybe_recompile_model(names, te_multipliers):
recompile_model = False
if shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled:
@ -186,7 +188,7 @@ def maybe_recompile_model(names, te_multipliers):
if not recompile_model:
if len(loaded_networks) > 0 and debug:
shared.log.debug('Model Compile: Skipping LoRa loading')
return
return recompile_model
else:
recompile_model = True
shared.compiled_model_state.lora_model = []
@ -277,6 +279,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
t1 = time.time()
timer['load'] += t1 - t0
def set_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown, ex_bias):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
@ -389,6 +392,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
t1 = time.time()
timer['apply'] += t1 - t0
def network_load():
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
for component_name in ['text_encoder','text_encoder_2', 'unet', 'transformer']: