import json from itertools import groupby from typing import Dict, List, Optional, Set, Tuple, Union import torch import torch.nn as nn from safetensors.torch import safe_open from safetensors.torch import save_file as safe_save from torch import dtype from dreambooth.utils.model_utils import disable_safe_unpickle, enable_safe_unpickle class LoraInjectedLinear(nn.Module): def __init__(self, in_features, out_features, bias=False, r=4, dropout_p=0.1): super().__init__() if r > min(in_features, out_features): raise ValueError( f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" ) self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) self.dropout = nn.Dropout(dropout_p) self.lora_up = nn.Linear(r, out_features, bias=False) self.scale = 1.0 nn.init.normal_(self.lora_down.weight, std=1 / r) nn.init.zeros_(self.lora_up.weight) def forward(self, input): return ( self.linear(input) + self.lora_up(self.dropout(self.lora_down(input))) * self.scale ) class LoraInjectedConv2d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, dilation=1, groups: int = 1, bias: bool = True, r: int = 4, dropout_p: float = 0.1, ): super().__init__() if r > min(in_channels, out_channels): raise ValueError( f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}" ) self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) self.lora_down = nn.Conv2d( in_channels=in_channels, out_channels=r, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False, ) self.dropout = nn.Dropout(dropout_p) self.lora_up = nn.Conv2d( in_channels=r, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False, ) self.scale = 1.0 nn.init.normal_(self.lora_down.weight, std=1 / r) nn.init.zeros_(self.lora_up.weight) def forward(self, input): return ( self.conv(input) + self.lora_up(self.dropout(self.lora_down(input))) * self.scale ) UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"} TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"} DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE EMBED_FLAG = "" def _find_children( model, search_class=None, ): """ Find all modules of a certain class (or union of classes). Returns all matching modules, along with the parent of those moduless and the names they are referenced by. """ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear if search_class is None: search_class = [nn.Linear] for parent in model.modules(): for name, module in parent.named_children(): if any([isinstance(module, _class) for _class in search_class]): yield parent, name, module def _find_modules_v2( model, ancestor_class: Optional[Set[str]] = None, search_class=None, exclude_children_of=None, ): """ Find all modules of a certain class (or union of classes) that are direct or indirect descendants of other modules of a certain class (or union of classes). Returns all matching modules, along with the parent of those moduless and the names they are referenced by. """ # Get the targets we should replace all linears under if exclude_children_of is None: exclude_children_of = [ LoraInjectedLinear, LoraInjectedConv2d, ] if search_class is None: search_class = [nn.Linear] if ancestor_class is not None: ancestors = ( module for module in model.modules() if module.__class__.__name__ in ancestor_class ) else: # this, incase you want to naively iterate over all modules. ancestors = [module for module in model.modules()] # For each target find every linear_class module that isn't a child of a LoraInjectedLinear for ancestor in ancestors: for fullname, module in ancestor.named_modules(): if any([isinstance(module, _class) for _class in search_class]): # Find the direct parent if this is a descendant, not a child, of target *path, name = fullname.split(".") parent = ancestor while path: parent = parent.get_submodule(path.pop(0)) # Skip this linear if it's a child of a LoraInjectedLinear if exclude_children_of and any( [isinstance(parent, _class) for _class in exclude_children_of] ): continue # Otherwise, yield it yield parent, name, module def _find_modules_old( model, ancestor_class=None, search_class=None, ): if search_class is None: search_class = [nn.Linear] if ancestor_class is None: ancestor_class = DEFAULT_TARGET_REPLACE ret = [] for _module in model.modules(): if _module.__class__.__name__ in ancestor_class: for name, _child_module in _module.named_modules(): if _child_module.__class__ in search_class: ret.append((_module, name, _child_module)) print(ret) return ret _find_modules = _find_modules_v2 def inject_trainable_lora( model: nn.Module, target_replace_module=None, r: int = 4, loras=None, # path to lora .pt ): """ inject lora into model, and returns lora parameter groups. """ if target_replace_module is None: target_replace_module = DEFAULT_TARGET_REPLACE disable_safe_unpickle() require_grad_params = [] names = [] if loras is not None: loras = torch.load(loras) for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear] ): weight = _child_module.weight bias = _child_module.bias _tmp = LoraInjectedLinear( _child_module.in_features, _child_module.out_features, _child_module.bias is not None, r, ) _tmp.linear.weight = weight if bias is not None: _tmp.linear.bias = bias # switch the module _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) _module._modules[name] = _tmp require_grad_params.append(_module._modules[name].lora_up.parameters()) require_grad_params.append(_module._modules[name].lora_down.parameters()) if loras is not None: _module._modules[name].lora_up.weight = nn.Parameter(loras.pop(0)) _module._modules[name].lora_down.weight = nn.Parameter(loras.pop(0)) _module._modules[name].lora_up.weight.requires_grad = True _module._modules[name].lora_down.weight.requires_grad = True names.append(name) enable_safe_unpickle() return require_grad_params, names def inject_trainable_lora_extended( model: nn.Module, target_replace_module=None, r: int = 4, loras=None, # path to lora .pt ): """ inject lora into model, and returns lora parameter groups. """ if target_replace_module is None: target_replace_module = UNET_EXTENDED_TARGET_REPLACE disable_safe_unpickle() require_grad_params = [] names = [] if loras is not None: loras = torch.load(loras) for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear, nn.Conv2d] ): if _child_module.__class__ == nn.Linear: weight = _child_module.weight bias = _child_module.bias _tmp = LoraInjectedLinear( _child_module.in_features, _child_module.out_features, _child_module.bias is not None, r, ) _tmp.linear.weight = weight if bias is not None: _tmp.linear.bias = bias elif _child_module.__class__ == nn.Conv2d: weight = _child_module.weight bias = _child_module.bias _tmp = LoraInjectedConv2d( _child_module.in_channels, _child_module.out_channels, _child_module.kernel_size, _child_module.stride, _child_module.padding, _child_module.dilation, _child_module.groups, _child_module.bias is not None, r, ) _tmp.conv.weight = weight if bias is not None: _tmp.conv.bias = bias # switch the module _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) if bias is not None: _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype) _module._modules[name] = _tmp require_grad_params.append(_module._modules[name].lora_up.parameters()) require_grad_params.append(_module._modules[name].lora_down.parameters()) if loras is not None: _module._modules[name].lora_up.weight = nn.Parameter(loras.pop(0)) _module._modules[name].lora_down.weight = nn.Parameter(loras.pop(0)) _module._modules[name].lora_up.weight.requires_grad = True _module._modules[name].lora_down.weight.requires_grad = True names.append(name) enable_safe_unpickle() return require_grad_params, names def extract_lora_ups_down(model, target_replace_module=None): if target_replace_module is None: target_replace_module = DEFAULT_TARGET_REPLACE loras = [] for _m, _n, _child_module in _find_modules( model, target_replace_module, search_class=[LoraInjectedLinear, LoraInjectedConv2d], ): loras.append((_child_module.lora_up, _child_module.lora_down)) if len(loras) == 0: raise ValueError("No lora injected.") return loras def save_lora_weight( model, path="./lora.pt", target_replace_module=None, save_safetensors: bool = False, d_type: dtype = torch.float32 ): if target_replace_module is None: target_replace_module = DEFAULT_TARGET_REPLACE weights = [] for _up, _down in extract_lora_ups_down( model, target_replace_module=target_replace_module ): weights.append(_up.weight.to("cpu", dtype=d_type)) weights.append(_down.weight.to("cpu", dtype=d_type)) if save_safetensors: path = path.replace(".pt", ".safetensors") save_safeloras(weights, path) else: torch.save(weights, path) def save_lora_as_json(model, path="./lora.json"): weights = [] for _up, _down in extract_lora_ups_down(model): weights.append(_up.weight.detach().cpu().numpy().tolist()) weights.append(_down.weight.detach().cpu().numpy().tolist()) import json with open(path, "w") as f: json.dump(weights, f) def save_safeloras_with_embeds( modelmap=None, embeds=None, outpath="./lora.safetensors", ): """ Saves the Lora from multiple modules in a single safetensor file. modelmap is a dictionary of { "module name": (module, target_replace_module) } """ if embeds is None: embeds = {} if modelmap is None: modelmap = {} weights = {} metadata = {} for name, (model, target_replace_module) in modelmap.items(): metadata[name] = json.dumps(list(target_replace_module)) for i, (_up, _down) in enumerate( extract_lora_ups_down(model, target_replace_module) ): try: rank = getattr(_down, "out_features") except: rank = getattr(_down, "out_channels") metadata[f"{name}:{i}:rank"] = str(rank) weights[f"{name}:{i}:up"] = _up.weight weights[f"{name}:{i}:down"] = _down.weight for token, tensor in embeds.items(): metadata[token] = EMBED_FLAG weights[token] = tensor print(f"Saving weights to {outpath}") safe_save(weights, outpath, metadata) def save_safeloras( modelmap=None, outpath="./lora.safetensors", ): if modelmap is None: modelmap = {} return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) def convert_loras_to_safeloras_with_embeds( modelmap=None, embeds=None, outpath="./lora.safetensors", ): """ Converts the Lora from multiple pytorch .pt files into a single safetensor file. modelmap is a dictionary of { "module name": (pytorch_model_path, target_replace_module, rank) } """ if modelmap is None: modelmap = {} if embeds is None: embeds = {} weights = {} metadata = {} for name, (path, target_replace_module, r) in modelmap.items(): metadata[name] = json.dumps(list(target_replace_module)) disable_safe_unpickle() lora = torch.load(path) enable_safe_unpickle() for i, weight in enumerate(lora): is_up = i % 2 == 0 i = i // 2 if is_up: metadata[f"{name}:{i}:rank"] = str(r) weights[f"{name}:{i}:up"] = weight else: weights[f"{name}:{i}:down"] = weight for token, tensor in embeds.items(): metadata[token] = EMBED_FLAG weights[token] = tensor print(f"Saving weights to {outpath}") safe_save(weights, outpath, metadata) def convert_loras_to_safeloras( modelmap=None, outpath="./lora.safetensors", ): if modelmap is None: modelmap = {} convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) def parse_safeloras( safeloras, ) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]: """ Converts a loaded safetensor file that contains a set of module Loras into Parameters and other information Output is a dictionary of { "module name": ( [list of weights], [list of ranks], target_replacement_modules ) } """ loras = {} metadata = safeloras.metadata() get_name = lambda k: k.split(":")[0] keys = list(safeloras.keys()) keys.sort(key=get_name) for name, module_keys in groupby(keys, get_name): info = metadata.get(name) if not info: raise ValueError( f"Tensor {name} has no metadata - is this a Lora safetensor?" ) # Skip Textual Inversion embeds if info == EMBED_FLAG: continue # Handle Loras # Extract the targets target = json.loads(info) # Build the result lists - Python needs us to preallocate lists to insert into them module_keys = list(module_keys) ranks = [4] * (len(module_keys) // 2) weights = [None] * len(module_keys) for key in module_keys: # Split the model name and index out of the key _, idx, direction = key.split(":") idx = int(idx) # Add the rank ranks[idx] = int(metadata[f"{name}:{idx}:rank"]) # Insert the weight into the list idx = idx * 2 + (1 if direction == "down" else 0) weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key)) loras[name] = (weights, ranks, target) return loras def parse_safeloras_embeds( safeloras, ) -> Dict[str, torch.Tensor]: """ Converts a loaded safetensor file that contains Textual Inversion embeds into a dictionary of embed_token: Tensor """ embeds = {} metadata = safeloras.metadata() for key in safeloras.keys(): # Only handle Textual Inversion embeds meta = metadata.get(key) if not meta or meta != EMBED_FLAG: continue embeds[key] = safeloras.get_tensor(key) return embeds def load_safeloras(path, device="cpu"): safeloras = safe_open(path, framework="pt", device=device) return parse_safeloras(safeloras) def load_safeloras_embeds(path, device="cpu"): safeloras = safe_open(path, framework="pt", device=device) return parse_safeloras_embeds(safeloras) def load_safeloras_both(path, device="cpu"): safeloras = safe_open(path, framework="pt", device=device) return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras) def collapse_lora(model, alpha=1.0): for _module, name, _child_module in _find_modules( model, UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE, search_class=[LoraInjectedLinear, LoraInjectedConv2d], ): if isinstance(_child_module, LoraInjectedLinear): _child_module.linear.weight = nn.Parameter( _child_module.linear.weight.data + alpha * ( _child_module.lora_up.weight.data @ _child_module.lora_down.weight.data ) .type(_child_module.linear.weight.dtype) .to(_child_module.linear.weight.device) ) else: _child_module.conv.weight = nn.Parameter( _child_module.conv.weight.data + alpha * ( _child_module.lora_up.weight.data.flatten(start_dim=1) @ _child_module.lora_down.weight.data.flatten(start_dim=1) ) .reshape(_child_module.conv.weight.data.shape) .type(_child_module.conv.weight.dtype) .to(_child_module.conv.weight.device) ) def monkeypatch_or_replace_lora( model, loras, target_replace_module=None, r: Union[int, List[int]] = 4, ): if target_replace_module is None: target_replace_module = DEFAULT_TARGET_REPLACE for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] ): _source = ( _child_module.linear if isinstance(_child_module, LoraInjectedLinear) else _child_module ) weight = _source.weight bias = _source.bias _tmp = LoraInjectedLinear( _source.in_features, _source.out_features, _source.bias is not None, r=r.pop(0) if isinstance(r, list) else r, ) _tmp.linear.weight = weight if bias is not None: _tmp.linear.bias = bias # switch the module _module._modules[name] = _tmp up_weight = loras.pop(0) down_weight = loras.pop(0) _module._modules[name].lora_up.weight = nn.Parameter( up_weight.type(weight.dtype) ) _module._modules[name].lora_down.weight = nn.Parameter( down_weight.type(weight.dtype) ) _module._modules[name].to(weight.device) def monkeypatch_or_replace_lora_extended( model, loras, target_replace_module=None, r: Union[int, List[int]] = 4, ): if target_replace_module is None: target_replace_module = DEFAULT_TARGET_REPLACE for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d], ): if (_child_module.__class__ == nn.Linear) or ( _child_module.__class__ == LoraInjectedLinear ): if len(loras[0].shape) != 2: continue _source = ( _child_module.linear if isinstance(_child_module, LoraInjectedLinear) else _child_module ) weight = _source.weight bias = _source.bias _tmp = LoraInjectedLinear( _source.in_features, _source.out_features, _source.bias is not None, r=r.pop(0) if isinstance(r, list) else r, ) _tmp.linear.weight = weight if bias is not None: _tmp.linear.bias = bias elif (_child_module.__class__ == nn.Conv2d) or ( _child_module.__class__ == LoraInjectedConv2d ): if len(loras[0].shape) != 4: continue _source = ( _child_module.conv if isinstance(_child_module, LoraInjectedConv2d) else _child_module ) weight = _source.weight bias = _source.bias _tmp = LoraInjectedConv2d( _source.in_channels, _source.out_channels, _source.kernel_size, _source.stride, _source.padding, _source.dilation, _source.groups, _source.bias is not None, r=r.pop(0) if isinstance(r, list) else r, ) _tmp.conv.weight = weight if bias is not None: _tmp.conv.bias = bias # switch the module _module._modules[name] = _tmp up_weight = loras.pop(0) down_weight = loras.pop(0) _module._modules[name].lora_up.weight = nn.Parameter( up_weight.type(weight.dtype) ) _module._modules[name].lora_down.weight = nn.Parameter( down_weight.type(weight.dtype) ) _module._modules[name].to(weight.device) def monkeypatch_or_replace_safeloras(models, safeloras): loras = parse_safeloras(safeloras) for name, (lora, ranks, target) in loras.items(): model = getattr(models, name, None) if not model: print(f"No model provided for {name}, contained in Lora") continue monkeypatch_or_replace_lora_extended(model, lora, target, ranks) def monkeypatch_remove_lora(model): for _module, name, _child_module in _find_modules( model, search_class=[LoraInjectedLinear, LoraInjectedConv2d] ): if isinstance(_child_module, LoraInjectedLinear): _source = _child_module.linear weight, bias = _source.weight, _source.bias _tmp = nn.Linear( _source.in_features, _source.out_features, bias is not None ) _tmp.weight = weight if bias is not None: _tmp.bias = bias else: _source = _child_module.conv weight, bias = _source.weight, _source.bias _tmp = nn.Conv2d( in_channels=_source.in_channels, out_channels=_source.out_channels, kernel_size=_source.kernel_size, stride=_source.stride, padding=_source.padding, dilation=_source.dilation, groups=_source.groups, bias=bias is not None, ) _tmp.weight = weight if bias is not None: _tmp.bias = bias _module._modules[name] = _tmp def monkeypatch_add_lora( model, loras, target_replace_module=None, alpha: float = 1.0, beta: float = 1.0, ): if target_replace_module is None: target_replace_module = DEFAULT_TARGET_REPLACE for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[LoraInjectedLinear] ): weight = _child_module.linear.weight up_weight = loras.pop(0) down_weight = loras.pop(0) _module._modules[name].lora_up.weight = nn.Parameter( up_weight.type(weight.dtype).to(weight.device) * alpha + _module._modules[name].lora_up.weight.to(weight.device) * beta ) _module._modules[name].lora_down.weight = nn.Parameter( down_weight.type(weight.dtype).to(weight.device) * alpha + _module._modules[name].lora_down.weight.to(weight.device) * beta ) _module._modules[name].to(weight.device) def tune_lora_scale(model, alpha: float = 1.0): for _module in model.modules(): if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]: _module.scale = alpha def _text_lora_path(path: str) -> str: assert path.endswith(".pt"), "Only .pt files are supported" return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) def _text_lora_path_ui(path: str) -> str: assert path.endswith(".pt"), "Only .pt files are supported" return path.replace(".pt", "_txt.pt") def _ti_lora_path(path: str) -> str: assert path.endswith(".pt"), "Only .pt files are supported" return ".".join(path.split(".")[:-1] + ["ti", "pt"]) def apply_learned_embed_in_clip( learned_embeds, text_encoder, tokenizer, token: Optional[Union[str, List[str]]] = None, idempotent=False, ): if isinstance(token, str): trained_tokens = [token] elif isinstance(token, list): assert len(learned_embeds.keys()) == len( token ), "The number of tokens and the number of embeds should be the same" trained_tokens = token else: trained_tokens = list(learned_embeds.keys()) for token in trained_tokens: print(token) embeds = learned_embeds[token] # cast to dtype of text_encoder dtype = text_encoder.get_input_embeddings().weight.dtype num_added_tokens = tokenizer.add_tokens(token) i = 1 if not idempotent: while num_added_tokens == 0: print(f"The tokenizer already contains the token {token}.") token = f"{token[:-1]}-{i}>" print(f"Attempting to add the token {token}.") num_added_tokens = tokenizer.add_tokens(token) i += 1 elif num_added_tokens == 0 and idempotent: print(f"The tokenizer already contains the token {token}.") print(f"Replacing {token} embedding.") # resize the token embeddings text_encoder.resize_token_embeddings(len(tokenizer)) # get the id for the token and assign the embeds token_id = tokenizer.convert_tokens_to_ids(token) text_encoder.get_input_embeddings().weight.data[token_id] = embeds return token def load_learned_embed_in_clip( learned_embeds_path, text_encoder, tokenizer, token: Optional[Union[str, List[str]]] = None, idempotent=False, ): disable_safe_unpickle() learned_embeds = torch.load(learned_embeds_path) enable_safe_unpickle() apply_learned_embed_in_clip( learned_embeds, text_encoder, tokenizer, token, idempotent ) def patch_pipe( pipe, maybe_unet_path, token: Optional[str] = None, r: int = 4, r_txt: int = 4, patch_unet=True, patch_text=True, patch_ti=False, idempotent_token=True, unet_target_replace_module=None, text_target_replace_module=None, ): if unet_target_replace_module is None: unet_target_replace_module = DEFAULT_TARGET_REPLACE if text_target_replace_module is None: text_target_replace_module = TEXT_ENCODER_DEFAULT_TARGET_REPLACE if maybe_unet_path.endswith(".pt"): # torch format if maybe_unet_path.endswith(".ti.pt"): unet_path = maybe_unet_path[:-6] + ".pt" elif maybe_unet_path.endswith(".text_encoder.pt"): unet_path = maybe_unet_path[:-16] + ".pt" else: unet_path = maybe_unet_path ti_path = _ti_lora_path(unet_path) text_path = _text_lora_path_ui(unet_path) disable_safe_unpickle() if patch_unet: print("LoRA : Patching Unet") lora_patch = get_target_module( "patch", bool(unet_target_replace_module == UNET_EXTENDED_TARGET_REPLACE) ) lora_patch( pipe.unet, torch.load(unet_path), r=r, target_replace_module=unet_target_replace_module, ) if patch_text: print("LoRA : Patching text encoder") monkeypatch_or_replace_lora( pipe.text_encoder, torch.load(text_path), target_replace_module=text_target_replace_module, r=r_txt, ) enable_safe_unpickle() if patch_ti: print("LoRA : Patching token input") token = load_learned_embed_in_clip( ti_path, pipe.text_encoder, pipe.tokenizer, token=token, idempotent=idempotent_token, ) elif maybe_unet_path.endswith(".safetensors"): safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu") monkeypatch_or_replace_safeloras(pipe, safeloras) tok_dict = parse_safeloras_embeds(safeloras) if patch_ti: apply_learned_embed_in_clip( tok_dict, pipe.text_encoder, pipe.tokenizer, token=token, idempotent=idempotent_token, ) return tok_dict @torch.no_grad() def inspect_lora(model): moved = {} for name, _module in model.named_modules(): if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]: ups = _module.lora_up.weight.data.clone() downs = _module.lora_down.weight.data.clone() wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1) dist = wght.flatten().abs().mean().item() if name in moved: moved[name].append(dist) else: moved[name] = [dist] return moved # Save loras from a diffusionpipeline def save_pipe( pipeline, model_base, save_safetensors=False, target_replace_module_text=None, target_replace_module_unet=None ): if target_replace_module_unet is None: target_replace_module_unet = DEFAULT_TARGET_REPLACE if target_replace_module_text is None: target_replace_module_text = TEXT_ENCODER_DEFAULT_TARGET_REPLACE save_unet_path = f"{model_base}" save_lora_weight( pipeline.unet, save_unet_path, target_replace_module=target_replace_module_unet, save_safetensors=save_safetensors ) print("Unet saved to ", save_unet_path) save_txt_path = _text_lora_path(save_unet_path), save_lora_weight( pipeline.text_encoder, save_txt_path, target_replace_module=target_replace_module_text, save_safetensors=save_safetensors ) print("Text Encoder saved to ", _text_lora_path(save_txt_path)) return save_unet_path, save_txt_path def save_all( unet, text_encoder, save_path, placeholder_token_ids=None, placeholder_tokens=None, save_lora=True, save_ti=True, target_replace_module_text=None, target_replace_module_unet=None, safe_form=True, ): if target_replace_module_text is None: target_replace_module_text = TEXT_ENCODER_DEFAULT_TARGET_REPLACE if target_replace_module_unet is None: target_replace_module_unet = DEFAULT_TARGET_REPLACE if not safe_form: # save ti if save_ti: ti_path = _ti_lora_path(save_path) learned_embeds_dict = {} for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] print( f"Current Learned Embeddings for {tok}:, id {tok_id} ", learned_embeds[:4], ) learned_embeds_dict[tok] = learned_embeds.detach().cpu() torch.save(learned_embeds_dict, ti_path) print("Ti saved to ", ti_path) # save text encoder if save_lora: save_lora_weight( unet, save_path, target_replace_module=target_replace_module_unet ) print("Unet saved to ", save_path) save_lora_weight( text_encoder, _text_lora_path(save_path), target_replace_module=target_replace_module_text, ) print("Text Encoder saved to ", _text_lora_path(save_path)) else: assert save_path.endswith( ".safetensors" ), f"Save path : {save_path} should end with .safetensors" loras = {} embeds = {} if save_lora: loras["unet"] = (unet, target_replace_module_unet) loras["text_encoder"] = (text_encoder, target_replace_module_text) if save_ti: for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] print( f"Current Learned Embeddings for {tok}:, id {tok_id} ", learned_embeds[:4], ) embeds[tok] = learned_embeds.detach().cpu() save_safeloras_with_embeds(loras, embeds, save_path) def merge_loras_to_pipe( pipline, lora_path=None, lora_alpha: float = 1, lora_txt_alpha: float = 1, r: int = 4, r_txt: int = 4, unet_target_module=None ): if unet_target_module is None: unet_target_module = UNET_DEFAULT_TARGET_REPLACE print( f"Merging UNET/CLIP with LoRA from {lora_path}. Merging ratio : UNET: {lora_alpha}, CLIP: {lora_txt_alpha}." ) patch_pipe( pipline, lora_path, r=r, r_txt=r_txt, unet_target_replace_module=unet_target_module ) collapse_lora(pipline.unet, lora_alpha) collapse_lora(pipline.text_encoder, lora_txt_alpha) monkeypatch_remove_lora(pipline.unet) monkeypatch_remove_lora(pipline.text_encoder) def merge_lora_to_model( model: nn.Module, lora: dict, is_tenc: bool, use_extended: bool, rank: int = 4, weight: float = 1.0 ): target_module = get_target_module("module", use_extended) if is_tenc: target_module = TEXT_ENCODER_EXTENDED_TARGET_REPLACE if use_extended else TEXT_ENCODER_DEFAULT_TARGET_REPLACE get_target_module("patch", use_extended)(model, lora, target_replace_module=target_module, r=rank) collapse_lora(model, weight) monkeypatch_remove_lora(model) def get_target_module(target_type: str = "injection", use_extended: bool = False): if target_type == "injection": return inject_trainable_lora if not use_extended else inject_trainable_lora_extended if target_type == "module": return UNET_DEFAULT_TARGET_REPLACE if not use_extended else UNET_EXTENDED_TARGET_REPLACE if target_type == "patch": return monkeypatch_or_replace_lora_extended if use_extended else monkeypatch_or_replace_lora def set_lora_requires_grad(model, requires_grad): for name, param in model.named_parameters(): if "lora" in name: if param.requires_grad != requires_grad: param.requires_grad = requires_grad