import os import time import diffusers import diffusers.models.lora as diffusers_lora # from modules import shared import modules.shared as shared import modules.errors debug_output = os.environ.get('SD_LORA_DEBUG', None) debug = shared.log.info if debug_output is not None else lambda *args, **kwargs: None lora_state = { # Lora state for Diffusers 'multiplier': [], 'active': False, 'loaded': [], 'all_loras': [], } def unload_diffusers_lora(): try: pipe = shared.sd_model if shared.opts.diffusers_lora_loader == "diffusers": if len(lora_state['loaded']) > 1 and hasattr(pipe, "unfuse_lora"): debug(f'LoRA unfuse: loader={shared.opts.diffusers_lora_loader}') pipe.unfuse_lora() pipe.unload_lora_weights() pipe._remove_text_encoder_monkey_patch() # pylint: disable=W0212 proc_cls_name = next(iter(pipe.unet.attn_processors.values())).__class__.__name__ non_lora_proc_cls = getattr(diffusers.models.attention_processor, proc_cls_name)#[len("LORA"):]) pipe.unet.set_attn_processor(non_lora_proc_cls()) else: lora_state['all_loras'].reverse() lora_state['multiplier'].reverse() for i, lora_network in enumerate(lora_state['all_loras']): if shared.opts.diffusers_lora_loader == "merge and apply": lora_network.restore_from(multiplier=lora_state['multiplier'][i]) if shared.opts.diffusers_lora_loader == "sequential apply": lora_network.unapply_to() lora_state['active'] = False lora_state['loaded'].clear() lora_state['all_loras'] = [] lora_state['multiplier'] = [] debug(f'LoRA unloaded: loader={shared.opts.diffusers_lora_loader}') except Exception as e: shared.log.error(f"LoRA unload failed: {e}") def load_diffusers_lora(name, lora, te_multiplier = 1.0, unet_multiplier = 1.0, dyn_dim = None): # TODO: te_multiplier is used as strength and unet_multiplier is ignored if f'{lora.filename}:{te_multiplier}' in lora_state['loaded']: debug(f'LoRA cached: {name} te-strength={te_multiplier} unet-strength={unet_multiplier} dyn-dim={dyn_dim}') return try: t0 = time.time() pipe = shared.sd_model lora_state['active'] = True lora_state['multiplier'].append(te_multiplier) fuse = 0 if shared.opts.diffusers_lora_loader.startswith("diffusers"): pipe.load_lora_weights(lora.filename, cache_dir=shared.opts.diffusers_dir, local_files_only=True, lora_scale=te_multiplier, low_cpu_mem_usage=True) if hasattr(pipe, "fuse_lora"): t2 = time.time() pipe.fuse_lora(lora_scale=te_multiplier) fuse = time.time() - t2 lora_state['loaded'].append(f'{lora.filename}:{te_multiplier}') if shared.compiled_model_state is not None: #filename breaks caching shared.compiled_model_state.lora_model.append(f'{name}:{te_multiplier}') else: from safetensors.torch import load_file lora_sd = load_file(lora.filename) if "XL" in pipe.__class__.__name__: text_encoders = [pipe.text_encoder, pipe.text_encoder_2] else: text_encoders = pipe.text_encoder lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=te_multiplier) lora_network.load_state_dict(lora_sd) if shared.opts.diffusers_lora_loader == "merge and apply": lora_network.merge_to(multiplier=te_multiplier) if shared.opts.diffusers_lora_loader == "sequential apply": lora_network.to(shared.device, dtype=pipe.unet.dtype) lora_network.apply_to(multiplier=te_multiplier) lora_state['all_loras'].append(lora_network) lora_state['loaded'].append(f'{lora.filename}:{te_multiplier}') if shared.compiled_model_state is not None: #filename breaks caching shared.compiled_model_state.lora_model.append(f'{name}:{te_multiplier}') t1 = time.time() fuse = f'fuse={fuse:.2f}s' if fuse > 0 else '' shared.log.info(f'LoRA loaded: {name} strength={te_multiplier} loader="{shared.opts.diffusers_lora_loader}" lora={t1-t0:.2f}s {fuse}') except Exception as e: lines = str(e).splitlines() if debug_output is None: shared.log.error(f'LoRA load failed: {name} loader="{shared.opts.diffusers_lora_loader}" {lines[0]}') else: modules.errors.display(e, 'LoRA load failed') # Diffusersで動くLoRA。このファイル単独で完結する。 # LoRA module for Diffusers. This file works independently. import bisect # pylint: disable=wrong-import-order import math # pylint: disable=wrong-import-order from typing import Any, Dict, List, Mapping, Optional, Union # pylint: disable=wrong-import-order from diffusers import UNet2DConditionModel # pylint: disable=wrong-import-order from tqdm import tqdm # pylint: disable=wrong-import-order from transformers import CLIPTextModel # pylint: disable=wrong-import-order import torch # pylint: disable=wrong-import-order def make_unet_conversion_map() -> Dict[str, str]: unet_conversion_map_layer = [] for i in range(3): # num_blocks is 3 in sdxl # loop over downblocks/upblocks for j in range(2): # loop over resnets/attentions for downblocks hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." sd_up_res_prefix = f"output_blocks.{3*i + j}.0." unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) # if i > 0: commentout for sdxl # no attention layers in up_blocks.0 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) hf_mid_atn_prefix = "mid_block.attentions.0." sd_mid_atn_prefix = "middle_block.1." unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): hf_mid_res_prefix = f"mid_block.resnets.{j}." sd_mid_res_prefix = f"middle_block.{2*j}." unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) unet_conversion_map_resnet = [ # (stable-diffusion, HF Diffusers) ("in_layers.0.", "norm1."), ("in_layers.2.", "conv1."), ("out_layers.0.", "norm2."), ("out_layers.3.", "conv2."), ("emb_layers.1.", "time_emb_proj."), ("skip_connection.", "conv_shortcut."), ] unet_conversion_map = [] for sd, hf in unet_conversion_map_layer: if "resnets" in hf: for sd_res, hf_res in unet_conversion_map_resnet: unet_conversion_map.append((sd + sd_res, hf + hf_res)) else: unet_conversion_map.append((sd, hf)) for j in range(2): hf_time_embed_prefix = f"time_embedding.linear_{j+1}." sd_time_embed_prefix = f"time_embed.{j*2}." unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) for j in range(2): hf_label_embed_prefix = f"add_embedding.linear_{j+1}." sd_label_embed_prefix = f"label_emb.0.{j*2}." unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) unet_conversion_map.append(("out.0.", "conv_norm_out.")) unet_conversion_map.append(("out.2.", "conv_out.")) sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map} return sd_hf_conversion_map UNET_CONVERSION_MAP = make_unet_conversion_map() class LoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. """ def __init__( self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() self.lora_name = lora_name if isinstance(org_module, diffusers_lora.LoRACompatibleConv): #Modified to support Diffusers>=0.19.2 in_dim = org_module.in_channels out_dim = org_module.out_channels else: in_dim = org_module.in_features out_dim = org_module.out_features self.lora_dim = lora_dim if isinstance(org_module, diffusers_lora.LoRACompatibleConv): #Modified to support Diffusers>=0.19.2 kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) if isinstance(alpha, torch.Tensor): alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation # same as microsoft's torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.zeros_(self.lora_up.weight) self.multiplier = multiplier self.org_module = [org_module] self.enabled = True self.network: LoRANetwork = None self.org_forward = None # override org_module's forward method def apply_to(self, multiplier=None): if multiplier is not None: self.multiplier = multiplier if self.org_forward is None: self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward # restore org_module's forward method def unapply_to(self): if self.org_forward is not None: self.org_module[0].forward = self.org_forward # forward with lora def forward(self, x, scale = 1.0): # pylint: disable=unused-argument if not self.enabled: return self.org_forward(x) return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def set_network(self, network): self.network = network # merge lora weight to org weight def merge_to(self, multiplier=1.0): # get lora weight lora_weight = self.get_weight(multiplier) # get org weight org_sd = self.org_module[0].state_dict() org_weight = org_sd["weight"] weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype) # set weight to org_module org_sd["weight"] = weight self.org_module[0].load_state_dict(org_sd) # restore org weight from lora weight def restore_from(self, multiplier=1.0): # get lora weight lora_weight = self.get_weight(multiplier) # get org weight org_sd = self.org_module[0].state_dict() org_weight = org_sd["weight"] weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype) # set weight to org_module org_sd["weight"] = weight self.org_module[0].load_state_dict(org_sd) # return lora weight def get_weight(self, multiplier=None): if multiplier is None: multiplier = self.multiplier # get up/down weight from module up_weight = self.lora_up.weight.to(torch.float) down_weight = self.lora_down.weight.to(torch.float) # pre-calculated weight if len(down_weight.size()) == 2: # linear weight = self.multiplier * (up_weight @ down_weight) * self.scale elif down_weight.size()[2:4] == (1, 1): # conv2d 1x1 weight = ( self.multiplier * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * self.scale ) else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) weight = self.multiplier * conved * self.scale return weight # Create network from weights for inference, weights are not loaded here def create_network_from_weights( text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0 ): # get dim/alpha mapping modules_dim = {} modules_alpha = {} for key, value in weights_sd.items(): if "." not in key: continue lora_name = key.split(".")[0] if "alpha" in key: modules_alpha[lora_name] = value elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim # print(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): if key not in modules_alpha: modules_alpha[key] = modules_dim[key] return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0): text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder] unet = pipe.unet lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier) lora_network.load_state_dict(weights_sd) lora_network.merge_to(multiplier=multiplier) # block weightや学習に対応しない簡易版 / simple version without block weight and training class LoRANetwork(torch.nn.Module): # pylint: disable=abstract-method UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" def __init__( self, text_encoder: Union[List[CLIPTextModel], CLIPTextModel], unet: UNet2DConditionModel, multiplier: float = 1.0, modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, varbose: Optional[bool] = False, # pylint: disable=unused-argument ) -> None: super().__init__() self.multiplier = multiplier # convert SDXL Stability AI's U-Net modules to Diffusers self.convert_unet_modules(modules_dim, modules_alpha) # create module instances def create_modules( is_unet: bool, text_encoder_idx: Optional[int], # None, 1, 2 root_module: torch.nn.Module, target_replace_modules: List[torch.nn.Module], ) -> List[LoRAModule]: prefix = ( self.LORA_PREFIX_UNET if is_unet else ( self.LORA_PREFIX_TEXT_ENCODER if text_encoder_idx is None else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) ) ) loras = [] skipped = [] for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): is_linear = isinstance(child_module, (torch.nn.Linear, diffusers_lora.LoRACompatibleLinear)) #Modified to support Diffusers>=0.19.2 is_conv2d = isinstance(child_module, (torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv)) #Modified to support Diffusers>=0.19.2 if is_linear or is_conv2d: lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") if lora_name not in modules_dim: # print(f"skipped {lora_name} (not found in modules_dim)") skipped.append(lora_name) continue dim = modules_dim[lora_name] alpha = modules_alpha[lora_name] lora = LoRAModule( lora_name, child_module, self.multiplier, dim, alpha, ) loras.append(lora) return loras, skipped text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] # create LoRA for text encoder # 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider self.text_encoder_loras: List[LoRAModule] = [] skipped_te = [] for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 else: index = None text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped # extend U-Net target modules to include Conv2d 3x3 target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras: List[LoRAModule] self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) debug(f"LoRA module: te_loaded={len(self.text_encoder_loras)} te_skipped={len(skipped_te)} unet_loaded={len(self.unet_loras)} unet_skipped={len(skipped_un)}") # assertion names = set() for lora in self.text_encoder_loras + self.unet_loras: names.add(lora.lora_name) for lora_name in modules_dim.keys(): assert lora_name in names, f"{lora_name} is not found in created LoRA modules." # make to work load_state_dict for lora in self.text_encoder_loras + self.unet_loras: self.add_module(lora.lora_name, lora) # SDXL: convert SDXL Stability AI's U-Net modules to Diffusers def convert_unet_modules(self, modules_dim, modules_alpha): converted_count = 0 not_converted_count = 0 map_keys = list(UNET_CONVERSION_MAP.keys()) map_keys.sort() for key in list(modules_dim.keys()): if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"): search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "") position = bisect.bisect_right(map_keys, search_key) map_key = map_keys[position - 1] if search_key.startswith(map_key): new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key]) modules_dim[new_key] = modules_dim[key] modules_alpha[new_key] = modules_alpha[key] del modules_dim[key] del modules_alpha[key] converted_count += 1 else: not_converted_count += 1 debug(f'LoRA module: unet converted={converted_count}/{not_converted_count}') def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: # shared.log.debug("LoRA apply for text encoder") for lora in self.text_encoder_loras: lora.apply_to(multiplier) if apply_unet: # shared.log.debug("LoRA apply for U-Net") for lora in self.unet_loras: lora.apply_to(multiplier) def unapply_to(self): for lora in self.text_encoder_loras + self.unet_loras: lora.unapply_to() def merge_to(self, multiplier=1.0): # shared.log.debug("LoRA merge weights for text encoder") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.merge_to(multiplier) def restore_from(self, multiplier=1.0): # shared.log.debug("LoRA restore weights") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.restore_from(multiplier) def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): # pylint: disable=arguments-differ # convert SDXL Stability AI's state dict to Diffusers' based state dict map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules map_keys.sort() for key in list(state_dict.keys()): if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"): search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "") position = bisect.bisect_right(map_keys, search_key) map_key = map_keys[position - 1] if search_key.startswith(map_key): new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key]) state_dict[new_key] = state_dict[key] del state_dict[key] # in case of V2, some weights have different shape, so we need to convert them # because V2 LoRA is based on U-Net created by use_linear_projection=False my_state_dict = self.state_dict() for key in state_dict.keys(): if state_dict[key].size() != my_state_dict[key].size(): # pylint: disable=unsubscriptable-object # print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") state_dict[key] = state_dict[key].view(my_state_dict[key].size()) # pylint: disable=unsubscriptable-object return super().load_state_dict(state_dict, strict)