1252 lines
39 KiB
Python
1252 lines
39 KiB
Python
import json
|
|
from itertools import groupby
|
|
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
# try:
|
|
# from safetensors.torch import safe_open
|
|
# from safetensors.torch import save_file as safe_save
|
|
|
|
# safetensors_available = True
|
|
# except ImportError:
|
|
# from .safe_open import safe_open
|
|
|
|
# def safe_save(
|
|
# tensors: Dict[str, torch.Tensor],
|
|
# filename: str,
|
|
# metadata: Optional[Dict[str, str]] = None,
|
|
# ) -> None:
|
|
# raise EnvironmentError(
|
|
# "Saving safetensors requires the safetensors library. Please install with pip or similar."
|
|
# )
|
|
|
|
# safetensors_available = False
|
|
|
|
|
|
class LoraInjectedLinear(nn.Module):
|
|
def __init__(
|
|
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
|
|
):
|
|
super().__init__()
|
|
|
|
if r > min(in_features, out_features):
|
|
raise ValueError(
|
|
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
|
|
)
|
|
self.r = r
|
|
self.linear = nn.Linear(in_features, out_features, bias)
|
|
self.lora_down = nn.Linear(in_features, r, bias=False)
|
|
self.dropout = nn.Dropout(dropout_p)
|
|
self.lora_up = nn.Linear(r, out_features, bias=False)
|
|
self.scale = scale
|
|
self.selector = nn.Identity()
|
|
|
|
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
|
nn.init.zeros_(self.lora_up.weight)
|
|
|
|
def forward(self, input):
|
|
return (
|
|
self.linear(input)
|
|
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
|
* self.scale
|
|
)
|
|
|
|
def realize_as_lora(self):
|
|
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
|
|
|
def set_selector_from_diag(self, diag: torch.Tensor):
|
|
# diag is a 1D tensor of size (r,)
|
|
assert diag.shape == (self.r,)
|
|
self.selector = nn.Linear(self.r, self.r, bias=False)
|
|
self.selector.weight.data = torch.diag(diag)
|
|
self.selector.weight.data = self.selector.weight.data.to(
|
|
self.lora_up.weight.device
|
|
).to(self.lora_up.weight.dtype)
|
|
|
|
|
|
class LoraInjectedConv2d(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups: int = 1,
|
|
bias: bool = True,
|
|
r: int = 4,
|
|
dropout_p: float = 0.1,
|
|
scale: float = 1.0,
|
|
):
|
|
super().__init__()
|
|
if r > min(in_channels, out_channels):
|
|
raise ValueError(
|
|
f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
|
|
)
|
|
self.r = r
|
|
self.conv = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=bias,
|
|
)
|
|
|
|
self.lora_down = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=r,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=False,
|
|
)
|
|
self.dropout = nn.Dropout(dropout_p)
|
|
self.lora_up = nn.Conv2d(
|
|
in_channels=r,
|
|
out_channels=out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=False,
|
|
)
|
|
self.selector = nn.Identity()
|
|
self.scale = scale
|
|
|
|
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
|
nn.init.zeros_(self.lora_up.weight)
|
|
|
|
def forward(self, input):
|
|
return (
|
|
self.conv(input)
|
|
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
|
* self.scale
|
|
)
|
|
|
|
def realize_as_lora(self):
|
|
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
|
|
|
def set_selector_from_diag(self, diag: torch.Tensor):
|
|
# diag is a 1D tensor of size (r,)
|
|
assert diag.shape == (self.r,)
|
|
self.selector = nn.Conv2d(
|
|
in_channels=self.r,
|
|
out_channels=self.r,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=False,
|
|
)
|
|
self.selector.weight.data = torch.diag(diag)
|
|
|
|
# same device + dtype as lora_up
|
|
self.selector.weight.data = self.selector.weight.data.to(
|
|
self.lora_up.weight.device
|
|
).to(self.lora_up.weight.dtype)
|
|
|
|
|
|
UNET_DEFAULT_TARGET_REPLACE = {"MemoryEfficientCrossAttention","CrossAttention", "Attention", "GEGLU"}
|
|
|
|
UNET_EXTENDED_TARGET_REPLACE = {"TimestepEmbedSequential","SpatialTemporalTransformer", "MemoryEfficientCrossAttention","CrossAttention", "Attention", "GEGLU"}
|
|
|
|
TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
|
|
|
|
TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPMLP","CLIPAttention"}
|
|
|
|
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
|
|
|
|
EMBED_FLAG = "<embed>"
|
|
|
|
|
|
def _find_children(
|
|
model,
|
|
search_class: List[Type[nn.Module]] = [nn.Linear],
|
|
):
|
|
"""
|
|
Find all modules of a certain class (or union of classes).
|
|
|
|
Returns all matching modules, along with the parent of those moduless and the
|
|
names they are referenced by.
|
|
"""
|
|
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
|
for parent in model.modules():
|
|
for name, module in parent.named_children():
|
|
if any([isinstance(module, _class) for _class in search_class]):
|
|
yield parent, name, module
|
|
|
|
|
|
def _find_modules_v2(
|
|
model,
|
|
ancestor_class: Optional[Set[str]] = None,
|
|
search_class: List[Type[nn.Module]] = [nn.Linear],
|
|
exclude_children_of: Optional[List[Type[nn.Module]]] = [
|
|
LoraInjectedLinear,
|
|
LoraInjectedConv2d,
|
|
],
|
|
):
|
|
"""
|
|
Find all modules of a certain class (or union of classes) that are direct or
|
|
indirect descendants of other modules of a certain class (or union of classes).
|
|
|
|
Returns all matching modules, along with the parent of those moduless and the
|
|
names they are referenced by.
|
|
"""
|
|
|
|
# Get the targets we should replace all linears under
|
|
if type(ancestor_class) is not set:
|
|
ancestor_class = set(ancestor_class)
|
|
print(ancestor_class)
|
|
if ancestor_class is not None:
|
|
ancestors = (
|
|
module
|
|
for module in model.modules()
|
|
if module.__class__.__name__ in ancestor_class
|
|
)
|
|
else:
|
|
# this, incase you want to naively iterate over all modules.
|
|
ancestors = [module for module in model.modules()]
|
|
|
|
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
|
for ancestor in ancestors:
|
|
for fullname, module in ancestor.named_children():
|
|
if any([isinstance(module, _class) for _class in search_class]):
|
|
# Find the direct parent if this is a descendant, not a child, of target
|
|
*path, name = fullname.split(".")
|
|
parent = ancestor
|
|
while path:
|
|
parent = parent.get_submodule(path.pop(0))
|
|
# Skip this linear if it's a child of a LoraInjectedLinear
|
|
if exclude_children_of and any(
|
|
[isinstance(parent, _class) for _class in exclude_children_of]
|
|
):
|
|
continue
|
|
# Otherwise, yield it
|
|
yield parent, name, module
|
|
|
|
|
|
def _find_modules_old(
|
|
model,
|
|
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
|
|
search_class: List[Type[nn.Module]] = [nn.Linear],
|
|
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
|
|
):
|
|
ret = []
|
|
for _module in model.modules():
|
|
if _module.__class__.__name__ in ancestor_class:
|
|
|
|
for name, _child_module in _module.named_children():
|
|
if _child_module.__class__ in search_class:
|
|
ret.append((_module, name, _child_module))
|
|
print(ret)
|
|
return ret
|
|
|
|
|
|
_find_modules = _find_modules_v2
|
|
|
|
|
|
def inject_trainable_lora(
|
|
model: nn.Module,
|
|
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
|
|
r: int = 4,
|
|
loras=None, # path to lora .pt
|
|
verbose: bool = False,
|
|
dropout_p: float = 0.0,
|
|
scale: float = 1.0,
|
|
):
|
|
"""
|
|
inject lora into model, and returns lora parameter groups.
|
|
"""
|
|
|
|
require_grad_params = []
|
|
names = []
|
|
|
|
if loras != None:
|
|
loras = torch.load(loras)
|
|
|
|
for _module, name, _child_module in _find_modules(
|
|
model, target_replace_module, search_class=[nn.Linear]
|
|
):
|
|
weight = _child_module.weight
|
|
bias = _child_module.bias
|
|
if verbose:
|
|
print("LoRA Injection : injecting lora into ", name)
|
|
print("LoRA Injection : weight shape", weight.shape)
|
|
_tmp = LoraInjectedLinear(
|
|
_child_module.in_features,
|
|
_child_module.out_features,
|
|
_child_module.bias is not None,
|
|
r=r,
|
|
dropout_p=dropout_p,
|
|
scale=scale,
|
|
)
|
|
_tmp.linear.weight = weight
|
|
if bias is not None:
|
|
_tmp.linear.bias = bias
|
|
|
|
# switch the module
|
|
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
|
_module._modules[name] = _tmp
|
|
|
|
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
|
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
|
|
|
if loras != None:
|
|
_module._modules[name].lora_up.weight = loras.pop(0)
|
|
_module._modules[name].lora_down.weight = loras.pop(0)
|
|
|
|
_module._modules[name].lora_up.weight.requires_grad = True
|
|
_module._modules[name].lora_down.weight.requires_grad = True
|
|
names.append(name)
|
|
|
|
return require_grad_params, names
|
|
|
|
|
|
def inject_trainable_lora_extended(
|
|
model: nn.Module,
|
|
target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
|
|
r: int = 4,
|
|
loras=None, # path to lora .pt
|
|
):
|
|
"""
|
|
inject lora into model, and returns lora parameter groups.
|
|
"""
|
|
|
|
require_grad_params = []
|
|
names = []
|
|
|
|
if loras != None:
|
|
loras = torch.load(loras)
|
|
|
|
for _module, name, _child_module in _find_modules(
|
|
model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
|
|
):
|
|
if _child_module.__class__ == nn.Linear:
|
|
weight = _child_module.weight
|
|
bias = _child_module.bias
|
|
_tmp = LoraInjectedLinear(
|
|
_child_module.in_features,
|
|
_child_module.out_features,
|
|
_child_module.bias is not None,
|
|
r=r,
|
|
)
|
|
_tmp.linear.weight = weight
|
|
if bias is not None:
|
|
_tmp.linear.bias = bias
|
|
elif _child_module.__class__ == nn.Conv2d:
|
|
weight = _child_module.weight
|
|
bias = _child_module.bias
|
|
_tmp = LoraInjectedConv2d(
|
|
_child_module.in_channels,
|
|
_child_module.out_channels,
|
|
_child_module.kernel_size,
|
|
_child_module.stride,
|
|
_child_module.padding,
|
|
_child_module.dilation,
|
|
_child_module.groups,
|
|
_child_module.bias is not None,
|
|
r=r,
|
|
)
|
|
|
|
_tmp.conv.weight = weight
|
|
if bias is not None:
|
|
_tmp.conv.bias = bias
|
|
|
|
# switch the module
|
|
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
|
if bias is not None:
|
|
_tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
|
|
|
|
_module._modules[name] = _tmp
|
|
|
|
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
|
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
|
|
|
if loras != None:
|
|
_module._modules[name].lora_up.weight = loras.pop(0)
|
|
_module._modules[name].lora_down.weight = loras.pop(0)
|
|
|
|
_module._modules[name].lora_up.weight.requires_grad = True
|
|
_module._modules[name].lora_down.weight.requires_grad = True
|
|
names.append(name)
|
|
|
|
return require_grad_params, names
|
|
|
|
|
|
def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
|
|
|
|
loras = []
|
|
|
|
for _m, _n, _child_module in _find_modules(
|
|
model,
|
|
target_replace_module,
|
|
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
|
):
|
|
loras.append((_child_module.lora_up, _child_module.lora_down))
|
|
|
|
if len(loras) == 0:
|
|
raise ValueError("No lora injected.")
|
|
|
|
return loras
|
|
|
|
|
|
def extract_lora_as_tensor(
|
|
model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
|
|
):
|
|
|
|
loras = []
|
|
|
|
for _m, _n, _child_module in _find_modules(
|
|
model,
|
|
target_replace_module,
|
|
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
|
):
|
|
up, down = _child_module.realize_as_lora()
|
|
if as_fp16:
|
|
up = up.to(torch.float16)
|
|
down = down.to(torch.float16)
|
|
|
|
loras.append((up, down))
|
|
|
|
if len(loras) == 0:
|
|
raise ValueError("No lora injected.")
|
|
|
|
return loras
|
|
|
|
|
|
def save_lora_weight(
|
|
model,
|
|
path="./lora.pt",
|
|
target_replace_module=DEFAULT_TARGET_REPLACE,
|
|
):
|
|
weights = []
|
|
for _up, _down in extract_lora_ups_down(
|
|
model, target_replace_module=target_replace_module
|
|
):
|
|
weights.append(_up.weight.to("cpu").to(torch.float16))
|
|
weights.append(_down.weight.to("cpu").to(torch.float16))
|
|
|
|
torch.save(weights, path)
|
|
|
|
|
|
def save_lora_as_json(model, path="./lora.json"):
|
|
weights = []
|
|
for _up, _down in extract_lora_ups_down(model):
|
|
weights.append(_up.weight.detach().cpu().numpy().tolist())
|
|
weights.append(_down.weight.detach().cpu().numpy().tolist())
|
|
|
|
import json
|
|
|
|
with open(path, "w") as f:
|
|
json.dump(weights, f)
|
|
|
|
|
|
def save_safeloras_with_embeds(
|
|
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
|
embeds: Dict[str, torch.Tensor] = {},
|
|
outpath="./lora.safetensors",
|
|
):
|
|
"""
|
|
Saves the Lora from multiple modules in a single safetensor file.
|
|
|
|
modelmap is a dictionary of {
|
|
"module name": (module, target_replace_module)
|
|
}
|
|
"""
|
|
weights = {}
|
|
metadata = {}
|
|
|
|
for name, (model, target_replace_module) in modelmap.items():
|
|
metadata[name] = json.dumps(list(target_replace_module))
|
|
|
|
for i, (_up, _down) in enumerate(
|
|
extract_lora_as_tensor(model, target_replace_module)
|
|
):
|
|
rank = _down.shape[0]
|
|
|
|
metadata[f"{name}:{i}:rank"] = str(rank)
|
|
weights[f"{name}:{i}:up"] = _up
|
|
weights[f"{name}:{i}:down"] = _down
|
|
|
|
for token, tensor in embeds.items():
|
|
metadata[token] = EMBED_FLAG
|
|
weights[token] = tensor
|
|
|
|
print(f"Saving weights to {outpath}")
|
|
safe_save(weights, outpath, metadata)
|
|
|
|
|
|
def save_safeloras(
|
|
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
|
outpath="./lora.safetensors",
|
|
):
|
|
return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
|
|
|
|
|
def convert_loras_to_safeloras_with_embeds(
|
|
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
|
embeds: Dict[str, torch.Tensor] = {},
|
|
outpath="./lora.safetensors",
|
|
):
|
|
"""
|
|
Converts the Lora from multiple pytorch .pt files into a single safetensor file.
|
|
|
|
modelmap is a dictionary of {
|
|
"module name": (pytorch_model_path, target_replace_module, rank)
|
|
}
|
|
"""
|
|
|
|
weights = {}
|
|
metadata = {}
|
|
|
|
for name, (path, target_replace_module, r) in modelmap.items():
|
|
metadata[name] = json.dumps(list(target_replace_module))
|
|
|
|
lora = torch.load(path)
|
|
for i, weight in enumerate(lora):
|
|
is_up = i % 2 == 0
|
|
i = i // 2
|
|
|
|
if is_up:
|
|
metadata[f"{name}:{i}:rank"] = str(r)
|
|
weights[f"{name}:{i}:up"] = weight
|
|
else:
|
|
weights[f"{name}:{i}:down"] = weight
|
|
|
|
for token, tensor in embeds.items():
|
|
metadata[token] = EMBED_FLAG
|
|
weights[token] = tensor
|
|
|
|
print(f"Saving weights to {outpath}")
|
|
safe_save(weights, outpath, metadata)
|
|
|
|
|
|
def convert_loras_to_safeloras(
|
|
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
|
outpath="./lora.safetensors",
|
|
):
|
|
convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
|
|
|
|
|
def parse_safeloras(
|
|
safeloras,
|
|
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
|
|
"""
|
|
Converts a loaded safetensor file that contains a set of module Loras
|
|
into Parameters and other information
|
|
|
|
Output is a dictionary of {
|
|
"module name": (
|
|
[list of weights],
|
|
[list of ranks],
|
|
target_replacement_modules
|
|
)
|
|
}
|
|
"""
|
|
loras = {}
|
|
metadata = safeloras.metadata()
|
|
|
|
get_name = lambda k: k.split(":")[0]
|
|
|
|
keys = list(safeloras.keys())
|
|
keys.sort(key=get_name)
|
|
|
|
for name, module_keys in groupby(keys, get_name):
|
|
info = metadata.get(name)
|
|
|
|
if not info:
|
|
raise ValueError(
|
|
f"Tensor {name} has no metadata - is this a Lora safetensor?"
|
|
)
|
|
|
|
# Skip Textual Inversion embeds
|
|
if info == EMBED_FLAG:
|
|
continue
|
|
|
|
# Handle Loras
|
|
# Extract the targets
|
|
target = json.loads(info)
|
|
|
|
# Build the result lists - Python needs us to preallocate lists to insert into them
|
|
module_keys = list(module_keys)
|
|
ranks = [4] * (len(module_keys) // 2)
|
|
weights = [None] * len(module_keys)
|
|
|
|
for key in module_keys:
|
|
# Split the model name and index out of the key
|
|
_, idx, direction = key.split(":")
|
|
idx = int(idx)
|
|
|
|
# Add the rank
|
|
ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
|
|
|
|
# Insert the weight into the list
|
|
idx = idx * 2 + (1 if direction == "down" else 0)
|
|
weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
|
|
|
|
loras[name] = (weights, ranks, target)
|
|
|
|
return loras
|
|
|
|
|
|
def parse_safeloras_embeds(
|
|
safeloras,
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Converts a loaded safetensor file that contains Textual Inversion embeds into
|
|
a dictionary of embed_token: Tensor
|
|
"""
|
|
embeds = {}
|
|
metadata = safeloras.metadata()
|
|
|
|
for key in safeloras.keys():
|
|
# Only handle Textual Inversion embeds
|
|
meta = metadata.get(key)
|
|
if not meta or meta != EMBED_FLAG:
|
|
continue
|
|
|
|
embeds[key] = safeloras.get_tensor(key)
|
|
|
|
return embeds
|
|
|
|
def net_load_lora(net, checkpoint_path, alpha=1.0, remove=False):
|
|
visited=[]
|
|
state_dict = torch.load(checkpoint_path)
|
|
for k, v in state_dict.items():
|
|
state_dict[k] = v.to(net.device)
|
|
|
|
for key in state_dict:
|
|
if ".alpha" in key or key in visited:
|
|
continue
|
|
layer_infos = key.split(".")[:-2] # remove lora_up and down weight
|
|
curr_layer = net
|
|
# find the target layer
|
|
temp_name = layer_infos.pop(0)
|
|
while len(layer_infos) > -1:
|
|
curr_layer = curr_layer.__getattr__(temp_name)
|
|
if len(layer_infos) > 0:
|
|
temp_name = layer_infos.pop(0)
|
|
elif len(layer_infos) == 0:
|
|
break
|
|
if curr_layer.__class__ not in [nn.Linear, nn.Conv2d]:
|
|
print('missing param at:', key)
|
|
continue
|
|
pair_keys = []
|
|
if "lora_down" in key:
|
|
pair_keys.append(key.replace("lora_down", "lora_up"))
|
|
pair_keys.append(key)
|
|
else:
|
|
pair_keys.append(key)
|
|
pair_keys.append(key.replace("lora_up", "lora_down"))
|
|
|
|
# update weight
|
|
if len(state_dict[pair_keys[0]].shape) == 4:
|
|
# for conv
|
|
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
|
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
|
if remove:
|
|
curr_layer.weight.data -= alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
else:
|
|
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
else:
|
|
# for linear
|
|
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
|
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
|
if remove:
|
|
curr_layer.weight.data -= alpha * torch.mm(weight_up, weight_down)
|
|
else:
|
|
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
|
|
|
# update visited list
|
|
for item in pair_keys:
|
|
visited.append(item)
|
|
print('load_weight_num:',len(visited))
|
|
return
|
|
|
|
def change_lora(model, inject_lora=False, lora_scale=1.0, lora_path='', last_time_lora='', last_time_lora_scale=1.0):
|
|
# remove lora
|
|
if last_time_lora != '':
|
|
net_load_lora(model, last_time_lora, alpha=last_time_lora_scale, remove=True)
|
|
# add new lora
|
|
if inject_lora:
|
|
net_load_lora(model, lora_path, alpha=lora_scale)
|
|
|
|
|
|
def net_load_lora_v2(net, checkpoint_path, alpha=1.0, remove=False, origin_weight=None):
|
|
visited=[]
|
|
state_dict = torch.load(checkpoint_path)
|
|
for k, v in state_dict.items():
|
|
state_dict[k] = v.to(net.device)
|
|
|
|
for key in state_dict:
|
|
if ".alpha" in key or key in visited:
|
|
continue
|
|
layer_infos = key.split(".")[:-2] # remove lora_up and down weight
|
|
curr_layer = net
|
|
# find the target layer
|
|
temp_name = layer_infos.pop(0)
|
|
while len(layer_infos) > -1:
|
|
curr_layer = curr_layer.__getattr__(temp_name)
|
|
if len(layer_infos) > 0:
|
|
temp_name = layer_infos.pop(0)
|
|
elif len(layer_infos) == 0:
|
|
break
|
|
if curr_layer.__class__ not in [nn.Linear, nn.Conv2d]:
|
|
print('missing param at:', key)
|
|
continue
|
|
pair_keys = []
|
|
if "lora_down" in key:
|
|
pair_keys.append(key.replace("lora_down", "lora_up"))
|
|
pair_keys.append(key)
|
|
else:
|
|
pair_keys.append(key)
|
|
pair_keys.append(key.replace("lora_up", "lora_down"))
|
|
|
|
# storage weight
|
|
if origin_weight is None:
|
|
origin_weight = dict()
|
|
storage_key = key.replace("lora_down", "lora").replace("lora_up", "lora")
|
|
origin_weight[storage_key] = curr_layer.weight.data.clone()
|
|
else:
|
|
storage_key = key.replace("lora_down", "lora").replace("lora_up", "lora")
|
|
if storage_key not in origin_weight.keys():
|
|
origin_weight[storage_key] = curr_layer.weight.data.clone()
|
|
|
|
|
|
# update
|
|
if len(state_dict[pair_keys[0]].shape) == 4:
|
|
# for conv
|
|
if remove:
|
|
curr_layer.weight.data = origin_weight[storage_key].clone()
|
|
else:
|
|
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
|
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
|
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
else:
|
|
# for linear
|
|
if remove:
|
|
curr_layer.weight.data = origin_weight[storage_key].clone()
|
|
else:
|
|
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
|
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
|
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
|
|
|
# update visited list
|
|
for item in pair_keys:
|
|
visited.append(item)
|
|
print('load_weight_num:',len(visited))
|
|
return origin_weight
|
|
|
|
def change_lora_v2(model, inject_lora=False, lora_scale=1.0, lora_path='', last_time_lora='', last_time_lora_scale=1.0, origin_weight=None):
|
|
# remove lora
|
|
if last_time_lora != '':
|
|
origin_weight = net_load_lora_v2(model, last_time_lora, alpha=last_time_lora_scale, remove=True, origin_weight=origin_weight)
|
|
# add new lora
|
|
if inject_lora:
|
|
origin_weight = net_load_lora_v2(model, lora_path, alpha=lora_scale, origin_weight=origin_weight)
|
|
return origin_weight
|
|
|
|
|
|
|
|
|
|
|
|
def load_safeloras(path, device="cpu"):
|
|
safeloras = safe_open(path, framework="pt", device=device)
|
|
return parse_safeloras(safeloras)
|
|
|
|
|
|
def load_safeloras_embeds(path, device="cpu"):
|
|
safeloras = safe_open(path, framework="pt", device=device)
|
|
return parse_safeloras_embeds(safeloras)
|
|
|
|
|
|
def load_safeloras_both(path, device="cpu"):
|
|
safeloras = safe_open(path, framework="pt", device=device)
|
|
return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
|
|
|
|
|
|
def collapse_lora(model, alpha=1.0):
|
|
|
|
for _module, name, _child_module in _find_modules(
|
|
model,
|
|
UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
|
|
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
|
):
|
|
|
|
if isinstance(_child_module, LoraInjectedLinear):
|
|
print("Collapsing Lin Lora in", name)
|
|
|
|
_child_module.linear.weight = nn.Parameter(
|
|
_child_module.linear.weight.data
|
|
+ alpha
|
|
* (
|
|
_child_module.lora_up.weight.data
|
|
@ _child_module.lora_down.weight.data
|
|
)
|
|
.type(_child_module.linear.weight.dtype)
|
|
.to(_child_module.linear.weight.device)
|
|
)
|
|
|
|
else:
|
|
print("Collapsing Conv Lora in", name)
|
|
_child_module.conv.weight = nn.Parameter(
|
|
_child_module.conv.weight.data
|
|
+ alpha
|
|
* (
|
|
_child_module.lora_up.weight.data.flatten(start_dim=1)
|
|
@ _child_module.lora_down.weight.data.flatten(start_dim=1)
|
|
)
|
|
.reshape(_child_module.conv.weight.data.shape)
|
|
.type(_child_module.conv.weight.dtype)
|
|
.to(_child_module.conv.weight.device)
|
|
)
|
|
|
|
|
|
def monkeypatch_or_replace_lora(
|
|
model,
|
|
loras,
|
|
target_replace_module=DEFAULT_TARGET_REPLACE,
|
|
r: Union[int, List[int]] = 4,
|
|
):
|
|
for _module, name, _child_module in _find_modules(
|
|
model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
|
|
):
|
|
_source = (
|
|
_child_module.linear
|
|
if isinstance(_child_module, LoraInjectedLinear)
|
|
else _child_module
|
|
)
|
|
|
|
weight = _source.weight
|
|
bias = _source.bias
|
|
_tmp = LoraInjectedLinear(
|
|
_source.in_features,
|
|
_source.out_features,
|
|
_source.bias is not None,
|
|
r=r.pop(0) if isinstance(r, list) else r,
|
|
)
|
|
_tmp.linear.weight = weight
|
|
|
|
if bias is not None:
|
|
_tmp.linear.bias = bias
|
|
|
|
# switch the module
|
|
_module._modules[name] = _tmp
|
|
|
|
up_weight = loras.pop(0)
|
|
down_weight = loras.pop(0)
|
|
|
|
_module._modules[name].lora_up.weight = nn.Parameter(
|
|
up_weight.type(weight.dtype)
|
|
)
|
|
_module._modules[name].lora_down.weight = nn.Parameter(
|
|
down_weight.type(weight.dtype)
|
|
)
|
|
|
|
_module._modules[name].to(weight.device)
|
|
|
|
|
|
def monkeypatch_or_replace_lora_extended(
|
|
model,
|
|
loras,
|
|
target_replace_module=DEFAULT_TARGET_REPLACE,
|
|
r: Union[int, List[int]] = 4,
|
|
):
|
|
for _module, name, _child_module in _find_modules(
|
|
model,
|
|
target_replace_module,
|
|
search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d],
|
|
):
|
|
|
|
if (_child_module.__class__ == nn.Linear) or (
|
|
_child_module.__class__ == LoraInjectedLinear
|
|
):
|
|
if len(loras[0].shape) != 2:
|
|
continue
|
|
|
|
_source = (
|
|
_child_module.linear
|
|
if isinstance(_child_module, LoraInjectedLinear)
|
|
else _child_module
|
|
)
|
|
|
|
weight = _source.weight
|
|
bias = _source.bias
|
|
_tmp = LoraInjectedLinear(
|
|
_source.in_features,
|
|
_source.out_features,
|
|
_source.bias is not None,
|
|
r=r.pop(0) if isinstance(r, list) else r,
|
|
)
|
|
_tmp.linear.weight = weight
|
|
|
|
if bias is not None:
|
|
_tmp.linear.bias = bias
|
|
|
|
elif (_child_module.__class__ == nn.Conv2d) or (
|
|
_child_module.__class__ == LoraInjectedConv2d
|
|
):
|
|
if len(loras[0].shape) != 4:
|
|
continue
|
|
_source = (
|
|
_child_module.conv
|
|
if isinstance(_child_module, LoraInjectedConv2d)
|
|
else _child_module
|
|
)
|
|
|
|
weight = _source.weight
|
|
bias = _source.bias
|
|
_tmp = LoraInjectedConv2d(
|
|
_source.in_channels,
|
|
_source.out_channels,
|
|
_source.kernel_size,
|
|
_source.stride,
|
|
_source.padding,
|
|
_source.dilation,
|
|
_source.groups,
|
|
_source.bias is not None,
|
|
r=r.pop(0) if isinstance(r, list) else r,
|
|
)
|
|
|
|
_tmp.conv.weight = weight
|
|
|
|
if bias is not None:
|
|
_tmp.conv.bias = bias
|
|
|
|
# switch the module
|
|
_module._modules[name] = _tmp
|
|
|
|
up_weight = loras.pop(0)
|
|
down_weight = loras.pop(0)
|
|
|
|
_module._modules[name].lora_up.weight = nn.Parameter(
|
|
up_weight.type(weight.dtype)
|
|
)
|
|
_module._modules[name].lora_down.weight = nn.Parameter(
|
|
down_weight.type(weight.dtype)
|
|
)
|
|
|
|
_module._modules[name].to(weight.device)
|
|
|
|
|
|
def monkeypatch_or_replace_safeloras(models, safeloras):
|
|
loras = parse_safeloras(safeloras)
|
|
|
|
for name, (lora, ranks, target) in loras.items():
|
|
model = getattr(models, name, None)
|
|
|
|
if not model:
|
|
print(f"No model provided for {name}, contained in Lora")
|
|
continue
|
|
|
|
monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
|
|
|
|
|
|
def monkeypatch_remove_lora(model):
|
|
for _module, name, _child_module in _find_modules(
|
|
model, search_class=[LoraInjectedLinear, LoraInjectedConv2d]
|
|
):
|
|
if isinstance(_child_module, LoraInjectedLinear):
|
|
_source = _child_module.linear
|
|
weight, bias = _source.weight, _source.bias
|
|
|
|
_tmp = nn.Linear(
|
|
_source.in_features, _source.out_features, bias is not None
|
|
)
|
|
|
|
_tmp.weight = weight
|
|
if bias is not None:
|
|
_tmp.bias = bias
|
|
|
|
else:
|
|
_source = _child_module.conv
|
|
weight, bias = _source.weight, _source.bias
|
|
|
|
_tmp = nn.Conv2d(
|
|
in_channels=_source.in_channels,
|
|
out_channels=_source.out_channels,
|
|
kernel_size=_source.kernel_size,
|
|
stride=_source.stride,
|
|
padding=_source.padding,
|
|
dilation=_source.dilation,
|
|
groups=_source.groups,
|
|
bias=bias is not None,
|
|
)
|
|
|
|
_tmp.weight = weight
|
|
if bias is not None:
|
|
_tmp.bias = bias
|
|
|
|
_module._modules[name] = _tmp
|
|
|
|
|
|
def monkeypatch_add_lora(
|
|
model,
|
|
loras,
|
|
target_replace_module=DEFAULT_TARGET_REPLACE,
|
|
alpha: float = 1.0,
|
|
beta: float = 1.0,
|
|
):
|
|
for _module, name, _child_module in _find_modules(
|
|
model, target_replace_module, search_class=[LoraInjectedLinear]
|
|
):
|
|
weight = _child_module.linear.weight
|
|
|
|
up_weight = loras.pop(0)
|
|
down_weight = loras.pop(0)
|
|
|
|
_module._modules[name].lora_up.weight = nn.Parameter(
|
|
up_weight.type(weight.dtype).to(weight.device) * alpha
|
|
+ _module._modules[name].lora_up.weight.to(weight.device) * beta
|
|
)
|
|
_module._modules[name].lora_down.weight = nn.Parameter(
|
|
down_weight.type(weight.dtype).to(weight.device) * alpha
|
|
+ _module._modules[name].lora_down.weight.to(weight.device) * beta
|
|
)
|
|
|
|
_module._modules[name].to(weight.device)
|
|
|
|
|
|
def tune_lora_scale(model, alpha: float = 1.0):
|
|
for _module in model.modules():
|
|
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
|
_module.scale = alpha
|
|
|
|
|
|
def set_lora_diag(model, diag: torch.Tensor):
|
|
for _module in model.modules():
|
|
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
|
_module.set_selector_from_diag(diag)
|
|
|
|
|
|
def _text_lora_path(path: str) -> str:
|
|
assert path.endswith(".pt"), "Only .pt files are supported"
|
|
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
|
|
|
|
|
def _ti_lora_path(path: str) -> str:
|
|
assert path.endswith(".pt"), "Only .pt files are supported"
|
|
return ".".join(path.split(".")[:-1] + ["ti", "pt"])
|
|
|
|
|
|
def apply_learned_embed_in_clip(
|
|
learned_embeds,
|
|
text_encoder,
|
|
tokenizer,
|
|
token: Optional[Union[str, List[str]]] = None,
|
|
idempotent=False,
|
|
):
|
|
if isinstance(token, str):
|
|
trained_tokens = [token]
|
|
elif isinstance(token, list):
|
|
assert len(learned_embeds.keys()) == len(
|
|
token
|
|
), "The number of tokens and the number of embeds should be the same"
|
|
trained_tokens = token
|
|
else:
|
|
trained_tokens = list(learned_embeds.keys())
|
|
|
|
for token in trained_tokens:
|
|
print(token)
|
|
embeds = learned_embeds[token]
|
|
|
|
# cast to dtype of text_encoder
|
|
dtype = text_encoder.get_input_embeddings().weight.dtype
|
|
num_added_tokens = tokenizer.add_tokens(token)
|
|
|
|
i = 1
|
|
if not idempotent:
|
|
while num_added_tokens == 0:
|
|
print(f"The tokenizer already contains the token {token}.")
|
|
token = f"{token[:-1]}-{i}>"
|
|
print(f"Attempting to add the token {token}.")
|
|
num_added_tokens = tokenizer.add_tokens(token)
|
|
i += 1
|
|
elif num_added_tokens == 0 and idempotent:
|
|
print(f"The tokenizer already contains the token {token}.")
|
|
print(f"Replacing {token} embedding.")
|
|
|
|
# resize the token embeddings
|
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
|
|
|
# get the id for the token and assign the embeds
|
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
|
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
|
return token
|
|
|
|
|
|
def load_learned_embed_in_clip(
|
|
learned_embeds_path,
|
|
text_encoder,
|
|
tokenizer,
|
|
token: Optional[Union[str, List[str]]] = None,
|
|
idempotent=False,
|
|
):
|
|
learned_embeds = torch.load(learned_embeds_path)
|
|
apply_learned_embed_in_clip(
|
|
learned_embeds, text_encoder, tokenizer, token, idempotent
|
|
)
|
|
|
|
|
|
def patch_pipe(
|
|
pipe,
|
|
maybe_unet_path,
|
|
token: Optional[str] = None,
|
|
r: int = 4,
|
|
patch_unet=True,
|
|
patch_text=True,
|
|
patch_ti=True,
|
|
idempotent_token=True,
|
|
unet_target_replace_module=DEFAULT_TARGET_REPLACE,
|
|
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
|
):
|
|
if maybe_unet_path.endswith(".pt"):
|
|
# torch format
|
|
|
|
if maybe_unet_path.endswith(".ti.pt"):
|
|
unet_path = maybe_unet_path[:-6] + ".pt"
|
|
elif maybe_unet_path.endswith(".text_encoder.pt"):
|
|
unet_path = maybe_unet_path[:-16] + ".pt"
|
|
else:
|
|
unet_path = maybe_unet_path
|
|
|
|
ti_path = _ti_lora_path(unet_path)
|
|
text_path = _text_lora_path(unet_path)
|
|
|
|
if patch_unet:
|
|
print("LoRA : Patching Unet")
|
|
monkeypatch_or_replace_lora(
|
|
pipe.unet,
|
|
torch.load(unet_path),
|
|
r=r,
|
|
target_replace_module=unet_target_replace_module,
|
|
)
|
|
|
|
if patch_text:
|
|
print("LoRA : Patching text encoder")
|
|
monkeypatch_or_replace_lora(
|
|
pipe.text_encoder,
|
|
torch.load(text_path),
|
|
target_replace_module=text_target_replace_module,
|
|
r=r,
|
|
)
|
|
if patch_ti:
|
|
print("LoRA : Patching token input")
|
|
token = load_learned_embed_in_clip(
|
|
ti_path,
|
|
pipe.text_encoder,
|
|
pipe.tokenizer,
|
|
token=token,
|
|
idempotent=idempotent_token,
|
|
)
|
|
|
|
elif maybe_unet_path.endswith(".safetensors"):
|
|
safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
|
|
monkeypatch_or_replace_safeloras(pipe, safeloras)
|
|
tok_dict = parse_safeloras_embeds(safeloras)
|
|
if patch_ti:
|
|
apply_learned_embed_in_clip(
|
|
tok_dict,
|
|
pipe.text_encoder,
|
|
pipe.tokenizer,
|
|
token=token,
|
|
idempotent=idempotent_token,
|
|
)
|
|
return tok_dict
|
|
|
|
|
|
@torch.no_grad()
|
|
def inspect_lora(model):
|
|
moved = {}
|
|
|
|
for name, _module in model.named_modules():
|
|
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
|
ups = _module.lora_up.weight.data.clone()
|
|
downs = _module.lora_down.weight.data.clone()
|
|
|
|
wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
|
|
|
|
dist = wght.flatten().abs().mean().item()
|
|
if name in moved:
|
|
moved[name].append(dist)
|
|
else:
|
|
moved[name] = [dist]
|
|
|
|
return moved
|
|
|
|
|
|
def save_all(
|
|
unet,
|
|
text_encoder,
|
|
save_path,
|
|
placeholder_token_ids=None,
|
|
placeholder_tokens=None,
|
|
save_lora=True,
|
|
save_ti=True,
|
|
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
|
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
|
|
safe_form=True,
|
|
):
|
|
if not safe_form:
|
|
# save ti
|
|
if save_ti:
|
|
ti_path = _ti_lora_path(save_path)
|
|
learned_embeds_dict = {}
|
|
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
|
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
|
print(
|
|
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
|
learned_embeds[:4],
|
|
)
|
|
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
|
|
|
|
torch.save(learned_embeds_dict, ti_path)
|
|
print("Ti saved to ", ti_path)
|
|
|
|
# save text encoder
|
|
if save_lora:
|
|
|
|
save_lora_weight(
|
|
unet, save_path, target_replace_module=target_replace_module_unet
|
|
)
|
|
print("Unet saved to ", save_path)
|
|
|
|
save_lora_weight(
|
|
text_encoder,
|
|
_text_lora_path(save_path),
|
|
target_replace_module=target_replace_module_text,
|
|
)
|
|
print("Text Encoder saved to ", _text_lora_path(save_path))
|
|
|
|
else:
|
|
assert save_path.endswith(
|
|
".safetensors"
|
|
), f"Save path : {save_path} should end with .safetensors"
|
|
|
|
loras = {}
|
|
embeds = {}
|
|
|
|
if save_lora:
|
|
|
|
loras["unet"] = (unet, target_replace_module_unet)
|
|
loras["text_encoder"] = (text_encoder, target_replace_module_text)
|
|
|
|
if save_ti:
|
|
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
|
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
|
print(
|
|
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
|
learned_embeds[:4],
|
|
)
|
|
embeds[tok] = learned_embeds.detach().cpu()
|
|
|
|
save_safeloras_with_embeds(loras, embeds, save_path)
|