mirror of https://github.com/bmaltais/kohya_ss
545 lines
25 KiB
Python
545 lines
25 KiB
Python
import sys
|
|
import os
|
|
|
|
# 1. Add sd-scripts directory to sys.path
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
project_root = os.path.dirname(script_dir)
|
|
sd_scripts_dir_path = os.path.join(project_root, "sd-scripts")
|
|
|
|
if sd_scripts_dir_path not in sys.path:
|
|
sys.path.insert(0, sd_scripts_dir_path)
|
|
|
|
# Now you can import from the library package and the networks package
|
|
try:
|
|
from library import sai_model_spec, model_util, sdxl_model_util
|
|
from library.utils import setup_logging
|
|
from networks import lora # <--- CORRECTED LORA IMPORT
|
|
except ImportError as e:
|
|
print(f"Error importing from sd-scripts. Please check your sd-scripts folder structure.")
|
|
print(f"Attempted to load from: {sd_scripts_dir_path}")
|
|
print(f"Original error: {e}")
|
|
print("Current sys.path relevant entries:")
|
|
for p in sys.path:
|
|
if "sd-scripts" in p or "kohya_ss" in p: # Print relevant paths for debugging
|
|
print(p)
|
|
# Ensure 'networks' directory exists in 'sd-scripts' and contains 'lora.py'
|
|
# Also ensure 'sd-scripts/networks/__init__.py' exists.
|
|
raise
|
|
|
|
# --- The rest of your script ---
|
|
import argparse
|
|
import json
|
|
# import os # Already imported
|
|
import time
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
from tqdm import tqdm
|
|
|
|
setup_logging()
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MIN_SV = 1e-6
|
|
|
|
def index_sv_cumulative(S, target):
|
|
original_sum = float(torch.sum(S))
|
|
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
|
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
|
index = max(1, min(index, len(S) - 1))
|
|
return index
|
|
|
|
def index_sv_fro(S, target):
|
|
S_squared = S.pow(2)
|
|
S_fro_sq = float(torch.sum(S_squared))
|
|
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
|
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
|
index = max(1, min(index, len(S) - 1))
|
|
return index
|
|
|
|
def index_sv_ratio(S, target):
|
|
max_sv = S[0]
|
|
min_sv = max_sv / target
|
|
index = int(torch.sum(S > min_sv).item())
|
|
index = max(1, min(index, len(S) - 1))
|
|
return index
|
|
|
|
def index_sv_knee_improved(S, MIN_SV_KNEE=1e-8): # MIN_SV_KNEE can be same as global MIN_SV or specific
|
|
"""
|
|
Determine rank using the knee point detection method with normalization.
|
|
Normalizes singular values and their indices to the [0,1] range
|
|
to make the knee detection scale-invariant.
|
|
"""
|
|
n = len(S)
|
|
if n < 3: # Need at least 3 points to detect a knee
|
|
return 1
|
|
|
|
# S is expected to be sorted in descending order.
|
|
s_max, s_min = S[0], S[-1]
|
|
|
|
# Handle flat or nearly flat singular value spectrum
|
|
if s_max - s_min < MIN_SV_KNEE:
|
|
# If all singular values are almost the same, a knee is not well-defined.
|
|
# Returning 1 is a conservative choice for low rank.
|
|
# Alternatively, n // 2 or n - 1 could be chosen depending on desired behavior.
|
|
return 1
|
|
|
|
# Normalize singular values to [0, 1]
|
|
# s_normalized[0] will be 1, s_normalized[n-1] will be 0
|
|
s_normalized = (S - s_min) / (s_max - s_min)
|
|
|
|
# Normalize indices to [0, 1]
|
|
# x_normalized[0] will be 0, x_normalized[n-1] will be 1
|
|
x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
|
|
|
|
# The line in normalized space connects (x_norm[0], s_norm[0]) to (x_norm[n-1], s_norm[n-1])
|
|
# This is (0, 1) to (1, 0).
|
|
# The equation of the line passing through (0,1) and (1,0) is x + y - 1 = 0.
|
|
# So, A=1, B=1, C=-1 for the line equation Ax + By + C = 0.
|
|
|
|
# Calculate the perpendicular distance from each point (x_normalized[i], s_normalized[i]) to this line.
|
|
# Distance = |A*x_i + B*y_i + C| / sqrt(A^2 + B^2)
|
|
# Distance = |1*x_normalized + 1*s_normalized - 1| / sqrt(1^2 + 1^2)
|
|
# Distance = |x_normalized + s_normalized - 1| / sqrt(2)
|
|
|
|
# The sqrt(2) denominator is constant and doesn't affect argmax, so it can be omitted for finding the index.
|
|
distances = (x_normalized + s_normalized - 1).abs()
|
|
|
|
# Find the 0-based index of the point with the maximum distance.
|
|
knee_index_0based = torch.argmax(distances).item()
|
|
|
|
# Convert 0-based index to 1-based rank.
|
|
rank = knee_index_0based + 1
|
|
|
|
# Clamp rank similar to original: must be > 0 and <= n-1 (typical for rank reduction)
|
|
# If knee_index_0based is n-1 (last point), rank becomes n. min(n, n-1) results in n-1.
|
|
rank = max(1, min(rank, n - 1))
|
|
|
|
return rank
|
|
|
|
def index_sv_cumulative_knee(S, min_sv_threshold=1e-8):
|
|
"""
|
|
Determine rank using the knee point detection method on the normalized cumulative sum of singular values.
|
|
This method identifies a point where adding more singular values contributes diminishingly to the total sum.
|
|
"""
|
|
n = len(S)
|
|
if n < 3: # Need at least 3 points to detect a knee
|
|
return 1
|
|
|
|
s_sum = torch.sum(S)
|
|
# If all singular values are zero or very small, return rank 1.
|
|
if s_sum < min_sv_threshold:
|
|
return 1
|
|
|
|
# Calculate cumulative sum of singular values, normalized by the total sum.
|
|
# y_values[0] = S[0]/s_sum, ..., y_values[n-1] = 1.0
|
|
y_values = torch.cumsum(S, dim=0) / s_sum
|
|
|
|
# Normalize these y_values (cumulative sums) to the range [0,1] for knee detection.
|
|
y_min, y_max = y_values[0], y_values[n-1] # y_max is typically 1.0
|
|
|
|
# If the normalized cumulative sum curve is very flat (e.g., S[0] captures almost all energy),
|
|
# it implies the first few components are dominant.
|
|
if y_max - y_min < min_sv_threshold: # Using min_sv_threshold here as a sensitivity for flatness
|
|
# This condition means (S[0] + ... + S[n-1]) - S[0] is small relative to sum(S) if n>1
|
|
# Effectively, S[1:] components are negligible.
|
|
return 1
|
|
|
|
# y_norm[0] = 0, y_norm[n-1] = 1 (represents the normalized cumulative sum from start to end)
|
|
y_norm = (y_values - y_min) / (y_max - y_min)
|
|
|
|
# x_values are indices, normalized to [0, 1]
|
|
# x_norm[0] = 0, ..., x_norm[n-1] = 1
|
|
x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
|
|
|
|
# The "knee" is the point on the curve (x_norm[i], y_norm[i]) that is farthest
|
|
# from the line connecting the start and end of this normalized curve.
|
|
# In this normalized space, the line connects (0,0) to (1,1).
|
|
# The equation of this line is Y = X, or X - Y = 0.
|
|
# The distance from a point (x_i, y_i) to the line X - Y = 0 is |x_i - y_i| / sqrt(1^2 + (-1)^2).
|
|
# We can maximize |x_i - y_i| (or |y_i - x_i|) as sqrt(2) is a constant factor.
|
|
distances = (y_norm - x_norm).abs() # y_norm is expected to be >= x_norm for a concave cumulative curve.
|
|
|
|
# Find the 0-based index of the point with the maximum distance.
|
|
knee_index_0based = torch.argmax(distances).item()
|
|
|
|
# Convert 0-based index to 1-based rank.
|
|
rank = knee_index_0based + 1
|
|
|
|
# Clamp rank to be between 1 and n-1 (as n elements give n-1 possible ranks for truncation).
|
|
# A rank of n means no truncation. n-1 is the highest sensible rank for reduction.
|
|
rank = max(1, min(rank, n - 1))
|
|
|
|
return rank
|
|
|
|
def index_sv_rel_decrease(S, tau=0.1):
|
|
"""Determine rank based on relative decrease threshold."""
|
|
if len(S) < 2:
|
|
# For matrices with fewer than 2 singular values, a relative decrease
|
|
# isn't meaningful. Returning 1 is a sensible default.
|
|
return 1
|
|
|
|
# Compute ratios of consecutive singular values
|
|
# S is sorted descending, so S[:-1] >= S[1:]
|
|
# ratios will be <= 1.0
|
|
ratios = S[1:] / S[:-1] # Example: S=[10,1,0.5], ratios=[0.1, 0.5]
|
|
|
|
# Find the smallest k such that S[k+1]/S[k] < tau.
|
|
# The rank would then be k+1, as we include S[k].
|
|
for k in range(len(ratios)): # k ranges from 0 to len(S)-2
|
|
if ratios[k] < tau:
|
|
# We found a significant drop after the k-th singular value.
|
|
# So, we keep k+1 singular values (indices 0 to k).
|
|
# The rank is k+1. Since k >= 0, k+1 >= 1.
|
|
return k + 1
|
|
|
|
# If no drop below tau was found, it means all relative decreases were >= tau.
|
|
# In this case, this method suggests using all available singular values.
|
|
# The actual rank will be capped later by args.dim/conv_dim and matrix dimensions.
|
|
return len(S)
|
|
|
|
def save_to_file(file_name, model_to_save, state_dict_content, dtype, metadata=None): # Renamed params for clarity
|
|
if dtype is not None:
|
|
for key in list(state_dict_content.keys()):
|
|
if isinstance(state_dict_content[key], torch.Tensor):
|
|
state_dict_content[key] = state_dict_content[key].to(dtype)
|
|
|
|
# save_file from safetensors expects a state_dict as the first argument if metadata is also passed.
|
|
# torch.save would also expect the state_dict.
|
|
# The 'model' variable being passed to save_file should be the state_dict itself.
|
|
if os.path.splitext(file_name)[1] == ".safetensors":
|
|
save_file(model_to_save, file_name, metadata=metadata) # Pass metadata correctly
|
|
else:
|
|
torch.save(model_to_save, file_name)
|
|
|
|
def svd(
|
|
model_org=None,
|
|
model_tuned=None,
|
|
save_to=None,
|
|
dim=4,
|
|
v2=None,
|
|
sdxl=None,
|
|
conv_dim=None,
|
|
v_parameterization=None,
|
|
device=None,
|
|
save_precision=None,
|
|
clamp_quantile=0.99,
|
|
min_diff=0.01,
|
|
no_metadata=False,
|
|
load_precision=None,
|
|
load_original_model_to=None,
|
|
load_tuned_model_to=None,
|
|
dynamic_method=None,
|
|
dynamic_param=None,
|
|
verbose=False,
|
|
):
|
|
def str_to_dtype(p):
|
|
if p == "float":
|
|
return torch.float
|
|
if p == "fp16":
|
|
return torch.float16
|
|
if p == "bf16":
|
|
return torch.bfloat16
|
|
return None
|
|
|
|
assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time"
|
|
v_parameterization = v2 if v_parameterization is None else v_parameterization
|
|
|
|
load_dtype = str_to_dtype(load_precision) if load_precision else None
|
|
save_dtype = str_to_dtype(save_precision) if save_precision else torch.float
|
|
work_device = "cpu"
|
|
|
|
# Load models
|
|
if not sdxl:
|
|
logger.info(f"Loading original SD model: {model_org}")
|
|
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
|
text_encoders_o = [text_encoder_o]
|
|
if load_dtype:
|
|
text_encoder_o.to(load_dtype)
|
|
unet_o.to(load_dtype)
|
|
|
|
logger.info(f"Loading tuned SD model: {model_tuned}")
|
|
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
|
text_encoders_t = [text_encoder_t]
|
|
if load_dtype:
|
|
text_encoder_t.to(load_dtype)
|
|
unet_t.to(load_dtype)
|
|
|
|
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
|
else:
|
|
device_org = load_original_model_to or "cpu"
|
|
device_tuned = load_tuned_model_to or "cpu"
|
|
|
|
logger.info(f"Loading original SDXL model: {model_org}")
|
|
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
|
)
|
|
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
|
if load_dtype:
|
|
text_encoder_o1.to(load_dtype)
|
|
text_encoder_o2.to(load_dtype)
|
|
unet_o.to(load_dtype)
|
|
|
|
logger.info(f"Loading tuned SDXL model: {model_tuned}")
|
|
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
|
)
|
|
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
|
if load_dtype:
|
|
text_encoder_t1.to(load_dtype)
|
|
text_encoder_t2.to(load_dtype)
|
|
unet_t.to(load_dtype)
|
|
|
|
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
|
|
|
# Create LoRA network
|
|
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {}
|
|
|
|
# Define a small initial dimension for memory efficiency
|
|
init_dim = 4 # Small value to minimize memory usage
|
|
|
|
# Create LoRA networks with minimal dimension
|
|
lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs)
|
|
lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs)
|
|
|
|
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)"
|
|
|
|
# Compute differences
|
|
diffs = {}
|
|
text_encoder_different = False
|
|
for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras):
|
|
lora_name = lora_o.lora_name
|
|
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
|
|
lora_o.org_module.weight = None
|
|
lora_t.org_module.weight = None
|
|
|
|
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
|
text_encoder_different = True
|
|
logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}")
|
|
diffs[lora_name] = diff
|
|
|
|
for text_encoder in text_encoders_t:
|
|
del text_encoder
|
|
|
|
if not text_encoder_different:
|
|
logger.warning("Text encoders are identical. Extracting U-Net only.")
|
|
lora_network_o.text_encoder_loras = []
|
|
diffs.clear()
|
|
|
|
for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras):
|
|
lora_name = lora_o.lora_name
|
|
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
|
|
lora_o.org_module.weight = None
|
|
lora_t.org_module.weight = None
|
|
diffs[lora_name] = diff
|
|
|
|
del lora_network_t, unet_t
|
|
|
|
# Filter relevant modules
|
|
lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
|
|
|
|
# Extract and resize LoRA using SVD
|
|
logger.info("Extracting and resizing LoRA via SVD")
|
|
lora_weights = {}
|
|
with torch.no_grad():
|
|
for lora_name in tqdm(lora_names):
|
|
mat = diffs[lora_name]
|
|
if device:
|
|
mat = mat.to(device)
|
|
mat = mat.to(torch.float)
|
|
|
|
conv2d = len(mat.size()) == 4
|
|
kernel_size = mat.size()[2:4] if conv2d else None
|
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
|
out_dim, in_dim = mat.size()[0:2]
|
|
|
|
if conv2d:
|
|
mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze()
|
|
|
|
U, S, Vh = torch.linalg.svd(mat)
|
|
|
|
# Determine rank
|
|
max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
|
|
if dynamic_method:
|
|
if S[0] <= MIN_SV:
|
|
rank = 1
|
|
elif dynamic_method == "sv_ratio":
|
|
rank = index_sv_ratio(S, dynamic_param)
|
|
elif dynamic_method == "sv_cumulative":
|
|
rank = index_sv_cumulative(S, dynamic_param)
|
|
elif dynamic_method == "sv_fro":
|
|
rank = index_sv_fro(S, dynamic_param)
|
|
elif dynamic_method == "sv_knee":
|
|
rank = index_sv_knee_improved(S, MIN_SV) # Pass MIN_SV or a specific threshold
|
|
elif dynamic_method == "sv_cumulative_knee": # New method
|
|
rank = index_sv_cumulative_knee(S, MIN_SV) # Pass MIN_SV or a specific threshold
|
|
elif dynamic_method == "sv_rel_decrease":
|
|
rank = index_sv_rel_decrease(S, dynamic_param)
|
|
rank = min(rank, max_rank, in_dim, out_dim)
|
|
else:
|
|
rank = min(max_rank, in_dim, out_dim)
|
|
|
|
rank = max(1, rank) # Ensure rank is at least 1
|
|
|
|
# Truncate SVD components and distribute sqrt(S)
|
|
S_k = S[:rank]
|
|
U_k = U[:, :rank]
|
|
Vh_k = Vh[:rank, :]
|
|
|
|
# Ensure S_k values are non-negative before sqrt to avoid NaN from tiny negative SVD artifacts
|
|
S_k_non_negative = torch.clamp(S_k, min=0.0) # Use 0.0 for float tensor
|
|
s_sqrt = torch.sqrt(S_k_non_negative)
|
|
|
|
# Distribute s_sqrt: U_final = U_k * diag(s_sqrt), Vh_final = diag(s_sqrt) * Vh_k
|
|
# Using efficient broadcasting for multiplication:
|
|
U_final = U_k * s_sqrt.unsqueeze(0) # (out_dim, rank) * (1, rank)
|
|
Vh_final = Vh_k * s_sqrt.unsqueeze(1) # (rank, in_dim_effective) * (rank, 1)
|
|
|
|
# Clamp values (applied to U_final, Vh_final)
|
|
# The distribution of values in U_final and Vh_final might be different
|
|
# than the original U and Vh, so the effect of clamping might change.
|
|
dist = torch.cat([U_final.flatten(), Vh_final.flatten()])
|
|
hi_val = torch.quantile(dist, clamp_quantile)
|
|
U_clamped = U_final.clamp(-hi_val, hi_val)
|
|
Vh_clamped = Vh_final.clamp(-hi_val, hi_val)
|
|
|
|
if conv2d:
|
|
# U_clamped is (out_dim, rank)
|
|
U_clamped = U_clamped.reshape(out_dim, rank, 1, 1)
|
|
|
|
# Vh_clamped is (rank, in_dim * possibly_kernel_dims)
|
|
# It needs to be reshaped back to (rank, in_dim, kernel_h, kernel_w)
|
|
if conv2d_3x3 : # Original mat was (out_dim, in_dim * k_h * k_w)
|
|
Vh_clamped = Vh_clamped.reshape(rank, in_dim, *kernel_size)
|
|
else: # Original mat was (out_dim, in_dim) for 1x1 conv, kernel_size is (1,1)
|
|
Vh_clamped = Vh_clamped.reshape(rank, in_dim, *kernel_size) # kernel_size is (1,1) here
|
|
|
|
U_clamped = U_clamped.to(work_device, dtype=save_dtype).contiguous()
|
|
Vh_clamped = Vh_clamped.to(work_device, dtype=save_dtype).contiguous()
|
|
lora_weights[lora_name] = (U_clamped, Vh_clamped)
|
|
|
|
# Verbose output (S values are pre-modification for accurate reporting of original SVD properties)
|
|
if verbose:
|
|
s_sum_total = float(torch.sum(S))
|
|
s_sum_rank = float(torch.sum(S[:rank])) # Sum of the singular values actually used for reconstruction
|
|
|
|
fro_orig_total = float(torch.sqrt(torch.sum(S.pow(2))))
|
|
fro_reconstructed_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2)))) # Frobenius norm of the matrix part represented by chosen rank
|
|
|
|
# Ratio of the largest retained singular value to the smallest retained singular value
|
|
# S is sorted, S[0] is max. S[rank-1] is the smallest singular value included if rank > 0.
|
|
ratio_sv = S[0] / S[rank - 1] if rank > 0 and S[rank - 1].abs() > MIN_SV else float('inf') # Avoid division by zero or tiny number
|
|
|
|
# Ensure denominators are not zero for percentages
|
|
sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > MIN_SV else 1.0
|
|
fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > MIN_SV else 1.0
|
|
|
|
logger.info(
|
|
f"{lora_name:75} | rank: {rank}, "
|
|
f"sum(S) retained: {sum_s_retained_percentage:.2%}, "
|
|
f"Frobenius norm retained: {fro_retained_percentage:.2%}, "
|
|
f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}"
|
|
)
|
|
|
|
# Create state dict
|
|
lora_sd = {}
|
|
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
|
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
|
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
|
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype) # alpha is rank
|
|
|
|
# Load and save LoRA
|
|
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
|
|
lora_network_save.apply_to(text_encoders_o, unet_o) # This applies weights, not strictly necessary if just saving sd
|
|
info = lora_network_save.load_state_dict(lora_sd) # This populates the network object with the weights from lora_sd
|
|
logger.info(f"Loaded extracted and resized LoRA weights into network object: {info}")
|
|
|
|
|
|
os.makedirs(os.path.dirname(save_to), exist_ok=True)
|
|
|
|
# Metadata
|
|
net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {}
|
|
# Determine network_dim and network_alpha for metadata based on dynamic method
|
|
if dynamic_method:
|
|
network_dim_meta = "Dynamic"
|
|
network_alpha_meta = "Dynamic" # Alpha is rank, which is dynamic
|
|
else:
|
|
network_dim_meta = str(dim)
|
|
network_alpha_meta = str(float(dim)) # Alpha is rank, which is dim
|
|
|
|
metadata = {
|
|
"ss_v2": str(v2),
|
|
"ss_base_model_version": model_version,
|
|
"ss_network_module": "networks.lora",
|
|
"ss_network_dim": network_dim_meta,
|
|
"ss_network_alpha": network_alpha_meta, # Alpha is typically the rank
|
|
"ss_network_args": json.dumps(net_kwargs),
|
|
"ss_lowram": "False", # Assuming not specifically lowram mode
|
|
"ss_num_train_images": "N/A", # Not applicable for extraction
|
|
# Add other relevant metadata as per sai_model_spec or conventions
|
|
}
|
|
if not no_metadata:
|
|
title = os.path.splitext(os.path.basename(save_to))[0]
|
|
# Build sai_metadata, ensuring it includes necessary fields like 'ss_sd_model_hash' if possible
|
|
# For extraction, some training-specific metadata might not be relevant or available.
|
|
sai_metadata = sai_model_spec.build_metadata(
|
|
None, # training_info (usually from train_util or fine_tune) - can be None for extraction
|
|
v2,
|
|
v_parameterization,
|
|
sdxl,
|
|
True, # is_sd2
|
|
False, # is_v_pred_like
|
|
time.time(),
|
|
title=title,
|
|
# model_hash=None, # Original model hash if available
|
|
# tuned_model_hash=None # Tuned model hash if available
|
|
)
|
|
# Filter out None values from sai_metadata if any, or handle them in build_metadata
|
|
sai_metadata_cleaned = {k: v for k, v in sai_metadata.items() if v is not None}
|
|
metadata.update(sai_metadata_cleaned)
|
|
|
|
|
|
# Use the state_dict 'lora_sd' for saving, not the network object 'lora_network_save'
|
|
save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata) # Pass lora_sd as the model/state_dict to save
|
|
logger.info(f"LoRA saved to: {save_to}")
|
|
|
|
def setup_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
|
|
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)")
|
|
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
|
|
parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models")
|
|
parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA")
|
|
parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
|
|
parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
|
|
parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)")
|
|
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
|
|
parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3")
|
|
parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)")
|
|
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
|
|
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract")
|
|
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata")
|
|
parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)")
|
|
parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)")
|
|
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
|
|
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info")
|
|
parser.add_argument(
|
|
"--dynamic_method",
|
|
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], # Added "sv_cumulative_knee"
|
|
help="Dynamic rank reduction method"
|
|
)
|
|
return parser
|
|
|
|
if __name__ == "__main__":
|
|
parser = setup_parser()
|
|
args = parser.parse_args()
|
|
methods_requiring_param = ["sv_ratio", "sv_fro", "sv_cumulative", "sv_rel_decrease"]
|
|
if args.dynamic_method in methods_requiring_param and args.dynamic_param is None:
|
|
raise ValueError(f"Dynamic method '{args.dynamic_method}' requires --dynamic_param to be set.")
|
|
|
|
# Add a check for rank > 0 if not dynamic, or ensure dynamic methods return rank >= 1
|
|
if not args.dynamic_method and args.dim <= 0:
|
|
raise ValueError(f"--dim (rank) must be > 0. Got {args.dim}")
|
|
if args.conv_dim is not None and args.conv_dim <=0:
|
|
raise ValueError(f"--conv_dim (rank) must be > 0 if specified. Got {args.conv_dim}")
|
|
|
|
svd(**vars(args)) |