kohya_ss/tools/extract_lora_from_models-nw...

340 lines
14 KiB
Python

import argparse
import json
import os
import time
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, model_util, sdxl_model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
def index_sv_knee(S):
"""Determine rank using the knee point detection method."""
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
# Line coefficients from (1, S[0]) to (n, S[-1])
a = S[0] - S[-1]
b = n - 1
c = 1 * S[-1] - n * S[0]
# Compute distances for each k
distances = []
for k in range(1, n + 1):
dist = abs(a * k + b * S[k - 1] + c) / (a**2 + b**2)**0.5
distances.append(dist)
# Find index of maximum distance (add 1 because k starts at 1)
index = torch.argmax(torch.tensor(distances)).item() + 1
index = max(1, min(index, n - 1))
return index
def index_sv_rel_decrease(S, tau=0.1):
"""Determine rank based on relative decrease threshold."""
if len(S) < 2:
return 1
# Compute ratios of consecutive singular values
ratios = S[1:] / S[:-1]
# Find the smallest k where ratio < tau
for k in range(len(ratios)):
if ratios[k] < tau:
return max(1, k + 1) # k + 1 because we want rank after the drop
# If no drop below tau, return max rank
return min(len(S), len(S) - 1)
def save_to_file(file_name, model, state_dict, dtype, metadata=None):
if dtype is not None:
for key in list(state_dict.keys()):
if isinstance(state_dict[key], torch.Tensor):
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
dynamic_method=None,
dynamic_param=None,
verbose=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time"
v_parameterization = v2 if v_parameterization is None else v_parameterization
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision) if save_precision else torch.float
work_device = "cpu"
# Load models
if not sdxl:
logger.info(f"Loading original SD model: {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
if load_dtype:
text_encoder_o.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SD model: {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
if load_dtype:
text_encoder_t.to(load_dtype)
unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
device_org = load_original_model_to or "cpu"
device_tuned = load_tuned_model_to or "cpu"
logger.info(f"Loading original SDXL model: {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype:
text_encoder_o1.to(load_dtype)
text_encoder_o2.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SDXL model: {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype:
text_encoder_t1.to(load_dtype)
text_encoder_t2.to(load_dtype)
unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# Create LoRA network
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {}
# Define a small initial dimension for memory efficiency
init_dim = 4 # Small value to minimize memory usage
# Create LoRA networks with minimal dimension
lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)"
# Compute differences
diffs = {}
text_encoder_different = False
for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}")
diffs[lora_name] = diff
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
logger.warning("Text encoders are identical. Extracting U-Net only.")
lora_network_o.text_encoder_loras = []
diffs.clear()
for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
diffs[lora_name] = diff
del lora_network_t, unet_t
# Filter relevant modules
lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
# Extract and resize LoRA using SVD
logger.info("Extracting and resizing LoRA via SVD")
lora_weights = {}
with torch.no_grad():
for lora_name in tqdm(lora_names):
mat = diffs[lora_name]
if device:
mat = mat.to(device)
mat = mat.to(torch.float)
conv2d = len(mat.size()) == 4
kernel_size = mat.size()[2:4] if conv2d else None
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = mat.size()[0:2]
if conv2d:
mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
# Determine rank
max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
if dynamic_method:
if S[0] <= MIN_SV:
rank = 1
elif dynamic_method == "sv_ratio":
rank = index_sv_ratio(S, dynamic_param)
elif dynamic_method == "sv_cumulative":
rank = index_sv_cumulative(S, dynamic_param)
elif dynamic_method == "sv_fro":
rank = index_sv_fro(S, dynamic_param)
elif dynamic_method == "sv_knee":
rank = index_sv_knee(S)
elif dynamic_method == "sv_rel_decrease":
rank = index_sv_rel_decrease(S, dynamic_param)
rank = min(rank, max_rank, in_dim, out_dim)
else:
rank = min(max_rank, in_dim, out_dim)
# Truncate SVD components
U = U[:, :rank] @ torch.diag(S[:rank])
Vh = Vh[:rank, :]
# Clamp values
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
U = U.clamp(-hi_val, hi_val)
Vh = Vh.clamp(-hi_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, *kernel_size)
U = U.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U, Vh)
# Verbose output
if verbose:
s_sum = float(torch.sum(S))
s_rank = float(torch.sum(S[:rank]))
fro = float(torch.sqrt(torch.sum(S.pow(2))))
fro_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2))))
ratio = S[0] / S[rank - 1] if rank > 1 else float('inf')
logger.info(f"{lora_name:75} | sum(S) retained: {s_rank/s_sum:.1%}, fro retained: {fro_rank/fro:.1%}, max ratio: {ratio:.1f}, rank: {rank}")
# Create state dict
lora_sd = {}
for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype)
# Load and save LoRA
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoders_o, unet_o)
info = lora_network_save.load_state_dict(lora_sd)
logger.info(f"Loaded extracted and resized LoRA weights: {info}")
os.makedirs(os.path.dirname(save_to), exist_ok=True)
# Metadata
net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {}
metadata = {
"ss_v2": str(v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": str(dim) if not dynamic_method else "Dynamic",
"ss_network_alpha": str(float(dim)) if not dynamic_method else "Dynamic",
"ss_network_args": json.dumps(net_kwargs),
}
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
metadata.update(sai_metadata)
save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata)
logger.info(f"LoRA saved to: {save_to}")
def setup_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)")
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models")
parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA")
parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)")
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3")
parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)")
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract")
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata")
parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)")
parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)")
parser.add_argument("--dynamic_method", choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease"], help="Dynamic rank reduction method")
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
if args.dynamic_method and not args.dynamic_param:
raise ValueError("Dynamic method requires a dynamic parameter")
svd(**vars(args))