diff --git a/modules/api/endpoints.py b/modules/api/endpoints.py index ee30fac44..87362b4c9 100644 --- a/modules/api/endpoints.py +++ b/modules/api/endpoints.py @@ -41,10 +41,10 @@ def get_embeddings(): return {"loaded": convert_embeddings(db.word_embeddings), "skipped": convert_embeddings(db.skipped_embeddings)} def get_loras(): - from modules.lora import network, networks + from modules.lora import network, lora_load def create_lora_json(obj: network.NetworkOnDisk): return { "name": obj.name, "alias": obj.alias, "path": obj.filename, "metadata": obj.metadata } - return [create_lora_json(obj) for obj in networks.available_networks.values()] + return [create_lora_json(obj) for obj in lora_load.available_networks.values()] def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, filename: Optional[str] = None, title: Optional[str] = None, fullname: Optional[str] = None, hash: Optional[str] = None): # pylint: disable=redefined-builtin res = [] @@ -134,8 +134,8 @@ def post_refresh_vae(): return shared.refresh_vaes() def post_refresh_loras(): - from modules.lora import networks - return networks.list_available_networks() + from modules.lora import lora_load + return lora_load.list_available_networks() def get_extensions_list(): from modules import extensions diff --git a/modules/api/gallery.py b/modules/api/gallery.py index e56bb00ac..e1add81af 100644 --- a/modules/api/gallery.py +++ b/modules/api/gallery.py @@ -74,7 +74,7 @@ def register_api(app: FastAPI): # register api manager = ConnectionManager() def get_video_thumbnail(filepath): - from modules.ui_control_helpers import get_video_params + from modules.video import get_video_params try: stat = os.stat(filepath) frames, fps, duration, width, height, codec, frame = get_video_params(filepath, capture=True) diff --git a/modules/extra_networks.py b/modules/extra_networks.py index e882b113c..d8b638b85 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -19,8 +19,9 @@ def register_default_extra_networks(): from modules.ui_extra_networks_styles import ExtraNetworkStyles register_extra_network(ExtraNetworkStyles()) if not shared.opts.lora_legacy: - from modules.lora.networks import extra_network_lora - register_extra_network(extra_network_lora) + from modules.lora import lora_common, extra_networks_lora + lora_common.extra_network_lora = extra_networks_lora.ExtraNetworkLora() + register_extra_network(lora_common.extra_network_lora) if shared.opts.hypernetwork_enabled: from modules.ui_extra_networks_hypernet import ExtraNetworkHypernet register_extra_network(ExtraNetworkHypernet()) diff --git a/modules/infotext.py b/modules/infotext.py index 78c1fd92e..497879d31 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -31,7 +31,7 @@ def unquote(text): # disabled by default can be enabled if needed def check_lora(params): try: - import modules.lora.networks as networks + from modules.lora import lora_load from modules.errors import log # pylint: disable=redefined-outer-name except Exception: return @@ -39,14 +39,14 @@ def check_lora(params): found = [] missing = [] for l in loras: - lora = networks.available_network_hash_lookup.get(l, None) + lora = lora_load.available_network_hash_lookup.get(l, None) if lora is not None: found.append(lora.name) else: missing.append(l) loras = [s.strip() for s in params.get('LoRA networks', '').split(',')] for l in loras: - lora = networks.available_network_aliases.get(l, None) + lora = lora_load.available_network_aliases.get(l, None) if lora is not None: found.append(lora.name) else: @@ -54,7 +54,7 @@ def check_lora(params): # networks.available_network_aliases.get(name, None) loras = re_lora.findall(params.get('Prompt', '')) for l in loras: - lora = networks.available_network_aliases.get(l, None) + lora = lora_load.available_network_aliases.get(l, None) if lora is not None: found.append(lora.name) else: diff --git a/modules/lora/extra_networks_lora.py b/modules/lora/extra_networks_lora.py index a775ac91f..7227680ce 100644 --- a/modules/lora/extra_networks_lora.py +++ b/modules/lora/extra_networks_lora.py @@ -2,7 +2,7 @@ from typing import List import os import re import numpy as np -from modules.lora import networks, network_overrides +from modules.lora import networks, lora_overrides, lora_load from modules import extra_networks, shared, sd_models @@ -156,13 +156,13 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access debug_log(f'Network load: type=LoRA include={include} exclude={exclude} requested={requested} fn={fn}') - force_diffusers = network_overrides.check_override() + force_diffusers = lora_overrides.check_override() if force_diffusers: has_changed = False # diffusers handle their own loading if len(exclude) == 0: - networks.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load only on first call + lora_load.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load only on first call else: - networks.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load + lora_load.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load has_changed = self.changed(requested, include, exclude) if has_changed: networks.network_deactivate(include, exclude) @@ -180,7 +180,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): if shared.native: networks.previously_loaded_networks = networks.loaded_networks.copy() debug_log(f'Network load: type=LoRA active={[n.name for n in networks.previously_loaded_networks]} deactivate') - if shared.native and len(networks.diffuser_loaded) > 0: + if shared.native and len(lora_load.diffuser_loaded) > 0: if not (shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled is True): if hasattr(shared.sd_model, "unfuse_lora"): try: diff --git a/modules/lora/lora_apply.py b/modules/lora/lora_apply.py new file mode 100644 index 000000000..7e865a167 --- /dev/null +++ b/modules/lora/lora_apply.py @@ -0,0 +1,217 @@ +from typing import Union +import re +import time +import torch +import diffusers.models.lora +from modules.lora.lora_common import timer, debug, loaded_networks, previously_loaded_networks, extra_network_lora +from modules import shared, devices, errors, model_quant + + +bnb = None +re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") + + +def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, wanted_names: tuple): + global bnb # pylint: disable=W0603 + backup_size = 0 + if len(loaded_networks) > 0 and network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in loaded_networks]): # noqa: C419 # pylint: disable=R1729 + t0 = time.time() + + weights_backup = getattr(self, "network_weights_backup", None) + bias_backup = getattr(self, "network_bias_backup", None) + if weights_backup is not None or bias_backup is not None: + if (shared.opts.lora_fuse_diffusers and not isinstance(weights_backup, bool)) or (not shared.opts.lora_fuse_diffusers and isinstance(weights_backup, bool)): # invalidate so we can change direct/backup on-the-fly + weights_backup = None + bias_backup = None + self.network_weights_backup = weights_backup + self.network_bias_backup = bias_backup + + if weights_backup is None and wanted_names != (): # pylint: disable=C1803 + weight = getattr(self, 'weight', None) + self.network_weights_backup = None + if getattr(weight, "quant_type", None) in ['nf4', 'fp4']: + if bnb is None: + bnb = model_quant.load_bnb('Network load: type=LoRA', silent=True) + if bnb is not None: + with devices.inference_context(): + if shared.opts.lora_fuse_diffusers: + self.network_weights_backup = True + else: + self.network_weights_backup = bnb.functional.dequantize_4bit(weight, quant_state=weight.quant_state, quant_type=weight.quant_type, blocksize=weight.blocksize,) + self.quant_state = weight.quant_state + self.quant_type = weight.quant_type + self.blocksize = weight.blocksize + else: + if shared.opts.lora_fuse_diffusers: + self.network_weights_backup = True + else: + weights_backup = weight.clone() + self.network_weights_backup = weights_backup.to(devices.cpu) + else: + if shared.opts.lora_fuse_diffusers: + self.network_weights_backup = True + else: + self.network_weights_backup = weight.clone().to(devices.cpu) + + if bias_backup is None: + if getattr(self, 'bias', None) is not None: + if shared.opts.lora_fuse_diffusers: + self.network_bias_backup = True + else: + bias_backup = self.bias.clone() + bias_backup = bias_backup.to(devices.cpu) + + if getattr(self, 'network_weights_backup', None) is not None: + backup_size += self.network_weights_backup.numel() * self.network_weights_backup.element_size() if isinstance(self.network_weights_backup, torch.Tensor) else 0 + if getattr(self, 'network_bias_backup', None) is not None: + backup_size += self.network_bias_backup.numel() * self.network_bias_backup.element_size() if isinstance(self.network_bias_backup, torch.Tensor) else 0 + timer.backup += time.time() - t0 + return backup_size + + +def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, use_previous: bool = False): + if shared.opts.diffusers_offload_mode == "none": + try: + self.to(devices.device) + except Exception: + pass + batch_updown = None + batch_ex_bias = None + loaded = loaded_networks if not use_previous else previously_loaded_networks + for net in loaded: + module = net.modules.get(network_layer_name, None) + if module is None: + continue + try: + t0 = time.time() + try: + weight = self.weight.to(devices.device) + except Exception: + weight = self.weight + + updown, ex_bias = module.calc_updown(weight) + if updown is not None: + if batch_updown is not None: + batch_updown += updown.to(batch_updown.device) + else: + batch_updown = updown.to(devices.device) + if ex_bias is not None: + if batch_ex_bias: + batch_ex_bias += ex_bias.to(batch_ex_bias.device) + else: + batch_ex_bias = ex_bias.to(devices.device) + timer.calc += time.time() - t0 + + if shared.opts.diffusers_offload_mode == "sequential": + t0 = time.time() + if batch_updown is not None: + batch_updown = batch_updown.to(devices.cpu) + if batch_ex_bias is not None: + batch_ex_bias = batch_ex_bias.to(devices.cpu) + t1 = time.time() + timer.move += t1 - t0 + except RuntimeError as e: + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + if debug: + module_name = net.modules.get(network_layer_name, None) + shared.log.error(f'LoRA apply weight name="{net.name}" module="{module_name}" layer="{network_layer_name}" {e}') + errors.display(e, 'LoRA') + raise RuntimeError('LoRA apply weight') from e + continue + return batch_updown, batch_ex_bias + + +def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], model_weights: Union[None, torch.Tensor] = None, lora_weights: torch.Tensor = None, deactivate: bool = False, device: torch.device = devices.device): + if lora_weights is None: + return None + if deactivate: + lora_weights *= -1 + if model_weights is None: # weights are used if provided-from-backup else use self.weight + model_weights = self.weight + # TODO lora: add other quantization types + weight = None + if self.__class__.__name__ == 'Linear4bit' and bnb is not None: + try: + dequant_weight = bnb.functional.dequantize_4bit(model_weights.to(devices.device), quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize) + new_weight = dequant_weight.to(devices.device) + lora_weights.to(devices.device) + weight = bnb.nn.Params4bit(new_weight, quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize, requires_grad=False) + # weight._quantize(devices.device) # TODO force imediate quantization + except Exception as e: + shared.log.error(f'Network load: type=LoRA quant=bnb cls={self.__class__.__name__} type={self.quant_type} blocksize={self.blocksize} state={vars(self.quant_state)} weight={self.weight} bias={lora_weights} {e}') + else: + try: + new_weight = model_weights.to(devices.device) + lora_weights.to(devices.device) + except Exception: + new_weight = model_weights + lora_weights # try without device cast + weight = torch.nn.Parameter(new_weight, requires_grad=False) + try: + # weight.to(device=device) # TODO required since quantization happens only during .to call, not during params creation + pass + except Exception: + pass # may fail if weights is meta tensor + return weight + + +def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False, device: torch.device = devices.device): + weights_backup = getattr(self, "network_weights_backup", False) + bias_backup = getattr(self, "network_bias_backup", False) + device = device or devices.device + if not isinstance(weights_backup, bool): # remove previous backup if we switched settings + weights_backup = True + if not isinstance(bias_backup, bool): + bias_backup = True + if not weights_backup and not bias_backup: + return + t0 = time.time() + + if weights_backup: + if updown is not None and len(self.weight.shape) == 4 and self.weight.shape[1] == 9: # inpainting model so zero pad updown to make channel 4 to 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable + if updown is not None: + weight = network_add_weights(self, lora_weights=updown, deactivate=deactivate, device=device) + if weight is not None: + self.weight = weight + + if bias_backup: + if ex_bias is not None: + bias = network_add_weights(self, lora_weights=ex_bias, deactivate=deactivate, device=device) + if bias is not None: + self.bias = bias + + if hasattr(self, "qweight") and hasattr(self, "freeze"): + self.freeze() + + timer.apply += time.time() - t0 + + +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, device: torch.device, deactivate: bool = False): + weights_backup = getattr(self, "network_weights_backup", None) + bias_backup = getattr(self, "network_bias_backup", None) + if weights_backup is None and bias_backup is None: + return + t0 = time.time() + + if weights_backup is not None: + self.weight = None + if updown is not None and len(weights_backup.shape) == 4 and weights_backup.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable + if updown is not None: + weight = network_add_weights(self, model_weights=weights_backup, lora_weights=updown, deactivate=deactivate, device=device) + if weight is not None: + self.weight = weight + else: + self.weight = torch.nn.Parameter(weights_backup.to(device), requires_grad=False) + + if bias_backup is not None: + self.bias = None + if ex_bias is not None: + bias = network_add_weights(self, model_weights=weights_backup, lora_weights=ex_bias, deactivate=deactivate, device=device) + if bias: + self.weight = bias + else: + self.bias = torch.nn.Parameter(bias_backup.to(device), requires_grad=False) + + if hasattr(self, "qweight") and hasattr(self, "freeze"): + self.freeze() + + timer.apply += time.time() - t0 diff --git a/modules/lora/lora_common.py b/modules/lora/lora_common.py new file mode 100644 index 000000000..a6b15ae13 --- /dev/null +++ b/modules/lora/lora_common.py @@ -0,0 +1,21 @@ +from typing import List +import os +from modules.lora import lora_timers +from modules.lora import network_lora, network_hada, network_ia3, network_oft, network_lokr, network_full, network_norm, network_glora + + +timer = lora_timers.Timer() +debug = os.environ.get('SD_LORA_DEBUG', None) is not None +module_types = [ + network_lora.ModuleTypeLora(), + network_hada.ModuleTypeHada(), + network_ia3.ModuleTypeIa3(), + network_oft.ModuleTypeOFT(), + network_lokr.ModuleTypeLokr(), + network_full.ModuleTypeFull(), + network_norm.ModuleTypeNorm(), + network_glora.ModuleTypeGLora(), +] +loaded_networks: List = [] # no type due to circular import +previously_loaded_networks: List = [] # no type due to circular import +extra_network_lora = None # initialized in extra_networks.py diff --git a/modules/lora/lora_load.py b/modules/lora/lora_load.py new file mode 100644 index 000000000..43efc6f04 --- /dev/null +++ b/modules/lora/lora_load.py @@ -0,0 +1,285 @@ +from typing import Union +import os +import time +import concurrent +from modules import shared, errors, devices, sd_models, sd_models_compile, files_cache +from modules.lora import network, lora_overrides, lora_convert +from modules.lora.lora_common import timer, debug, module_types, loaded_networks + + +diffuser_loaded = [] +diffuser_scales = [] +lora_cache = {} +available_networks = {} +available_network_aliases = {} +forbidden_network_aliases = {} +available_network_hash_lookup = {} + + +def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> Union[network.Network, None]: + t0 = time.time() + name = name.replace(".", "_") + shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" detected={network_on_disk.sd_version} method=diffusers scale={lora_scale} fuse={shared.opts.lora_fuse_diffusers}') + if not shared.native: + return None + if not hasattr(shared.sd_model, 'load_lora_weights'): + shared.log.error(f'Network load: type=LoRA class={shared.sd_model.__class__} does not implement load lora') + return None + try: + shared.sd_model.load_lora_weights(network_on_disk.filename, adapter_name=name) + except Exception as e: + if 'already in use' in str(e): + pass + else: + if 'The following keys have not been correctly renamed' in str(e): + shared.log.error(f'Network load: type=LoRA name="{name}" diffusers unsupported format') + else: + shared.log.error(f'Network load: type=LoRA name="{name}" {e}') + if debug: + errors.display(e, "LoRA") + return None + if name not in diffuser_loaded: + diffuser_loaded.append(name) + diffuser_scales.append(lora_scale) + net = network.Network(name, network_on_disk) + net.mtime = os.path.getmtime(network_on_disk.filename) + timer.activate += time.time() - t0 + return net + + +def load_safetensors(name, network_on_disk) -> Union[network.Network, None]: + if not shared.sd_loaded: + return None + + cached = lora_cache.get(name, None) + if debug: + shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" type=lora {"cached" if cached else ""}') + if cached is not None: + return cached + net = network.Network(name, network_on_disk) + net.mtime = os.path.getmtime(network_on_disk.filename) + sd = sd_models.read_state_dict(network_on_disk.filename, what='network') + if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict + sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # pylint: disable=protected-access + if shared.sd_model_type == 'sd3': # if kohya flux lora, convert state_dict + try: + sd = lora_convert._convert_kohya_sd3_lora_to_diffusers(sd) or sd # pylint: disable=protected-access + except ValueError: # EAFP for diffusers PEFT keys + pass + lora_convert.assign_network_names_to_compvis_modules(shared.sd_model) + keys_failed_to_match = {} + matched_networks = {} + bundle_embeddings = {} + dtypes = [] + convert = lora_convert.KeyConvert() + for key_network, weight in sd.items(): + parts = key_network.split('.') + if parts[0] == "bundle_emb": + emb_name, vec_name = parts[1], key_network.split(".", 2)[-1] + emb_dict = bundle_embeddings.get(emb_name, {}) + emb_dict[vec_name] = weight + bundle_embeddings[emb_name] = emb_dict + continue + if len(parts) > 5: # messy handler for diffusers peft lora + key_network_without_network_parts = '_'.join(parts[:-2]) + if not key_network_without_network_parts.startswith('lora_'): + key_network_without_network_parts = 'lora_' + key_network_without_network_parts + network_part = '.'.join(parts[-2:]).replace('lora_A', 'lora_down').replace('lora_B', 'lora_up') + else: + key_network_without_network_parts, network_part = key_network.split(".", 1) + key, sd_module = convert(key_network_without_network_parts) + if sd_module is None: + keys_failed_to_match[key_network] = key + continue + if key not in matched_networks: + matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) + matched_networks[key].w[network_part] = weight + if weight.dtype not in dtypes: + dtypes.append(weight.dtype) + network_types = [] + for key, weights in matched_networks.items(): + net_module = None + for nettype in module_types: + net_module = nettype.create_module(net, weights) + if net_module is not None: + network_types.append(nettype.__class__.__name__) + break + if net_module is None: + shared.log.error(f'LoRA unhandled: name={name} key={key} weights={weights.w.keys()}') + else: + net.modules[key] = net_module + if len(keys_failed_to_match) > 0: + shared.log.warning(f'Network load: type=LoRA name="{name}" type={set(network_types)} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}') + if debug: + shared.log.debug(f'Network load: type=LoRA name="{name}" unmatched={keys_failed_to_match}') + else: + shared.log.debug(f'Network load: type=LoRA name="{name}" type={set(network_types)} keys={len(matched_networks)} dtypes={dtypes} direct={shared.opts.lora_fuse_diffusers}') + if len(matched_networks) == 0: + return None + lora_cache[name] = net + net.bundle_embeddings = bundle_embeddings + return net + + +def maybe_recompile_model(names, te_multipliers): + recompile_model = False + skip_lora_load = False + if shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled: + if len(names) == len(shared.compiled_model_state.lora_model): + for i, name in enumerate(names): + if shared.compiled_model_state.lora_model[ + i] != f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}": + recompile_model = True + shared.compiled_model_state.lora_model = [] + break + if not recompile_model: + skip_lora_load = True + if len(loaded_networks) > 0 and debug: + shared.log.debug('Model Compile: Skipping LoRa loading') + return recompile_model, skip_lora_load + else: + recompile_model = True + shared.compiled_model_state.lora_model = [] + if recompile_model: + backup_cuda_compile = shared.opts.cuda_compile + backup_scheduler = getattr(shared.sd_model, "scheduler", None) + sd_models.unload_model_weights(op='model') + shared.opts.cuda_compile = [] + sd_models.reload_model_weights(op='model') + shared.opts.cuda_compile = backup_cuda_compile + if backup_scheduler is not None: + shared.sd_model.scheduler = backup_scheduler + return recompile_model, skip_lora_load + + +def list_available_networks(): + t0 = time.time() + available_networks.clear() + available_network_aliases.clear() + forbidden_network_aliases.clear() + available_network_hash_lookup.clear() + forbidden_network_aliases.update({"none": 1, "Addams": 1}) + if not os.path.exists(shared.cmd_opts.lora_dir): + shared.log.warning(f'LoRA directory not found: path="{shared.cmd_opts.lora_dir}"') + + def add_network(filename): + if not os.path.isfile(filename): + return + name = os.path.splitext(os.path.basename(filename))[0] + name = name.replace('.', '_') + try: + entry = network.NetworkOnDisk(name, filename) + available_networks[entry.name] = entry + if entry.alias in available_network_aliases: + forbidden_network_aliases[entry.alias.lower()] = 1 + if shared.opts.lora_preferred_name == 'filename': + available_network_aliases[entry.name] = entry + else: + available_network_aliases[entry.alias] = entry + if entry.shorthash: + available_network_hash_lookup[entry.shorthash] = entry + except OSError as e: # should catch FileNotFoundError and PermissionError etc. + shared.log.error(f'LoRA: filename="{filename}" {e}') + + candidates = sorted(files_cache.list_files(shared.cmd_opts.lora_dir, ext_filter=[".pt", ".ckpt", ".safetensors"])) + with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor: + for fn in candidates: + executor.submit(add_network, fn) + t1 = time.time() + timer.list = t1 - t0 + shared.log.info(f'Available LoRAs: path="{shared.cmd_opts.lora_dir}" items={len(available_networks)} folders={len(forbidden_network_aliases)} time={t1 - t0:.2f}') + + +def network_download(name): + from huggingface_hub import hf_hub_download + if os.path.exists(name): + return network.NetworkOnDisk(name, name) + parts = name.split('/') + if len(parts) >= 5 and parts[1] == 'huggingface.co': + repo_id = f'{parts[2]}/{parts[3]}' + filename = '/'.join(parts[4:]) + fn = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=shared.opts.hfcache_dir) + return network.NetworkOnDisk(name, fn) + return None + + +def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): + networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names] + if any(x is None for x in networks_on_disk): + list_available_networks() + networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names] + for i in range(len(names)): + if names[i].startswith('/'): + networks_on_disk[i] = network_download(names[i]) + failed_to_load_networks = [] + recompile_model, skip_lora_load = maybe_recompile_model(names, te_multipliers) + + loaded_networks.clear() + diffuser_loaded.clear() + diffuser_scales.clear() + t0 = time.time() + + for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): + net = None + if network_on_disk is not None: + shorthash = getattr(network_on_disk, 'shorthash', '').lower() + if debug: + shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" hash="{shorthash}"') + try: + if recompile_model: + shared.compiled_model_state.lora_model.append(f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}") + if shared.opts.lora_force_diffusers or lora_overrides.check_override(shorthash): # OpenVINO only works with Diffusers LoRa loading + net = load_diffusers(name, network_on_disk, lora_scale=te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier) + else: + net = load_safetensors(name, network_on_disk) + if net is not None: + net.mentioned_name = name + network_on_disk.read_hash() + except Exception as e: + shared.log.error(f'Network load: type=LoRA file="{network_on_disk.filename}" {e}') + if debug: + errors.display(e, 'LoRA') + continue + if net is None: + failed_to_load_networks.append(name) + shared.log.error(f'Network load: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} failed') + continue + if hasattr(shared.sd_model, 'embedding_db'): + shared.sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings) + net.te_multiplier = te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier + net.unet_multiplier = unet_multipliers[i] if unet_multipliers else shared.opts.extra_networks_default_multiplier + net.dyn_dim = dyn_dims[i] if dyn_dims else shared.opts.extra_networks_default_multiplier + loaded_networks.append(net) + + while len(lora_cache) > shared.opts.lora_in_memory_limit: + name = next(iter(lora_cache)) + lora_cache.pop(name, None) + + if not skip_lora_load and len(diffuser_loaded) > 0: + shared.log.debug(f'Network load: type=LoRA loaded={diffuser_loaded} available={shared.sd_model.get_list_adapters()} active={shared.sd_model.get_active_adapters()} scales={diffuser_scales}') + try: + t0 = time.time() + shared.sd_model.set_adapters(adapter_names=diffuser_loaded, adapter_weights=diffuser_scales) + if shared.opts.lora_fuse_diffusers and not lora_overrides.check_fuse(): + shared.sd_model.fuse_lora(adapter_names=diffuser_loaded, lora_scale=1.0, fuse_unet=True, fuse_text_encoder=True) # fuse uses fixed scale since later apply does the scaling + shared.sd_model.unload_lora_weights() + timer.activate += time.time() - t0 + except Exception as e: + shared.log.error(f'Network load: type=LoRA {e}') + if debug: + errors.display(e, 'LoRA') + + if len(loaded_networks) > 0 and debug: + shared.log.debug(f'Network load: type=LoRA loaded={[n.name for n in loaded_networks]} cache={list(lora_cache)}') + + if recompile_model: + shared.log.info("Network load: type=LoRA recompiling model") + backup_lora_model = shared.compiled_model_state.lora_model + if 'Model' in shared.opts.cuda_compile: + shared.sd_model = sd_models_compile.compile_diffusers(shared.sd_model) + shared.compiled_model_state.lora_model = backup_lora_model + + if len(loaded_networks) > 0: + devices.torch_gc() + + timer.load = time.time() - t0 diff --git a/modules/lora/network_overrides.py b/modules/lora/lora_overrides.py similarity index 100% rename from modules/lora/network_overrides.py rename to modules/lora/lora_overrides.py diff --git a/modules/lora/network.py b/modules/lora/network.py index c4768d9ad..f6d93009c 100644 --- a/modules/lora/network.py +++ b/modules/lora/network.py @@ -91,8 +91,10 @@ class NetworkOnDisk: self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') def get_alias(self): - import modules.lora.networks as networks - return self.name if shared.opts.lora_preferred_name == "filename" or self.alias.lower() in networks.forbidden_network_aliases else self.alias + if shared.opts.lora_preferred_name == "filename": + return self.name + else: + return self.alias class Network: # LoraModule diff --git a/modules/lora/networks.py b/modules/lora/networks.py index 4903e5a6f..44dad1afd 100644 --- a/modules/lora/networks.py +++ b/modules/lora/networks.py @@ -1,583 +1,12 @@ -from typing import Union, List from contextlib import nullcontext -import os -import re import time -import concurrent -import torch -import diffusers.models.lora import rich.progress as rp - -from modules.lora import lora_timers, network, lora_convert, network_overrides -from modules.lora import network_lora, network_hada, network_ia3, network_oft, network_lokr, network_full, network_norm, network_glora -from modules.lora.extra_networks_lora import ExtraNetworkLora -from modules import shared, devices, sd_models, sd_models_compile, errors, files_cache, model_quant +from modules.lora.lora_common import timer, debug, loaded_networks, previously_loaded_networks +from modules.lora.lora_apply import network_apply_weights, network_apply_direct, network_backup_weights, network_calc_weights +from modules import shared, devices, sd_models -debug = os.environ.get('SD_LORA_DEBUG', None) is not None -extra_network_lora = ExtraNetworkLora() -available_networks = {} -available_network_aliases = {} -loaded_networks: List[network.Network] = [] -previously_loaded_networks: List[network.Network] = [] applied_layers: list[str] = [] -bnb = None -lora_cache = {} -diffuser_loaded = [] -diffuser_scales = [] -available_network_hash_lookup = {} -forbidden_network_aliases = {} -re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") -timer = lora_timers.Timer() -module_types = [ - network_lora.ModuleTypeLora(), - network_hada.ModuleTypeHada(), - network_ia3.ModuleTypeIa3(), - network_oft.ModuleTypeOFT(), - network_lokr.ModuleTypeLokr(), - network_full.ModuleTypeFull(), - network_norm.ModuleTypeNorm(), - network_glora.ModuleTypeGLora(), -] - -# section: load networks from disk - -def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> Union[network.Network, None]: - t0 = time.time() - name = name.replace(".", "_") - shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" detected={network_on_disk.sd_version} method=diffusers scale={lora_scale} fuse={shared.opts.lora_fuse_diffusers}') - if not shared.native: - return None - if not hasattr(shared.sd_model, 'load_lora_weights'): - shared.log.error(f'Network load: type=LoRA class={shared.sd_model.__class__} does not implement load lora') - return None - try: - shared.sd_model.load_lora_weights(network_on_disk.filename, adapter_name=name) - except Exception as e: - if 'already in use' in str(e): - pass - else: - if 'The following keys have not been correctly renamed' in str(e): - shared.log.error(f'Network load: type=LoRA name="{name}" diffusers unsupported format') - else: - shared.log.error(f'Network load: type=LoRA name="{name}" {e}') - if debug: - errors.display(e, "LoRA") - return None - if name not in diffuser_loaded: - diffuser_loaded.append(name) - diffuser_scales.append(lora_scale) - net = network.Network(name, network_on_disk) - net.mtime = os.path.getmtime(network_on_disk.filename) - timer.activate += time.time() - t0 - return net - - -def load_safetensors(name, network_on_disk) -> Union[network.Network, None]: - if not shared.sd_loaded: - return None - - cached = lora_cache.get(name, None) - if debug: - shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" type=lora {"cached" if cached else ""}') - if cached is not None: - return cached - net = network.Network(name, network_on_disk) - net.mtime = os.path.getmtime(network_on_disk.filename) - sd = sd_models.read_state_dict(network_on_disk.filename, what='network') - if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict - sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # pylint: disable=protected-access - if shared.sd_model_type == 'sd3': # if kohya flux lora, convert state_dict - try: - sd = lora_convert._convert_kohya_sd3_lora_to_diffusers(sd) or sd # pylint: disable=protected-access - except ValueError: # EAFP for diffusers PEFT keys - pass - lora_convert.assign_network_names_to_compvis_modules(shared.sd_model) - keys_failed_to_match = {} - matched_networks = {} - bundle_embeddings = {} - convert = lora_convert.KeyConvert() - for key_network, weight in sd.items(): - parts = key_network.split('.') - if parts[0] == "bundle_emb": - emb_name, vec_name = parts[1], key_network.split(".", 2)[-1] - emb_dict = bundle_embeddings.get(emb_name, {}) - emb_dict[vec_name] = weight - bundle_embeddings[emb_name] = emb_dict - continue - if len(parts) > 5: # messy handler for diffusers peft lora - key_network_without_network_parts = '_'.join(parts[:-2]) - if not key_network_without_network_parts.startswith('lora_'): - key_network_without_network_parts = 'lora_' + key_network_without_network_parts - network_part = '.'.join(parts[-2:]).replace('lora_A', 'lora_down').replace('lora_B', 'lora_up') - else: - key_network_without_network_parts, network_part = key_network.split(".", 1) - key, sd_module = convert(key_network_without_network_parts) - if sd_module is None: - keys_failed_to_match[key_network] = key - continue - if key not in matched_networks: - matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) - matched_networks[key].w[network_part] = weight - network_types = [] - for key, weights in matched_networks.items(): - net_module = None - for nettype in module_types: - net_module = nettype.create_module(net, weights) - if net_module is not None: - network_types.append(nettype.__class__.__name__) - break - if net_module is None: - shared.log.error(f'LoRA unhandled: name={name} key={key} weights={weights.w.keys()}') - else: - net.modules[key] = net_module - if len(keys_failed_to_match) > 0: - shared.log.warning(f'Network load: type=LoRA name="{name}" type={set(network_types)} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}') - if debug: - shared.log.debug(f'Network load: type=LoRA name="{name}" unmatched={keys_failed_to_match}') - else: - shared.log.debug(f'Network load: type=LoRA name="{name}" type={set(network_types)} keys={len(matched_networks)} direct={shared.opts.lora_fuse_diffusers}') - if len(matched_networks) == 0: - return None - lora_cache[name] = net - net.bundle_embeddings = bundle_embeddings - return net - - -def maybe_recompile_model(names, te_multipliers): - recompile_model = False - skip_lora_load = False - if shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled: - if len(names) == len(shared.compiled_model_state.lora_model): - for i, name in enumerate(names): - if shared.compiled_model_state.lora_model[ - i] != f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}": - recompile_model = True - shared.compiled_model_state.lora_model = [] - break - if not recompile_model: - skip_lora_load = True - if len(loaded_networks) > 0 and debug: - shared.log.debug('Model Compile: Skipping LoRa loading') - return recompile_model, skip_lora_load - else: - recompile_model = True - shared.compiled_model_state.lora_model = [] - if recompile_model: - backup_cuda_compile = shared.opts.cuda_compile - backup_scheduler = getattr(shared.sd_model, "scheduler", None) - sd_models.unload_model_weights(op='model') - shared.opts.cuda_compile = [] - sd_models.reload_model_weights(op='model') - shared.opts.cuda_compile = backup_cuda_compile - if backup_scheduler is not None: - shared.sd_model.scheduler = backup_scheduler - return recompile_model, skip_lora_load - - -def list_available_networks(): - t0 = time.time() - available_networks.clear() - available_network_aliases.clear() - forbidden_network_aliases.clear() - available_network_hash_lookup.clear() - forbidden_network_aliases.update({"none": 1, "Addams": 1}) - if not os.path.exists(shared.cmd_opts.lora_dir): - shared.log.warning(f'LoRA directory not found: path="{shared.cmd_opts.lora_dir}"') - - def add_network(filename): - if not os.path.isfile(filename): - return - name = os.path.splitext(os.path.basename(filename))[0] - name = name.replace('.', '_') - try: - entry = network.NetworkOnDisk(name, filename) - available_networks[entry.name] = entry - if entry.alias in available_network_aliases: - forbidden_network_aliases[entry.alias.lower()] = 1 - if shared.opts.lora_preferred_name == 'filename': - available_network_aliases[entry.name] = entry - else: - available_network_aliases[entry.alias] = entry - if entry.shorthash: - available_network_hash_lookup[entry.shorthash] = entry - except OSError as e: # should catch FileNotFoundError and PermissionError etc. - shared.log.error(f'LoRA: filename="{filename}" {e}') - - candidates = sorted(files_cache.list_files(shared.cmd_opts.lora_dir, ext_filter=[".pt", ".ckpt", ".safetensors"])) - with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor: - for fn in candidates: - executor.submit(add_network, fn) - t1 = time.time() - timer.list = t1 - t0 - shared.log.info(f'Available LoRAs: path="{shared.cmd_opts.lora_dir}" items={len(available_networks)} folders={len(forbidden_network_aliases)} time={t1 - t0:.2f}') - - -def network_download(name): - from huggingface_hub import hf_hub_download - if os.path.exists(name): - return network.NetworkOnDisk(name, name) - parts = name.split('/') - if len(parts) >= 5 and parts[1] == 'huggingface.co': - repo_id = f'{parts[2]}/{parts[3]}' - filename = '/'.join(parts[4:]) - fn = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=shared.opts.hfcache_dir) - return network.NetworkOnDisk(name, fn) - return None - - -def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): - networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names] - if any(x is None for x in networks_on_disk): - list_available_networks() - networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names] - for i in range(len(names)): - if names[i].startswith('/'): - networks_on_disk[i] = network_download(names[i]) - failed_to_load_networks = [] - recompile_model, skip_lora_load = maybe_recompile_model(names, te_multipliers) - - loaded_networks.clear() - diffuser_loaded.clear() - diffuser_scales.clear() - t0 = time.time() - - for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): - net = None - if network_on_disk is not None: - shorthash = getattr(network_on_disk, 'shorthash', '').lower() - if debug: - shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" hash="{shorthash}"') - try: - if recompile_model: - shared.compiled_model_state.lora_model.append(f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}") - if shared.opts.lora_force_diffusers or network_overrides.check_override(shorthash): # OpenVINO only works with Diffusers LoRa loading - net = load_diffusers(name, network_on_disk, lora_scale=te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier) - else: - net = load_safetensors(name, network_on_disk) - if net is not None: - net.mentioned_name = name - network_on_disk.read_hash() - except Exception as e: - shared.log.error(f'Network load: type=LoRA file="{network_on_disk.filename}" {e}') - if debug: - errors.display(e, 'LoRA') - continue - if net is None: - failed_to_load_networks.append(name) - shared.log.error(f'Network load: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} failed') - continue - if hasattr(shared.sd_model, 'embedding_db'): - shared.sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings) - net.te_multiplier = te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier - net.unet_multiplier = unet_multipliers[i] if unet_multipliers else shared.opts.extra_networks_default_multiplier - net.dyn_dim = dyn_dims[i] if dyn_dims else shared.opts.extra_networks_default_multiplier - loaded_networks.append(net) - - while len(lora_cache) > shared.opts.lora_in_memory_limit: - name = next(iter(lora_cache)) - lora = lora_cache.pop(name, None) - del lora - - if not skip_lora_load and len(diffuser_loaded) > 0: - shared.log.debug(f'Network load: type=LoRA loaded={diffuser_loaded} available={shared.sd_model.get_list_adapters()} active={shared.sd_model.get_active_adapters()} scales={diffuser_scales}') - try: - t0 = time.time() - shared.sd_model.set_adapters(adapter_names=diffuser_loaded, adapter_weights=diffuser_scales) - if shared.opts.lora_fuse_diffusers and not network_overrides.check_fuse(): - shared.sd_model.fuse_lora(adapter_names=diffuser_loaded, lora_scale=1.0, fuse_unet=True, fuse_text_encoder=True) # fuse uses fixed scale since later apply does the scaling - shared.sd_model.unload_lora_weights() - timer.activate += time.time() - t0 - except Exception as e: - shared.log.error(f'Network load: type=LoRA {e}') - if debug: - errors.display(e, 'LoRA') - - if len(loaded_networks) > 0 and debug: - shared.log.debug(f'Network load: type=LoRA loaded={[n.name for n in loaded_networks]} cache={list(lora_cache)}') - - if recompile_model: - shared.log.info("Network load: type=LoRA recompiling model") - backup_lora_model = shared.compiled_model_state.lora_model - if 'Model' in shared.opts.cuda_compile: - shared.sd_model = sd_models_compile.compile_diffusers(shared.sd_model) - shared.compiled_model_state.lora_model = backup_lora_model - - if len(loaded_networks) > 0: - devices.torch_gc() - - timer.load = time.time() - t0 - - -# section: process loaded networks - -def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, wanted_names: tuple): - global bnb # pylint: disable=W0603 - backup_size = 0 - if len(loaded_networks) > 0 and network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in loaded_networks]): # noqa: C419 # pylint: disable=R1729 - t0 = time.time() - - weights_backup = getattr(self, "network_weights_backup", None) - bias_backup = getattr(self, "network_bias_backup", None) - if weights_backup is not None or bias_backup is not None: - if (shared.opts.lora_fuse_diffusers and not isinstance(weights_backup, bool)) or (not shared.opts.lora_fuse_diffusers and isinstance(weights_backup, bool)): # invalidate so we can change direct/backup on-the-fly - weights_backup = None - bias_backup = None - self.network_weights_backup = weights_backup - self.network_bias_backup = bias_backup - - if weights_backup is None and wanted_names != (): # pylint: disable=C1803 - weight = getattr(self, 'weight', None) - self.network_weights_backup = None - if getattr(weight, "quant_type", None) in ['nf4', 'fp4']: - if bnb is None: - bnb = model_quant.load_bnb('Network load: type=LoRA', silent=True) - if bnb is not None: - with devices.inference_context(): - if shared.opts.lora_fuse_diffusers: - self.network_weights_backup = True - else: - self.network_weights_backup = bnb.functional.dequantize_4bit(weight, quant_state=weight.quant_state, quant_type=weight.quant_type, blocksize=weight.blocksize,) - self.quant_state = weight.quant_state - self.quant_type = weight.quant_type - self.blocksize = weight.blocksize - else: - if shared.opts.lora_fuse_diffusers: - self.network_weights_backup = True - else: - weights_backup = weight.clone() - self.network_weights_backup = weights_backup.to(devices.cpu) - else: - if shared.opts.lora_fuse_diffusers: - self.network_weights_backup = True - else: - self.network_weights_backup = weight.clone().to(devices.cpu) - - if bias_backup is None: - if getattr(self, 'bias', None) is not None: - if shared.opts.lora_fuse_diffusers: - self.network_bias_backup = True - else: - bias_backup = self.bias.clone() - bias_backup = bias_backup.to(devices.cpu) - - if getattr(self, 'network_weights_backup', None) is not None: - backup_size += self.network_weights_backup.numel() * self.network_weights_backup.element_size() if isinstance(self.network_weights_backup, torch.Tensor) else 0 - if getattr(self, 'network_bias_backup', None) is not None: - backup_size += self.network_bias_backup.numel() * self.network_bias_backup.element_size() if isinstance(self.network_bias_backup, torch.Tensor) else 0 - timer.backup += time.time() - t0 - return backup_size - - -def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, use_previous: bool = False): - if shared.opts.diffusers_offload_mode == "none": - try: - self.to(devices.device) - except Exception: - pass - batch_updown = None - batch_ex_bias = None - loaded = loaded_networks if not use_previous else previously_loaded_networks - for net in loaded: - module = net.modules.get(network_layer_name, None) - if module is None: - continue - try: - t0 = time.time() - try: - weight = self.weight.to(devices.device) - except Exception: - weight = self.weight - - updown, ex_bias = module.calc_updown(weight) - if updown is not None: - if batch_updown is not None: - batch_updown += updown.to(batch_updown.device) - else: - batch_updown = updown.to(devices.device) - if ex_bias is not None: - if batch_ex_bias: - batch_ex_bias += ex_bias.to(batch_ex_bias.device) - else: - batch_ex_bias = ex_bias.to(devices.device) - timer.calc += time.time() - t0 - - if shared.opts.diffusers_offload_mode == "sequential": - t0 = time.time() - if batch_updown is not None: - batch_updown = batch_updown.to(devices.cpu) - if batch_ex_bias is not None: - batch_ex_bias = batch_ex_bias.to(devices.cpu) - t1 = time.time() - timer.move += t1 - t0 - except RuntimeError as e: - extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 - if debug: - module_name = net.modules.get(network_layer_name, None) - shared.log.error(f'LoRA apply weight name="{net.name}" module="{module_name}" layer="{network_layer_name}" {e}') - errors.display(e, 'LoRA') - raise RuntimeError('LoRA apply weight') from e - continue - return batch_updown, batch_ex_bias - - -def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], model_weights: Union[None, torch.Tensor] = None, lora_weights: torch.Tensor = None, deactivate: bool = False, device: torch.device = None): - if lora_weights is None: - return None - if deactivate: - lora_weights *= -1 - if model_weights is None: # weights are used if provided-from-backup else use self.weight - model_weights = self.weight - # TODO lora: add other quantization types - weight = None - device = device or devices.device - if self.__class__.__name__ == 'Linear4bit' and bnb is not None: - try: - dequant_weight = bnb.functional.dequantize_4bit(model_weights.to(device), quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize) - new_weight = dequant_weight.to(device) + lora_weights.to(device) - weight = bnb.nn.Params4bit(new_weight, quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize) - except Exception as e: - shared.log.error(f'Network load: type=LoRA quant=bnb cls={self.__class__.__name__} type={self.quant_type} blocksize={self.blocksize} state={vars(self.quant_state)} weight={self.weight} bias={lora_weights} {e}') - else: - try: - new_weight = model_weights.to(device) + lora_weights.to(device) - except Exception: - new_weight = model_weights + lora_weights # try without device cast - weight = torch.nn.Parameter(new_weight, requires_grad=False) - try: - # weight = weight.to(device=devices.device) # required since quantization happens only during .to call, not during params creation - pass - except Exception: - pass # may fail if weights is meta tensor - return weight - - -def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False, device: torch.device = None): - weights_backup = getattr(self, "network_weights_backup", False) - bias_backup = getattr(self, "network_bias_backup", False) - if not isinstance(weights_backup, bool): # remove previous backup if we switched settings - weights_backup = True - if not isinstance(bias_backup, bool): - bias_backup = True - if not weights_backup and not bias_backup: - return - device = device or devices.device - t0 = time.time() - - if weights_backup: - if updown is not None and len(self.weight.shape) == 4 and self.weight.shape[1] == 9: # inpainting model so zero pad updown to make channel 4 to 9 - updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable - if updown is not None: - weight = network_add_weights(self, lora_weights=updown, deactivate=deactivate, device=device) - if weight is not None: - self.weight = weight - - if bias_backup: - if ex_bias is not None: - bias = network_add_weights(self, lora_weights=ex_bias, deactivate=deactivate, device=device) - if bias is not None: - self.bias = bias - - if hasattr(self, "qweight") and hasattr(self, "freeze"): - self.freeze() - - timer.apply += time.time() - t0 - - -def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, orig_device: torch.device, deactivate: bool = False): - weights_backup = getattr(self, "network_weights_backup", None) - bias_backup = getattr(self, "network_bias_backup", None) - if weights_backup is None and bias_backup is None: - return - t0 = time.time() - - if weights_backup is not None: - self.weight = None - if updown is not None and len(weights_backup.shape) == 4 and weights_backup.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 - updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable - if updown is not None: - weight = network_add_weights(self, model_weights=weights_backup, lora_weights=updown, deactivate=deactivate) - if weight is not None: - self.weight = weight - else: - self.weight = torch.nn.Parameter(weights_backup.to(device=orig_device), requires_grad=False) - - if bias_backup is not None: - self.bias = None - if ex_bias is not None: - bias = network_add_weights(self, model_weights=weights_backup, lora_weights=ex_bias, deactivate=deactivate) - if bias: - self.weight = bias - else: - self.bias = torch.nn.Parameter(bias_backup.to(device=orig_device), requires_grad=False) - - if hasattr(self, "qweight") and hasattr(self, "freeze"): - self.freeze() - - timer.apply += time.time() - t0 - - -def network_deactivate(include=[], exclude=[]): - if not shared.opts.lora_fuse_diffusers or shared.opts.lora_force_diffusers: - return - t0 = time.time() - sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility - if shared.opts.diffusers_offload_mode == "sequential": - sd_models.disable_offload(sd_model) - sd_models.move_model(sd_model, device=devices.cpu) - modules = {} - - components = include if len(include) > 0 else ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'unet', 'transformer'] - components = [x for x in components if x not in exclude] - active_components = [] - for name in components: - component = getattr(sd_model, name, None) - if component is not None and hasattr(component, 'named_modules'): - modules[name] = list(component.named_modules()) - active_components.append(name) - total = sum(len(x) for x in modules.values()) - if shared.opts.lora_apply_gpu: - device = devices.device - else: - device = devices.cpu - if len(previously_loaded_networks) > 0 and debug: - pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=deactivate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console) - task = pbar.add_task(description='', total=total) - else: - task = None - pbar = nullcontext() - with devices.inference_context(), pbar: - applied_layers.clear() - for component in modules.keys(): - orig_device = getattr(sd_model, component, None).device - for _, module in modules[component]: - network_layer_name = getattr(module, 'network_layer_name', None) - if shared.state.interrupted or network_layer_name is None: - if task is not None: - pbar.update(task, advance=1) - continue - batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name, use_previous=True) - if shared.opts.lora_fuse_diffusers: - network_apply_direct(module, batch_updown, batch_ex_bias, deactivate=True) - else: - network_apply_weights(module, batch_updown, batch_ex_bias, orig_device, deactivate=True) - if batch_updown is not None or batch_ex_bias is not None: - applied_layers.append(network_layer_name) - del batch_updown, batch_ex_bias - module.network_current_names = () - try: - module.to(device) - except Exception: - pass - if task is not None: - pbar.update(task, advance=1, description=f'networks={len(previously_loaded_networks)} modules={active_components} layers={total} unapply={len(applied_layers)}') - - timer.deactivate = time.time() - t0 - if debug and len(previously_loaded_networks) > 0: - shared.log.debug(f'Network deactivate: type=LoRA networks={[n.name for n in previously_loaded_networks]} modules={active_components} layers={total} apply={len(applied_layers)} fuse={shared.opts.lora_fuse_diffusers} time={timer.summary}') - modules.clear() - if shared.opts.diffusers_offload_mode == "sequential": - sd_models.set_diffuser_offload(sd_model, op="model") def network_activate(include=[], exclude=[]): @@ -604,10 +33,7 @@ def network_activate(include=[], exclude=[]): pbar = nullcontext() applied_weight = 0 applied_bias = 0 - if shared.opts.lora_apply_gpu: - device = devices.device - else: - device = devices.cpu + device = devices.device if shared.opts.lora_apply_gpu else devices.cpu with devices.inference_context(), pbar: wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) if len(loaded_networks) > 0 else () applied_layers.clear() @@ -624,21 +50,18 @@ def network_activate(include=[], exclude=[]): backup_size += network_backup_weights(module, network_layer_name, wanted_names) batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name) if shared.opts.lora_fuse_diffusers: - network_apply_direct(module, batch_updown, batch_ex_bias, device) + network_apply_direct(module, batch_updown, batch_ex_bias, device=device) else: - network_apply_weights(module, batch_updown, batch_ex_bias, orig_device) + network_apply_weights(module, batch_updown, batch_ex_bias, device=orig_device) if batch_updown is not None or batch_ex_bias is not None: applied_layers.append(network_layer_name) + # module.to(device) # TODO maybe if batch_updown is not None: applied_weight += 1 if batch_ex_bias is not None: applied_bias += 1 del batch_updown, batch_ex_bias module.network_current_names = wanted_names - try: - module.to(device) - except Exception: - pass if task is not None: pbar.update(task, advance=1, description=f'networks={len(loaded_networks)} modules={active_components} layers={total} weights={applied_weight} bias={applied_bias} backup={backup_size}') @@ -646,8 +69,65 @@ def network_activate(include=[], exclude=[]): pbar.remove_task(task) # hide progress bar for no action timer.activate += time.time() - t0 if debug and len(loaded_networks) > 0: - shared.log.debug(f'Network load: type=LoRA networks={[n.name for n in loaded_networks]} modules={active_components} layers={total} weights={applied_weight} bias={applied_bias} backup={backup_size} device={device} fuse={shared.opts.lora_fuse_diffusers} time={timer.summary}') + shared.log.debug(f'Network load: type=LoRA networks={[n.name for n in loaded_networks]} modules={active_components} layers={total} weights={applied_weight} bias={applied_bias} backup={backup_size} fuse={shared.opts.lora_fuse_diffusers} device={device} time={timer.summary}') modules.clear() if len(loaded_networks) > 0 and (applied_weight > 0 or applied_bias > 0): if shared.opts.diffusers_offload_mode == "sequential": sd_models.set_diffuser_offload(sd_model, op="model") + + +def network_deactivate(include=[], exclude=[]): + if not shared.opts.lora_fuse_diffusers or shared.opts.lora_force_diffusers: + return + t0 = time.time() + sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility + if shared.opts.diffusers_offload_mode == "sequential": + sd_models.disable_offload(sd_model) + sd_models.move_model(sd_model, device=devices.cpu) + modules = {} + + components = include if len(include) > 0 else ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'unet', 'transformer'] + components = [x for x in components if x not in exclude] + active_components = [] + for name in components: + component = getattr(sd_model, name, None) + if component is not None and hasattr(component, 'named_modules'): + modules[name] = list(component.named_modules()) + active_components.append(name) + total = sum(len(x) for x in modules.values()) + device = devices.device if shared.opts.lora_apply_gpu else devices.cpu + if len(previously_loaded_networks) > 0 and debug: + pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=deactivate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console) + task = pbar.add_task(description='', total=total) + else: + task = None + pbar = nullcontext() + with devices.inference_context(), pbar: + applied_layers.clear() + for component in modules.keys(): + orig_device = getattr(sd_model, component, None).device + for _, module in modules[component]: + network_layer_name = getattr(module, 'network_layer_name', None) + if shared.state.interrupted or network_layer_name is None: + if task is not None: + pbar.update(task, advance=1) + continue + batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name, use_previous=True) + if shared.opts.lora_fuse_diffusers: + network_apply_direct(module, batch_updown, batch_ex_bias, device=device, deactivate=True) + else: + network_apply_weights(module, batch_updown, batch_ex_bias, device=orig_device, deactivate=True) + if batch_updown is not None or batch_ex_bias is not None: + # module.to(device) # TODO maybe + applied_layers.append(network_layer_name) + del batch_updown, batch_ex_bias + module.network_current_names = () + if task is not None: + pbar.update(task, advance=1, description=f'networks={len(previously_loaded_networks)} modules={active_components} layers={total} unapply={len(applied_layers)}') + + timer.deactivate = time.time() - t0 + if debug and len(previously_loaded_networks) > 0: + shared.log.debug(f'Network deactivate: type=LoRA networks={[n.name for n in previously_loaded_networks]} modules={active_components} layers={total} apply={len(applied_layers)} fuse={shared.opts.lora_fuse_diffusers} time={timer.summary}') + modules.clear() + if shared.opts.diffusers_offload_mode == "sequential": + sd_models.set_diffuser_offload(sd_model, op="model") diff --git a/modules/model_auraflow.py b/modules/model_auraflow.py index 83320040b..790095cae 100644 --- a/modules/model_auraflow.py +++ b/modules/model_auraflow.py @@ -17,5 +17,5 @@ def load_auraflow(checkpoint_info, diffusers_load_config={}): cache_dir = shared.opts.diffusers_dir, **diffusers_load_config, ) - devices.torch_gc() + devices.torch_gc(force=True) return pipe diff --git a/modules/model_flux.py b/modules/model_flux.py index 74ba5c8dd..12ad0a471 100644 --- a/modules/model_flux.py +++ b/modules/model_flux.py @@ -342,6 +342,5 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch vae = None for k in kwargs.keys(): kwargs[k] = None - devices.torch_gc() - + devices.torch_gc(force=True) return pipe diff --git a/modules/model_kolors.py b/modules/model_kolors.py index 932763b4b..b6c35c85a 100644 --- a/modules/model_kolors.py +++ b/modules/model_kolors.py @@ -23,5 +23,5 @@ def load_kolors(_checkpoint_info, diffusers_load_config={}): **diffusers_load_config, ) pipe.vae.config.force_upcast = True - devices.torch_gc() + devices.torch_gc(force=True) return pipe diff --git a/modules/model_lumina.py b/modules/model_lumina.py index f9b09f23b..f19fcd7da 100644 --- a/modules/model_lumina.py +++ b/modules/model_lumina.py @@ -21,7 +21,7 @@ def load_lumina(_checkpoint_info, diffusers_load_config={}): cache_dir = shared.opts.diffusers_dir, **diffusers_load_config, ) - devices.torch_gc() + devices.torch_gc(force=True) return pipe @@ -40,4 +40,5 @@ def load_lumina2(checkpoint_info, diffusers_load_config={}): if ('TE' in shared.opts.bnb_quantization or 'TE' in shared.opts.torchao_quantization or 'TE' in shared.opts.quanto_quantization): kwargs['text_encoder'] = transformers.AutoModel.from_pretrained(repo_id, subfolder="text_encoder", cache_dir=shared.opts.diffusers_dir, torch_dtype=devices.dtype, **quant_args) sd_model = diffusers.Lumina2Text2ImgPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config, **quant_args, **kwargs) + devices.torch_gc(force=True) return sd_model diff --git a/modules/model_meissonic.py b/modules/model_meissonic.py index 69ceab458..d705a32d9 100644 --- a/modules/model_meissonic.py +++ b/modules/model_meissonic.py @@ -33,5 +33,5 @@ def load_meissonic(checkpoint_info, diffusers_load_config={}): diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["meissonic"] = PipelineMeissonic diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["meissonic"] = PipelineMeissonicImg2Img diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["meissonic"] = PipelineMeissonicInpaint - devices.torch_gc() + devices.torch_gc(force=True) return pipe diff --git a/modules/model_omnigen.py b/modules/model_omnigen.py index a08ad4ed5..b7b6e3546 100644 --- a/modules/model_omnigen.py +++ b/modules/model_omnigen.py @@ -20,12 +20,5 @@ def load_omnigen(checkpoint_info, diffusers_load_config={}): # pylint: disable=u if shared.opts.diffusers_eval: pipe.model.eval() pipe.vae.to(devices.device, dtype=devices.dtype) - devices.torch_gc() - - # register - # from diffusers import pipelines - # pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["omnigen"] = pipe.__class__ - # pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["omnigen"] = pipe.__class__ - # pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["omnigen"] = pipe.__class__ - + devices.torch_gc(force=True) return pipe diff --git a/modules/model_pixart.py b/modules/model_pixart.py index c017cc468..0757a1216 100644 --- a/modules/model_pixart.py +++ b/modules/model_pixart.py @@ -26,5 +26,5 @@ def load_pixart(checkpoint_info, diffusers_load_config={}): **kwargs, **diffusers_load_config, ) - devices.torch_gc() + devices.torch_gc(force=True) return pipe diff --git a/modules/model_sana.py b/modules/model_sana.py index a31985ab6..7f39f17e0 100644 --- a/modules/model_sana.py +++ b/modules/model_sana.py @@ -73,6 +73,5 @@ def load_sana(checkpoint_info, kwargs={}): pipe.transformer.eval() t1 = time.time() shared.log.debug(f'Load model: type=Sana target={devices.dtype} te={pipe.text_encoder.dtype} transformer={pipe.transformer.dtype} vae={pipe.vae.dtype} time={t1-t0:.2f}') - - devices.torch_gc() + devices.torch_gc(force=True) return pipe diff --git a/modules/model_sd3.py b/modules/model_sd3.py index 5b8006c2a..e3774b291 100644 --- a/modules/model_sd3.py +++ b/modules/model_sd3.py @@ -156,5 +156,5 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None): config=config, **kwargs, ) - devices.torch_gc() + devices.torch_gc(force=True) return pipe diff --git a/modules/model_stablecascade.py b/modules/model_stablecascade.py index 0c767d33b..fc143e8e7 100644 --- a/modules/model_stablecascade.py +++ b/modules/model_stablecascade.py @@ -155,6 +155,7 @@ def load_cascade_combined(checkpoint_info, diffusers_load_config): latent_dim_scale=sd_model.decoder_pipe.config.latent_dim_scale, ) + devices.torch_gc(force=True) shared.log.debug(f'StableCascade combined: {sd_model.__class__.__name__}') return sd_model diff --git a/modules/sd_models.py b/modules/sd_models.py index b500c51b7..2d4451db1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1053,10 +1053,10 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model', def clear_caches(): # shared.log.debug('Cache clear') if not shared.opts.lora_legacy: - from modules.lora import networks - networks.loaded_networks.clear() - networks.previously_loaded_networks.clear() - networks.lora_cache.clear() + from modules.lora import lora_common, lora_load + lora_common.loaded_networks.clear() + lora_common.previously_loaded_networks.clear() + lora_load.lora_cache.clear() from modules import prompt_parser_diffusers prompt_parser_diffusers.cache.clear() diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index ef40092ec..435ff95ae 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -437,7 +437,6 @@ def create_html(search_text, sort_column): def create_ui(): - import modules.ui extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "user", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all", visible=False) extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False) extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False, container=False) diff --git a/modules/ui_extra_networks_lora.py b/modules/ui_extra_networks_lora.py index 9dd1b3573..194f16b41 100644 --- a/modules/ui_extra_networks_lora.py +++ b/modules/ui_extra_networks_lora.py @@ -1,8 +1,8 @@ import os import json import concurrent -import modules.lora.networks as networks from modules import shared, ui_extra_networks +from modules.lora import lora_load debug = os.environ.get('SD_LORA_DEBUG', None) is not None @@ -14,7 +14,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): self.list_time = 0 def refresh(self): - networks.list_available_networks() + lora_load.list_available_networks() @staticmethod def get_tags(l, info): @@ -78,9 +78,9 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): return clean_tags def create_item(self, name): - l = networks.available_networks.get(name) + l = lora_load.available_networks.get(name) if l is None: - shared.log.warning(f'Networks: type=lora registered={len(list(networks.available_networks))} file="{name}" not registered') + shared.log.warning(f'Networks: type=lora registered={len(list(lora_load.available_networks))} file="{name}" not registered') return None try: # path, _ext = os.path.splitext(l.filename) @@ -111,7 +111,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): def list_items(self): items = [] with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor: - future_items = {executor.submit(self.create_item, net): net for net in networks.available_networks} + future_items = {executor.submit(self.create_item, net): net for net in lora_load.available_networks} for future in concurrent.futures.as_completed(future_items): item = future.result() if item is not None: diff --git a/modules/ui_gallery.py b/modules/ui_gallery.py index a1f317caa..40bcace03 100644 --- a/modules/ui_gallery.py +++ b/modules/ui_gallery.py @@ -3,7 +3,7 @@ from datetime import datetime from urllib.parse import unquote import gradio as gr from PIL import Image -from modules import shared, ui_symbols, ui_common, images, ui_control_helpers +from modules import shared, ui_symbols, ui_common, images, video from modules.ui_components import ToolButton def read_media(fn): @@ -13,7 +13,7 @@ def read_media(fn): return [[], None, '', '', f'Media not found: {fn}'] stat = os.stat(fn) if fn.lower().endswith('.mp4'): - frames, fps, duration, w, h, codec, _frame = ui_control_helpers.get_video_params(fn) + frames, fps, duration, w, h, codec, _frame = video.get_video_params(fn) geninfo = '' log = f'''
Video {w} x {h} diff --git a/modules/zluda_hijacks.py b/modules/zluda_hijacks.py index ea60c846a..431c359b8 100644 --- a/modules/zluda_hijacks.py +++ b/modules/zluda_hijacks.py @@ -1,7 +1,7 @@ from functools import wraps import torch import torch._dynamo.device_interface -from modules import rocm, zluda, shared +from modules import shared, rocm, zluda # pylint: disable=unused-import MEM_BUS_WIDTH = { diff --git a/webui.py b/webui.py index ecfafe3f1..e44c5c2b9 100644 --- a/webui.py +++ b/webui.py @@ -88,8 +88,8 @@ def initialize(): timer.startup.record("models") if not shared.opts.lora_legacy: - import modules.lora.networks as lora_networks - lora_networks.list_available_networks() + from modules.lora import lora_load + lora_load.list_available_networks() timer.startup.record("lora") shared.prompt_styles.reload()