pull/3264/head
bmaltais 2025-05-10 15:33:24 -04:00
parent be30140562
commit ece84c6d06
6 changed files with 1755 additions and 206 deletions

159
tools/analyse_loha.py Normal file
View File

@ -0,0 +1,159 @@
import safetensors.torch
import json
from collections import OrderedDict
import sys # To redirect stdout
import traceback
class Logger(object):
def __init__(self, filename="loha_analysis_output.txt"):
self.terminal = sys.stdout
self.log = open(filename, "w", encoding='utf-8')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
# This flush method is needed for python 3 compatibility.
# This handles the flush command, which shutil.copytree or os.system uses.
self.terminal.flush()
self.log.flush()
def close(self):
self.log.close()
def analyze_safetensors_file(filepath, output_filename="loha_analysis_output.txt"):
"""
Analyzes a .safetensors file to extract and print its metadata
and tensor information (keys, shapes, dtypes) to a file.
"""
original_stdout = sys.stdout
logger = Logger(filename=output_filename)
sys.stdout = logger
try:
print(f"--- Analyzing: {filepath} ---\n")
print(f"--- Output will be saved to: {output_filename} ---\n")
# Load the tensors to get their structure
state_dict = safetensors.torch.load_file(filepath, device="cpu") # Load to CPU to avoid potential CUDA issues
print("--- Tensor Information ---")
if not state_dict:
print("No tensors found in the state dictionary.")
else:
# Sort keys for consistent output
sorted_keys = sorted(state_dict.keys())
current_module_prefix = ""
# First, identify all unique module prefixes for better grouping
module_prefixes = sorted(list(set([".".join(key.split(".")[:-1]) for key in sorted_keys if "." in key])))
for prefix in module_prefixes:
if not prefix: # Skip keys that don't seem to be part of a module (e.g. global metadata tensors if any)
continue
print(f"\nModule: {prefix}")
for key in sorted_keys:
if key.startswith(prefix + "."):
tensor = state_dict[key]
print(f" - Key: {key}")
print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}") # Output shape as list for clarity
if key.endswith((".alpha", ".dim")):
try:
value = tensor.item()
# Check if value is float and format if it is
if isinstance(value, float):
print(f" Value: {value:.8f}") # Format float to a certain precision
else:
print(f" Value: {value}")
except Exception as e:
print(f" Value: Could not extract scalar value ({tensor}, error: {e})")
elif tensor.numel() < 10: # Print small tensors' values
print(f" Values (first few): {tensor.flatten()[:10].tolist()}")
# Print keys that might not fit the module pattern (e.g., older formats or single tensors)
print("\n--- Other Tensor Keys (if any, not fitting typical module.parameter pattern) ---")
other_keys_found = False
for key in sorted_keys:
if not any(key.startswith(p + ".") for p in module_prefixes if p):
other_keys_found = True
tensor = state_dict[key]
print(f" - Key: {key}")
print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}")
if key.endswith((".alpha", ".dim")) or tensor.numel() == 1:
try:
value = tensor.item()
if isinstance(value, float):
print(f" Value: {value:.8f}")
else:
print(f" Value: {value}")
except Exception as e:
print(f" Value: Could not extract scalar value ({tensor}, error: {e})")
if not other_keys_found:
print("No other keys found.")
print(f"\nTotal tensor keys found: {len(state_dict)}")
print("\n--- Metadata (from safetensors header) ---")
metadata_content = OrderedDict()
malformed_metadata_keys = []
try:
# Use safe_open to access the metadata separately
with safetensors.safe_open(filepath, framework="pt", device="cpu") as f:
metadata_keys = f.metadata()
if metadata_keys is None:
print("No metadata dictionary found in the file header (f.metadata() returned None).")
else:
for k in metadata_keys.keys():
try:
metadata_content[k] = metadata_keys.get(k)
except Exception as e:
malformed_metadata_keys.append((k, str(e)))
metadata_content[k] = f"[Error reading value: {e}]"
except Exception as e:
print(f"Could not open or read metadata using safe_open: {e}")
traceback.print_exc(file=sys.stdout)
if not metadata_content and not malformed_metadata_keys:
print("No metadata content extracted.")
else:
for key, value in metadata_content.items():
print(f"- {key}: {value}")
if key == "ss_network_args" and value and not value.startswith("[Error"):
try:
parsed_args = json.loads(value)
print(" Parsed ss_network_args:")
for arg_key, arg_value in parsed_args.items():
print(f" - {arg_key}: {arg_value}")
except json.JSONDecodeError:
print(" (ss_network_args is not a valid JSON string)")
if malformed_metadata_keys:
print("\n--- Malformed Metadata Keys (could not be read) ---")
for key, error_msg in malformed_metadata_keys:
print(f"- {key}: Error: {error_msg}")
print("\n--- End of Analysis ---")
except Exception as e:
print(f"\n!!! An error occurred during analysis !!!")
print(str(e))
traceback.print_exc(file=sys.stdout) # Print full traceback to the log file
finally:
sys.stdout = original_stdout # Restore standard output
logger.close()
print(f"\nAnalysis complete. Output saved to: {output_filename}")
if __name__ == "__main__":
input_file_path = input("Enter the path to your working LoHA .safetensors file: ")
output_file_name = "loha_analysis_results.txt" # You can change this default
# Suggest a default output name based on input file if desired
# import os
# base_name = os.path.splitext(os.path.basename(input_file_path))[0]
# output_file_name = f"{base_name}_analysis.txt"
print(f"The analysis will be saved to: {output_file_name}")
analyze_safetensors_file(input_file_path, output_filename=output_file_name)

204
tools/dummy_loha.py Normal file
View File

@ -0,0 +1,204 @@
import torch
from safetensors.torch import save_file
from collections import OrderedDict
import json
# --- Script Configuration ---
# This script generates a minimal, non-functional LoHA (LyCORIS Hadamard Product Adaptation)
# .safetensors file, designed to be structurally compatible with ComfyUI and
# based on the analysis of a working SDXL LoHA file.
# --- Global LoHA Parameters (mimicking metadata from your working file) ---
# These can be overridden per layer if needed for more complex dummies.
# From your metadata: ss_network_dim: 32, ss_network_alpha: 32.0
DEFAULT_RANK = 32
DEFAULT_ALPHA = 32.0
CONV_RANK = 8 # From your ss_network_args: "conv_dim": "8"
CONV_ALPHA = 4.0 # From your ss_network_args: "conv_alpha": "4"
# Define example target layers.
# We'll use names and dimensions that are representative of SDXL and your analysis.
# Format: (layer_name, in_dim, out_dim, rank, alpha)
# Note: For Conv2d, in_dim = in_channels, out_dim = out_channels.
# The hada_wX_b for conv will have shape (rank, in_channels * kernel_h * kernel_w)
# For simplicity in this dummy, we'll primarily focus on linear/attention
# layers first, and then add one representative conv-like layer.
# Layer that previously caused error:
# "ERROR loha diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight shape '[640, 640]' is invalid..."
# This corresponds to lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_v
# In your working LoHA, similar attention layers (e.g., *_attn1_to_k) have out_dim=640, in_dim=640, rank=32, alpha=32.0
EXAMPLE_LAYERS_CONFIG = [
# UNet Attention Layers (mimicking typical SDXL structure)
{
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_q", # Query
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
},
{
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_k", # Key
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
},
{
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_v", # Value - this one errored previously
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
},
{
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_out_0", # Output Projection
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
},
# A deeper UNet attention block
{
"name": "lora_unet_middle_block_1_transformer_blocks_0_attn1_to_q",
"in_dim": 1280, "out_dim": 1280, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
},
{
"name": "lora_unet_middle_block_1_transformer_blocks_0_attn1_to_out_0",
"in_dim": 1280, "out_dim": 1280, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
},
# Example UNet "Convolutional" LoHA (e.g., for a ResBlock's conv layer)
# Based on your lora_unet_input_blocks_1_0_in_layers_2 which had rank 8, alpha 4
# Assuming original conv was Conv2d(320, 320, kernel_size=3, padding=1)
{
"name": "lora_unet_input_blocks_1_0_in_layers_2",
"in_dim": 320, # in_channels
"out_dim": 320, # out_channels
"rank": CONV_RANK,
"alpha": CONV_ALPHA,
"is_conv": True,
"kernel_size": 3 # Assume 3x3 kernel for this example
},
# Example Text Encoder Layer (CLIP-L, first one from your list)
# lora_te1_text_model_encoder_layers_0_mlp_fc1 (original Linear(768, 3072))
{
"name": "lora_te1_text_model_encoder_layers_0_mlp_fc1",
"in_dim": 768, "out_dim": 3072, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
},
]
# Use bfloat16 as seen in the analysis
DTYPE = torch.bfloat16
# --- Main Script ---
def create_dummy_loha_file(filepath="dummy_loha_corrected.safetensors"):
"""
Creates and saves a dummy LoHA .safetensors file with corrected structure
and metadata based on analysis of a working file.
"""
state_dict = OrderedDict()
metadata = OrderedDict()
print(f"Generating dummy LoHA with default rank={DEFAULT_RANK}, default alpha={DEFAULT_ALPHA}")
print(f"Targeting DTYPE: {DTYPE}")
for layer_config in EXAMPLE_LAYERS_CONFIG:
layer_name = layer_config["name"]
in_dim = layer_config["in_dim"]
out_dim = layer_config["out_dim"]
rank = layer_config["rank"]
alpha = layer_config["alpha"]
is_conv = layer_config["is_conv"]
print(f"Processing layer: {layer_name} (in: {in_dim}, out: {out_dim}, rank: {rank}, alpha: {alpha}, conv: {is_conv})")
# --- LoHA Tensor Shapes Correction based on analysis ---
# hada_wX_a (maps to original layer's out_features): (out_dim, rank)
# hada_wX_b (maps from original layer's in_features): (rank, in_dim)
# For Convolutions, in_dim refers to in_channels, out_dim to out_channels.
# For hada_wX_b in conv, the effective input dimension includes kernel size.
if is_conv:
kernel_size = layer_config.get("kernel_size", 3) # Default to 3x3 if not specified
# This is for LoHA types that decompose the full kernel (e.g. LyCORIS full conv):
# (rank, in_channels * kernel_h * kernel_w)
# For simpler conv LoHA (like applying to 1x1 equivalent), it might just be (rank, in_channels)
# The analysis for `lora_unet_input_blocks_1_0_in_layers_2` showed hada_w1_b as [8, 2880]
# where in_dim=320, rank=8. 2880 = 320 * 9 (i.e., in_channels * kernel_h * kernel_w for 3x3)
# This indicates a full kernel decomposition.
eff_in_dim_conv_b = in_dim * kernel_size * kernel_size
hada_w1_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
hada_w1_b = torch.randn(rank, eff_in_dim_conv_b, dtype=DTYPE) * 0.01
hada_w2_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
hada_w2_b = torch.randn(rank, eff_in_dim_conv_b, dtype=DTYPE) * 0.01
else: # Linear layers
hada_w1_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
hada_w1_b = torch.randn(rank, in_dim, dtype=DTYPE) * 0.01
hada_w2_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
hada_w2_b = torch.randn(rank, in_dim, dtype=DTYPE) * 0.01
state_dict[f"{layer_name}.hada_w1_a"] = hada_w1_a
state_dict[f"{layer_name}.hada_w1_b"] = hada_w1_b
state_dict[f"{layer_name}.hada_w2_a"] = hada_w2_a
state_dict[f"{layer_name}.hada_w2_b"] = hada_w2_b
# Alpha tensor (scalar)
state_dict[f"{layer_name}.alpha"] = torch.tensor(float(alpha), dtype=DTYPE)
# IMPORTANT: No per-module ".dim" tensor, as per analysis of working file.
# Rank is implicit in weight shapes and global metadata.
# --- Metadata (mimicking the working LoHA file) ---
metadata["ss_network_module"] = "lycoris.kohya"
metadata["ss_network_dim"] = str(DEFAULT_RANK) # Global/default rank
metadata["ss_network_alpha"] = str(DEFAULT_ALPHA) # Global/default alpha
metadata["ss_network_algo"] = "loha" # Also specified inside ss_network_args by convention
# Mimic ss_network_args from your file
network_args = {
"conv_dim": str(CONV_RANK),
"conv_alpha": str(CONV_ALPHA),
"algo": "loha",
# Add other args from your file if they seem relevant for loading structure,
# but these are the most critical for type/rank.
"dropout": "0.0", # From your file, though value might not matter for dummy
"rank_dropout": "0", # from your file
"module_dropout": "0", # from your file
"use_tucker": "False", # from your file
"use_scalar": "False", # from your file
"rank_dropout_scale": "False", # from your file
"train_norm": "False" # from your file
}
metadata["ss_network_args"] = json.dumps(network_args)
# Other potentially useful metadata from your working file (optional for basic loading)
metadata["ss_sd_model_name"] = "sd_xl_base_1.0.safetensors" # Example base model
metadata["ss_resolution"] = "(1024,1024)" # Example, format might vary
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "https_//github.com/Stability-AI/generative-models" # fixed typo
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base/lora" # Even for LoHA, this is often used
metadata["ss_mixed_precision"] = "bf16"
metadata["ss_note"] = "Dummy LoHA (corrected) for ComfyUI validation. Not trained."
# --- Save the State Dictionary with Metadata ---
try:
save_file(state_dict, filepath, metadata=metadata)
print(f"\nSuccessfully saved dummy LoHA file to: {filepath}")
print("\nFile structure (tensor keys):")
for key in state_dict.keys():
print(f"- {key}: shape {state_dict[key].shape}, dtype {state_dict[key].dtype}")
print("\nMetadata:")
for key, value in metadata.items():
print(f"- {key}: {value}")
except Exception as e:
print(f"\nError saving file: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
create_dummy_loha_file()
# --- Verification Note for ComfyUI ---
# 1. Place `dummy_loha_corrected.safetensors` into `ComfyUI/models/loras/`.
# 2. Load an SDXL base model in ComfyUI.
# 3. Add a "Load LoRA" node and select `dummy_loha_corrected.safetensors`.
# 4. Connect the LoRA node between the checkpoint loader and the KSampler.
#
# Expected outcome:
# - ComfyUI should load the file without "key not loaded" or "dimension mismatch" errors
# for the layers defined in EXAMPLE_LAYERS_CONFIG.
# - The LoRA node should correctly identify it as a LoHA/LyCORIS model.
# - If you have layers in your SDXL model that match the names in EXAMPLE_LAYERS_CONFIG,
# ComfyUI will attempt to apply these (random) weights.

View File

@ -0,0 +1,397 @@
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file, load_file
from tqdm import tqdm
import math
import json
from collections import OrderedDict
import signal
import sys
# --- Global variables (ensure they are defined as before) ---
extracted_loha_state_dict_global = OrderedDict()
layer_optimization_stats_global = []
args_global = None
processed_layers_count_global = 0
skipped_identical_count_global = 0
skipped_other_count_global = 0
keys_scanned_this_run_global = 0 # This will track scans for the outer pbar
save_attempted_on_interrupt = False
outer_pbar_global = None
# --- optimize_loha_for_layer and get_module_shape_info_from_weight (UNCHANGED from your last version) ---
def optimize_loha_for_layer(
layer_name: str, delta_W_target: torch.Tensor, out_dim: int, in_dim_effective: int,
k_h: int, k_w: int, rank: int, initial_alpha_val: float, lr: float = 1e-3,
max_iterations: int = 1000, min_iterations: int = 100, target_loss: float = None,
weight_decay: float = 1e-4, device: str = 'cuda', dtype: torch.dtype = torch.float32,
is_conv: bool = True, verbose_layer_debug: bool = False
):
delta_W_target = delta_W_target.to(device, dtype=dtype)
if is_conv:
k_ops = k_h * k_w
hada_w1_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype))
hada_w1_b = nn.Parameter(torch.empty(rank, in_dim_effective * k_ops, device=device, dtype=dtype))
hada_w2_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype))
hada_w2_b = nn.Parameter(torch.empty(rank, in_dim_effective * k_ops, device=device, dtype=dtype))
else: # Linear
hada_w1_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype))
hada_w1_b = nn.Parameter(torch.empty(rank, in_dim_effective, device=device, dtype=dtype))
hada_w2_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype))
hada_w2_b = nn.Parameter(torch.empty(rank, in_dim_effective, device=device, dtype=dtype))
nn.init.kaiming_uniform_(hada_w1_a, a=math.sqrt(5))
nn.init.normal_(hada_w1_b, std=0.02)
nn.init.kaiming_uniform_(hada_w2_a, a=math.sqrt(5))
nn.init.normal_(hada_w2_b, std=0.02)
alpha_param = nn.Parameter(torch.tensor(initial_alpha_val, device=device, dtype=dtype))
optimizer = torch.optim.AdamW(
[hada_w1_a, hada_w1_b, hada_w2_a, hada_w2_b, alpha_param], lr=lr, weight_decay=weight_decay
)
patience_epochs = max(10, int(max_iterations * 0.05))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience_epochs, factor=0.5, min_lr=1e-7, verbose=False)
iter_pbar = tqdm(range(max_iterations), desc=f"Opt: {layer_name}", leave=False, dynamic_ncols=True, position=1)
final_loss = float('inf')
stopped_early_by_loss = False
iterations_actually_done = 0
for i in iter_pbar:
iterations_actually_done = i + 1
if save_attempted_on_interrupt:
print(f"\n Interrupt during opt of {layer_name}. Stopping this layer after {i} iters.")
break
optimizer.zero_grad()
eff_alpha_scale = alpha_param / rank
if is_conv:
term1_flat = hada_w1_a @ hada_w1_b; term1_reshaped = term1_flat.view(out_dim, in_dim_effective, k_h, k_w)
term2_flat = hada_w2_a @ hada_w2_b; term2_reshaped = term2_flat.view(out_dim, in_dim_effective, k_h, k_w)
delta_W_loha = eff_alpha_scale * term1_reshaped * term2_reshaped
else:
term1 = hada_w1_a @ hada_w1_b; term2 = hada_w2_a @ hada_w2_b
delta_W_loha = eff_alpha_scale * term1 * term2
loss = F.mse_loss(delta_W_loha, delta_W_target)
final_loss = loss.item()
loss.backward(); optimizer.step(); scheduler.step(loss)
current_lr = optimizer.param_groups[0]['lr']
iter_pbar.set_postfix_str(f"Loss={final_loss:.3e}, AlphaP={alpha_param.item():.2f}, LR={current_lr:.1e}", refresh=True)
if verbose_layer_debug and (i == 0 or (i + 1) % (max_iterations // 10 if max_iterations >= 10 else 1) == 0 or i == max_iterations - 1):
iter_pbar.write(f" Debug {layer_name} - Iter {i+1}/{max_iterations}: Loss: {final_loss:.6e}, LR: {current_lr:.2e}, AlphaP: {alpha_param.item():.4f}")
if target_loss is not None and i >= min_iterations -1 :
if final_loss <= target_loss:
if verbose_layer_debug or (args_global and args_global.verbose):
iter_pbar.write(f" Target loss {target_loss:.2e} reached for {layer_name} at iter {i+1}.")
stopped_early_by_loss = True; break
if not save_attempted_on_interrupt :
iter_pbar.set_description_str(f"Opt: {layer_name} (Done)")
iter_pbar.set_postfix_str(f"FinalLoss={final_loss:.2e}, It={iterations_actually_done}{', EarlyStop' if stopped_early_by_loss else ''}")
iter_pbar.close()
if save_attempted_on_interrupt and not stopped_early_by_loss and iterations_actually_done < max_iterations:
return {'final_loss': final_loss, 'stopped_early': False, 'iterations_done': iterations_actually_done, 'interrupted_mid_layer': True}
return {
'hada_w1_a': hada_w1_a.data.cpu().contiguous(), 'hada_w1_b': hada_w1_b.data.cpu().contiguous(),
'hada_w2_a': hada_w2_a.data.cpu().contiguous(), 'hada_w2_b': hada_w2_b.data.cpu().contiguous(),
'alpha': alpha_param.data.cpu().contiguous(), 'final_loss': final_loss,
'stopped_early': stopped_early_by_loss, 'iterations_done': iterations_actually_done,
'interrupted_mid_layer': False
}
def get_module_shape_info_from_weight(weight_tensor: torch.Tensor):
if len(weight_tensor.shape) == 4: is_conv = True; out_dim, in_dim_effective, k_h, k_w = weight_tensor.shape; groups = 1; return out_dim, in_dim_effective, k_h, k_w, groups, is_conv
elif len(weight_tensor.shape) == 2: is_conv = False; out_dim, in_dim = weight_tensor.shape; return out_dim, in_dim, None, None, 1, is_conv
return None
# --- perform_graceful_save (UNCHANGED from previous version) ---
def perform_graceful_save(output_path_override=None):
global extracted_loha_state_dict_global, layer_optimization_stats_global, args_global
global processed_layers_count_global, save_attempted_on_interrupt
global skipped_identical_count_global, skipped_other_count_global, keys_scanned_this_run_global
if not extracted_loha_state_dict_global: print("No layers were processed enough to save."); return
args_to_use = args_global
if not args_to_use: print("Error: Global args not available for saving metadata."); return
final_save_path = output_path_override if output_path_override else args_to_use.save_to
if args_to_use.save_weights_dtype == "fp16": final_save_dtype_torch = torch.float16
elif args_to_use.save_weights_dtype == "bf16": final_save_dtype_torch = torch.bfloat16
else: final_save_dtype_torch = torch.float32
final_state_dict_to_save = OrderedDict()
for k, v_tensor in extracted_loha_state_dict_global.items():
if v_tensor.is_floating_point(): final_state_dict_to_save[k] = v_tensor.to(final_save_dtype_torch)
else: final_state_dict_to_save[k] = v_tensor
print(f"\nAttempting to save {len(final_state_dict_to_save)} LoHA param sets for {processed_layers_count_global} layers to {final_save_path}")
eff_global_network_alpha_val = args_to_use.initial_alpha; eff_global_network_alpha_str = f"{eff_global_network_alpha_val:.8f}"
global_rank_str = str(args_to_use.rank)
conv_rank_str = str(args_to_use.conv_rank if args_to_use.conv_rank is not None else args_to_use.rank)
eff_conv_alpha_val = args_to_use.initial_conv_alpha; conv_alpha_str = f"{eff_conv_alpha_val:.8f}"
network_args_dict = {
"algo": "loha", "dim": global_rank_str, "alpha": eff_global_network_alpha_str,
"conv_dim": conv_rank_str, "conv_alpha": conv_alpha_str,
"dropout": str(args_to_use.dropout), "rank_dropout": str(args_to_use.rank_dropout), "module_dropout": str(args_to_use.module_dropout),
"use_tucker": "false", "use_scalar": "false", "block_size": "1",}
sf_metadata = {
"ss_network_module": "lycoris.kohya", "ss_network_rank": global_rank_str,
"ss_network_alpha": eff_global_network_alpha_str, "ss_network_algo": "loha",
"ss_network_args": json.dumps(network_args_dict),
"ss_comment": f"Extracted LoHA (Interrupt: {save_attempted_on_interrupt}). OptPrec: {args_to_use.precision}. SaveDtype: {args_to_use.save_weights_dtype}. ATOL: {args_to_use.atol_fp32_check}. Layers: {processed_layers_count_global}. MaxIter: {args_to_use.max_iterations}. TargetLoss: {args_to_use.target_loss}",
"ss_base_model_name": os.path.splitext(os.path.basename(args_to_use.base_model_path))[0],
"ss_ft_model_name": os.path.splitext(os.path.basename(args_to_use.ft_model_path))[0],
"ss_save_weights_dtype": args_to_use.save_weights_dtype, "ss_optimization_precision": args_to_use.precision,}
json_metadata_for_file = {
"comfyui_lora_type": "LyCORIS_LoHa", "model_name": os.path.splitext(os.path.basename(final_save_path))[0],
"base_model_path": args_to_use.base_model_path, "ft_model_path": args_to_use.ft_model_path,
"loha_extraction_settings": {k: str(v) if isinstance(v, type(os.pathsep)) else v for k,v in vars(args_to_use).items()},
"extraction_summary":{"processed_layers_count": processed_layers_count_global, "skipped_identical_count": skipped_identical_count_global, "skipped_other_count": skipped_other_count_global, "total_candidate_keys_scanned": keys_scanned_this_run_global,},
"layer_optimization_details": layer_optimization_stats_global, "embedded_safetensors_metadata": sf_metadata, "interrupted_save": save_attempted_on_interrupt}
if final_save_path.endswith(".safetensors"):
try: save_file(final_state_dict_to_save, final_save_path, metadata=sf_metadata); print(f"LoHA state_dict saved to: {final_save_path}")
except Exception as e: print(f"Error saving .safetensors file: {e}"); return
metadata_json_file_path = os.path.splitext(final_save_path)[0] + "_extraction_metadata.json"
try:
with open(metadata_json_file_path, 'w') as f: json.dump(json_metadata_for_file, f, indent=4)
print(f"Extended metadata saved to: {metadata_json_file_path}")
except Exception as e: print(f"Could not save extended metadata JSON: {e}")
else: print(f"Saving to .pt not fully supported with interrupt metadata.")
# --- handle_interrupt (UNCHANGED from previous version) ---
def handle_interrupt(signum, frame):
global save_attempted_on_interrupt, outer_pbar_global
print("\n" + "="*30 + "\nCtrl+C (SIGINT) detected!\n" + "="*30)
if save_attempted_on_interrupt: print("Save already attempted. Force exiting."); os._exit(1); return
save_attempted_on_interrupt = True
if outer_pbar_global: outer_pbar_global.close()
print("Attempting to save progress for processed layers...")
perform_graceful_save()
print("Graceful save attempt finished. Exiting.")
sys.exit(0)
def main(cli_args):
global args_global, extracted_loha_state_dict_global, layer_optimization_stats_global
global processed_layers_count_global, save_attempted_on_interrupt, outer_pbar_global
global skipped_identical_count_global, skipped_other_count_global, keys_scanned_this_run_global
args_global = cli_args
signal.signal(signal.SIGINT, handle_interrupt)
if args_global.precision == "fp16": target_opt_dtype = torch.float16
elif args_global.precision == "bf16": target_opt_dtype = torch.bfloat16
else: target_opt_dtype = torch.float32
if args_global.save_weights_dtype == "fp16": final_save_dtype = torch.float16
elif args_global.save_weights_dtype == "bf16": final_save_dtype = torch.bfloat16
else: final_save_dtype = torch.float32
print(f"Using device: {args_global.device}, Opt Dtype: {target_opt_dtype}, Save Dtype: {final_save_dtype}")
if args_global.target_loss: print(f"Target Loss: {args_global.target_loss:.2e} (after {args_global.min_iterations} min iters)")
print(f"Max Iters/Layer: {args_global.max_iterations}")
print(f"Loading base: {args_global.base_model_path}")
if args_global.base_model_path.endswith(".safetensors"): base_model_sd = load_file(args_global.base_model_path, device='cpu')
else: base_model_sd = torch.load(args_global.base_model_path, map_location='cpu'); base_model_sd = base_model_sd.get('state_dict', base_model_sd)
print(f"Loading fine-tuned: {args_global.ft_model_path}")
if args_global.ft_model_path.endswith(".safetensors"): ft_model_sd = load_file(args_global.ft_model_path, device='cpu')
else: ft_model_sd = torch.load(args_global.ft_model_path, map_location='cpu'); ft_model_sd = ft_model_sd.get('state_dict', ft_model_sd)
extracted_loha_state_dict_global = OrderedDict()
layer_optimization_stats_global = []
processed_layers_count_global = 0; skipped_identical_count_global = 0; skipped_other_count_global = 0; keys_scanned_this_run_global = 0
all_candidate_keys = []
for k in base_model_sd.keys():
if k.endswith('.weight') and k in ft_model_sd and (len(base_model_sd[k].shape) == 2 or len(base_model_sd[k].shape) == 4):
all_candidate_keys.append(k)
elif k.endswith('.weight') and k not in ft_model_sd and args_global.verbose:
print(f"Note: Base key '{k}' not in FT model.")
all_candidate_keys.sort()
total_candidate_keys_to_scan = len(all_candidate_keys)
print(f"Found {total_candidate_keys_to_scan} candidate '.weight' keys common to both models and of suitable shape.")
# The outer progress bar will now track the scan through ALL candidate keys
outer_pbar_global = tqdm(total=total_candidates_to_scan, desc=f"Scanning Layers (Processed 0)", dynamic_ncols=True, position=0)
try:
for key_name in all_candidate_keys:
if save_attempted_on_interrupt: break
keys_scanned_this_run_global += 1
outer_pbar_global.update(1) # Update for every key scanned
if args_global.max_layers is not None and args_global.max_layers > 0 and processed_layers_count_global >= args_global.max_layers:
if args_global.verbose: print(f"\nReached max_layers limit ({args_global.max_layers}). All further candidate layers will be skipped by the scan.")
# No 'break' here, let the loop finish scanning so the pbar completes to 100% of total_candidates_to_scan
# The actual optimization will be skipped by the condition above.
# Update description to show scanning is finishing due to max_layers
outer_pbar_global.set_description_str(f"Scan (Max Layers Reached: {processed_layers_count_global}/{args_global.max_layers}, Scanned {keys_scanned_this_run_global}/{total_candidates_to_scan})")
skipped_other_count_global +=1 # Count as skipped_other if we don't even check identical
continue # Continue to scan remaining keys but don't process them
base_W = base_model_sd[key_name].to(dtype=torch.float32)
ft_W = ft_model_sd[key_name].to(dtype=torch.float32)
if base_W.shape != ft_W.shape:
if args_global.verbose: print(f"\nSkipping {key_name} (scan): shape mismatch.")
skipped_other_count_global +=1; continue
shape_info = get_module_shape_info_from_weight(base_W)
if shape_info is None:
if args_global.verbose: print(f"\nSkipping {key_name} (scan): not Linear or Conv2d.")
skipped_other_count_global +=1; continue
delta_W_fp32 = (ft_W - base_W)
if torch.allclose(delta_W_fp32, torch.zeros_like(delta_W_fp32), atol=args_global.atol_fp32_check):
if args_global.verbose: print(f"\nSkipping {key_name} (scan): weights effectively identical.")
skipped_identical_count_global += 1; continue
# This layer WILL be processed. Update description.
max_layers_target_str = f"/{args_global.max_layers}" if args_global.max_layers is not None and args_global.max_layers > 0 else ""
outer_pbar_global.set_description_str(f"Optimizing L{processed_layers_count_global + 1}{max_layers_target_str} (Scanned {keys_scanned_this_run_global}/{total_candidates_to_scan})")
original_module_path = key_name[:-len(".weight")]
loha_key_prefix = ""
if original_module_path.startswith("model.diffusion_model."): loha_key_prefix = "lora_unet_" + original_module_path[len("model.diffusion_model."):].replace(".", "_")
elif original_module_path.startswith("conditioner.embedders.0.transformer."): loha_key_prefix = "lora_te1_" + original_module_path[len("conditioner.embedders.0.transformer."):].replace(".", "_")
elif original_module_path.startswith("conditioner.embedders.1.model.transformer."): loha_key_prefix = "lora_te2_" + original_module_path[len("conditioner.embedders.1.model.transformer."):].replace(".", "_")
else: loha_key_prefix = "lora_" + original_module_path.replace(".", "_")
if args_global.verbose: tqdm.write(f"\n Orig: {key_name} -> LoHA: {loha_key_prefix}") # Use tqdm.write for multiline
out_dim, in_dim_effective, k_h, k_w, _, is_conv = shape_info
delta_W_target_for_opt = delta_W_fp32.to(dtype=target_opt_dtype)
current_rank = args_global.conv_rank if is_conv and args_global.conv_rank is not None else args_global.rank
current_initial_alpha = args_global.initial_conv_alpha if is_conv else args_global.initial_alpha
tqdm.write(f"Optimizing Layer {processed_layers_count_global + 1}{max_layers_target_str}: {loha_key_prefix} (Orig: {original_module_path}, Shp: {list(base_W.shape)}, R: {current_rank}, Alpha_init: {current_initial_alpha:.1f})")
try:
opt_results = optimize_loha_for_layer(
layer_name=loha_key_prefix, delta_W_target=delta_W_target_for_opt,
out_dim=out_dim, in_dim_effective=in_dim_effective, k_h=k_h, k_w=k_w, rank=current_rank,
initial_alpha_val=current_initial_alpha, lr=args_global.lr,
max_iterations=args_global.max_iterations, min_iterations=args_global.min_iterations,
target_loss=args_global.target_loss, weight_decay=args_global.weight_decay,
device=args_global.device, dtype=target_opt_dtype, is_conv=is_conv,
verbose_layer_debug=args_global.verbose_layer_debug
)
if not opt_results.get('interrupted_mid_layer'):
for p_name, p_val in opt_results.items():
if p_name not in ['final_loss', 'stopped_early', 'iterations_done', 'interrupted_mid_layer']:
extracted_loha_state_dict_global[f'{loha_key_prefix}.{p_name}'] = p_val.to(final_save_dtype)
layer_optimization_stats_global.append({
"name": loha_key_prefix, "original_name": original_module_path,
"final_loss": opt_results['final_loss'], "iterations_done": opt_results['iterations_done'],
"stopped_early_by_loss_target": opt_results['stopped_early']
})
tqdm.write(f" Layer {loha_key_prefix} Done. Loss: {opt_results['final_loss']:.4e}, Iters: {opt_results['iterations_done']}{', Stopped by Loss' if opt_results['stopped_early'] else ''}")
if args_global.use_bias:
original_bias_key = f"{original_module_path}.bias"
if original_bias_key in ft_model_sd and original_bias_key in base_model_sd:
base_B = base_model_sd[original_bias_key].to(dtype=torch.float32); ft_B = ft_model_sd[original_bias_key].to(dtype=torch.float32)
if not torch.allclose(base_B, ft_B, atol=args_global.atol_fp32_check):
extracted_loha_state_dict_global[original_bias_key] = ft_B.cpu().to(final_save_dtype)
if args_global.verbose: tqdm.write(f" Saved differing bias for {original_bias_key}")
elif original_bias_key in ft_model_sd and original_bias_key not in base_model_sd:
if args_global.verbose: tqdm.write(f" Bias {original_bias_key} in FT only. Saving.")
extracted_loha_state_dict_global[original_bias_key] = ft_model_sd[original_bias_key].cpu().to(final_save_dtype)
processed_layers_count_global += 1
# Outer pbar update is now at the start of the loop for each scanned key.
# The description will reflect the processed count.
else:
if args_global.verbose: tqdm.write(f" Opt for {loha_key_prefix} interrupted; not saving params.")
except Exception as e:
print(f"\nError during optimization for {original_module_path} ({loha_key_prefix}): {e}")
import traceback; traceback.print_exc()
skipped_other_count_global +=1
finally:
if outer_pbar_global:
if outer_pbar_global.n < outer_pbar_global.total: # Fill to 100% if loop broke early
outer_pbar_global.update(outer_pbar_global.total - outer_pbar_global.n)
outer_pbar_global.close()
if not save_attempted_on_interrupt: # Normal completion
print("\n--- Final Optimization Summary (Normal Completion) ---")
for stat in layer_optimization_stats_global:
print(f"Layer: {stat['name']}, Final Loss: {stat['final_loss']:.4e}, Iters: {stat['iterations_done']}{', Stopped by Loss' if stat['stopped_early_by_loss_target'] else ''}")
print(f"\n--- Overall Summary ---")
print(f"Total candidate weight keys: {total_candidates_to_scan}")
print(f"Total keys scanned by loop: {keys_scanned_this_run_global}")
print(f"Processed {processed_layers_count_global} layers for LoHA extraction.")
print(f"Skipped {skipped_identical_count_global} layers (identical).")
print(f"Skipped {skipped_other_count_global} layers (other reasons).")
perform_graceful_save()
else: # Interrupted
print("\nProcess was interrupted. Saved data for fully completed layers.")
# --- __main__ block (UNCHANGED from previous version) ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract LoHA parameters by optimizing against weight differences.")
parser.add_argument("base_model_path", type=str, help="Path to the base model state_dict file (.pt, .pth, .safetensors)")
parser.add_argument("ft_model_path", type=str, help="Path to the fine-tuned model state_dict file (.pt, .pth, .safetensors)")
parser.add_argument("save_to", type=str, help="Path to save the extracted LoHA file (recommended .safetensors)")
parser.add_argument("--rank", type=int, default=4, help="Default rank for LoHA decomposition (used for linear layers and as fallback for conv).")
parser.add_argument("--conv_rank", type=int, default=None, help="Specific rank for convolutional LoHA layers. Defaults to --rank if not set.")
parser.add_argument("--initial_alpha", type=float, default=None,
help="Global initial alpha for optimization (used for linear and as fallback for conv). Defaults to 'rank'. This is also used for 'ss_network_alpha'.")
parser.add_argument("--initial_conv_alpha", type=float, default=None,
help="Specific initial alpha for convolutional LoHA layers. Defaults to '--initial_alpha' or conv_rank if neither initial_alpha nor initial_conv_alpha is set, it defaults to the respective rank.")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate for LoHA optimization per layer.")
parser.add_argument("--max_iterations", type=int, default=1000, help="Maximum number of optimization iterations per layer.")
parser.add_argument("--min_iterations", type=int, default=100, help="Minimum number of optimization iterations per layer before checking target loss. Default 100.")
parser.add_argument("--target_loss", type=float, default=None, help="Target MSE loss to achieve for stopping optimization for a layer. If None, runs for max_iterations.")
parser.add_argument("--weight_decay", type=float, default=1e-5, help="Weight decay for LoHA optimization.")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device for computation ('cuda' or 'cpu').")
parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"],
help="Computation precision for LoHA optimization. This is 'ss_mixed_precision'. Default: fp32")
parser.add_argument("--save_weights_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"],
help="Data type for saving the final LoHA weights in the .safetensors file. Default: bf16.")
parser.add_argument("--atol_fp32_check", type=float, default=1e-6,
help="Absolute tolerance for fp32 weight difference check to consider layers identical and skip them.")
parser.add_argument("--use_bias", action="store_true", help="If set, save fine-tuned bias terms if they differ from the base model's bias (saved with original key names).")
parser.add_argument("--dropout", type=float, default=0.0, help="General dropout for LoHA modules (for metadata).")
parser.add_argument("--rank_dropout", type=float, default=0.0, help="Rank dropout rate for LoHA modules (for metadata).")
parser.add_argument("--module_dropout", type=float, default=0.0, help="Module dropout rate for LoHA modules (for metadata).")
parser.add_argument("--max_layers", type=int, default=None,
help="Process at most N differing layers for quick testing. Layers are sorted by name before processing.")
parser.add_argument("--verbose", action="store_true", help="Print general verbose information during processing.")
parser.add_argument("--verbose_layer_debug", action="store_true",
help="Print very detailed per-iteration debug info for each layer's optimization (can be very spammy).")
parsed_args = parser.parse_args()
if not os.path.exists(parsed_args.base_model_path): print(f"Error: Base model path not found: {parsed_args.base_model_path}"); exit(1)
if not os.path.exists(parsed_args.ft_model_path): print(f"Error: Fine-tuned model path not found: {parsed_args.ft_model_path}"); exit(1)
save_dir = os.path.dirname(parsed_args.save_to)
if save_dir and not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True); print(f"Created directory: {save_dir}")
if parsed_args.initial_alpha is None: parsed_args.initial_alpha = float(parsed_args.rank)
conv_rank_for_alpha_default = parsed_args.conv_rank if parsed_args.conv_rank is not None else parsed_args.rank
if parsed_args.initial_conv_alpha is None: parsed_args.initial_conv_alpha = parsed_args.initial_alpha
main(parsed_args)

View File

@ -6,7 +6,7 @@ import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from tqdm import tqdm from tqdm import tqdm
from library import sai_model_spec, model_util, sdxl_model_util from library import sai_model_spec, model_util, sdxl_model_util
import lora import lora # Assuming this is your existing lora script/library
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
import logging import logging
@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
MIN_SV = 1e-6 MIN_SV = 1e-6
# ... (Keep all your existing helper functions: index_sv_cumulative, index_sv_fro, etc.)
def index_sv_cumulative(S, target): def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S)) original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum cumulative_sums = torch.cumsum(S, dim=0) / original_sum
@ -36,62 +37,73 @@ def index_sv_ratio(S, target):
index = max(1, min(index, len(S) - 1)) index = max(1, min(index, len(S) - 1))
return index return index
def index_sv_knee(S): def index_sv_knee_improved(S, MIN_SV_KNEE=1e-8): # MIN_SV_KNEE can be same as global MIN_SV or specific
"""Determine rank using the knee point detection method."""
n = len(S) n = len(S)
if n < 3: # Need at least 3 points to detect a knee if n < 3: return 1
return 1 s_max, s_min = S[0], S[-1]
if s_max - s_min < MIN_SV_KNEE: return 1
s_normalized = (S - s_min) / (s_max - s_min)
x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
distances = (x_normalized + s_normalized - 1).abs()
knee_index_0based = torch.argmax(distances).item()
rank = knee_index_0based + 1
rank = max(1, min(rank, n - 1))
return rank
# Line coefficients from (1, S[0]) to (n, S[-1]) def index_sv_cumulative_knee(S, min_sv_threshold=1e-8):
a = S[0] - S[-1] n = len(S)
b = n - 1 if n < 3: return 1
c = 1 * S[-1] - n * S[0] s_sum = torch.sum(S)
if s_sum < min_sv_threshold: return 1
# Compute distances for each k y_values = torch.cumsum(S, dim=0) / s_sum
distances = [] y_min, y_max = y_values[0], y_values[n-1]
for k in range(1, n + 1): if y_max - y_min < min_sv_threshold: return 1
dist = abs(a * k + b * S[k - 1] + c) / (a**2 + b**2)**0.5 y_norm = (y_values - y_min) / (y_max - y_min)
distances.append(dist) x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
distances = (y_norm - x_norm).abs()
# Find index of maximum distance (add 1 because k starts at 1) knee_index_0based = torch.argmax(distances).item()
index = torch.argmax(torch.tensor(distances)).item() + 1 rank = knee_index_0based + 1
index = max(1, min(index, n - 1)) rank = max(1, min(rank, n - 1))
return index return rank
def index_sv_rel_decrease(S, tau=0.1): def index_sv_rel_decrease(S, tau=0.1):
"""Determine rank based on relative decrease threshold.""" if len(S) < 2: return 1
if len(S) < 2:
return 1
# Compute ratios of consecutive singular values
ratios = S[1:] / S[:-1] ratios = S[1:] / S[:-1]
# Find the smallest k where ratio < tau
for k in range(len(ratios)): for k in range(len(ratios)):
if ratios[k] < tau: if ratios[k] < tau:
return max(1, k + 1) # k + 1 because we want rank after the drop return max(1, k + 1)
return min(len(S), len(S) - 1 if len(S) > 1 else 1)
# If no drop below tau, return max rank
return min(len(S), len(S) - 1)
def save_to_file(file_name, model, state_dict, dtype, metadata=None): def save_to_file(file_name, model_sd, dtype, metadata=None): # Changed model to model_sd for clarity
if dtype is not None: if dtype is not None:
for key in list(state_dict.keys()): for key in list(model_sd.keys()):
if isinstance(state_dict[key], torch.Tensor): if isinstance(model_sd[key], torch.Tensor):
state_dict[key] = state_dict[key].to(dtype) model_sd[key] = model_sd[key].to(dtype)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
def svd( # Filter out non-tensor metadata if it accidentally gets into model_sd
model_org=None, final_sd = {k: v for k, v in model_sd.items() if isinstance(v, torch.Tensor)}
model_tuned=None,
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(final_sd, file_name, metadata=metadata) # Pass metadata here
else:
# For .pt, metadata is typically not saved in this manner.
# If you need to save metadata with .pt, you might save a dict like {'state_dict': final_sd, 'metadata': metadata}
torch.save(final_sd, file_name)
if metadata:
logger.warning(".pt format does not standardly support metadata like safetensors. Metadata not saved in file.")
def svd_decomposition(
model_org_path=None, # Renamed for clarity
model_tuned_path=None, # Renamed for clarity
save_to=None, save_to=None,
dim=4, algo="lora", # New: lora or loha
network_dim=4, # For LoRA: rank. For LoHA: "factor" or "hada_dim"
network_alpha=None, # For LoRA: alpha (often same as rank). For LoHA: "rank_initial" or "hada_alpha"
conv_dim=None, # For LoRA: conv_rank. For LoHA: "conv_factor"
conv_alpha=None, # For LoRA: conv_alpha. For LoHA: "conv_rank_initial"
v2=None, v2=None,
sdxl=None, sdxl=None,
conv_dim=None,
v_parameterization=None, v_parameterization=None,
device=None, device=None,
save_precision=None, save_precision=None,
@ -106,12 +118,9 @@ def svd(
verbose=False, verbose=False,
): ):
def str_to_dtype(p): def str_to_dtype(p):
if p == "float": if p == "float": return torch.float
return torch.float if p == "fp16": return torch.float16
if p == "fp16": if p == "bf16": return torch.bfloat16
return torch.float16
if p == "bf16":
return torch.bfloat16
return None return None
assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time" assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time"
@ -119,222 +128,395 @@ def svd(
load_dtype = str_to_dtype(load_precision) if load_precision else None load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision) if save_precision else torch.float save_dtype = str_to_dtype(save_precision) if save_precision else torch.float
work_device = "cpu" work_device = "cpu" # Perform SVD and weight manipulation on CPU then move
compute_device = device if device else "cpu"
# Handle default alpha values based on dim values if not provided
if network_alpha is None: network_alpha = network_dim
if conv_dim is None: conv_dim = network_dim # default conv_dim to network_dim
if conv_alpha is None: conv_alpha = conv_dim # default conv_alpha to conv_dim
# Load models # Load models
if not sdxl: if not sdxl:
logger.info(f"Loading original SD model: {model_org}") logger.info(f"Loading original SD model: {model_org_path}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org_path, load_dtype)
text_encoders_o = [text_encoder_o] text_encoders_o = [text_encoder_o]
if load_dtype:
text_encoder_o.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SD model: {model_tuned}") logger.info(f"Loading tuned SD model: {model_tuned_path}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned_path, load_dtype)
text_encoders_t = [text_encoder_t] text_encoders_t = [text_encoder_t]
if load_dtype:
text_encoder_t.to(load_dtype)
unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else: else: # SDXL
device_org = load_original_model_to or "cpu" device_org = load_original_model_to or "cpu"
device_tuned = load_tuned_model_to or "cpu" device_tuned = load_tuned_model_to or "cpu"
logger.info(f"Loading original SDXL model: {model_org}") logger.info(f"Loading original SDXL model: {model_org_path}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org_path, device_org, load_dtype
) )
text_encoders_o = [text_encoder_o1, text_encoder_o2] text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype:
text_encoder_o1.to(load_dtype)
text_encoder_o2.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SDXL model: {model_tuned}") logger.info(f"Loading tuned SDXL model: {model_tuned_path}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned_path, device_tuned, load_dtype
) )
text_encoders_t = [text_encoder_t1, text_encoder_t2] text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype:
text_encoder_t1.to(load_dtype)
text_encoder_t2.to(load_dtype)
unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# Create LoRA network # Create temporary LoRA network to identify modules and get original weights
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {} # Use a minimal fixed dimension for this stage as we only need module structure and original weights
temp_lora_kwargs = {"conv_dim": 1, "conv_alpha": 1.0} # Minimal conv settings for network creation
lora_network_o = lora.create_network(1.0, 1, 1.0, None, text_encoders_o, unet_o, **temp_lora_kwargs)
lora_network_t = lora.create_network(1.0, 1, 1.0, None, text_encoders_t, unet_t, **temp_lora_kwargs)
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras)
# Define a small initial dimension for memory efficiency
init_dim = 4 # Small value to minimize memory usage
# Create LoRA networks with minimal dimension
lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)"
# Compute differences
diffs = {} diffs = {}
text_encoder_different = False text_encoder_differs = False
for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras): # Text Encoders
lora_name = lora_o.lora_name for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device) module_key_name = lora_o.lora_name # e.g. "lora_te1_text_model_encoder_layers_0_mlp_fc1"
org_weight = lora_o.org_module.weight.to(device=work_device, dtype=torch.float)
tuned_weight = lora_t.org_module.weight.to(device=work_device, dtype=torch.float)
diff = tuned_weight - org_weight
if torch.max(torch.abs(diff)) > min_diff:
text_encoder_differs = True
logger.info(f"Text encoder {i+1} module {module_key_name} differs: max diff {torch.max(torch.abs(diff))}")
diffs[module_key_name] = diff
else:
logger.info(f"Text encoder {i+1} module {module_key_name} has no significant difference.")
# Free memory
lora_o.org_module.weight = None lora_o.org_module.weight = None
lora_t.org_module.weight = None lora_t.org_module.weight = None
del org_weight, tuned_weight
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: # UNet
text_encoder_different = True
logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}")
diffs[lora_name] = diff
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
logger.warning("Text encoders are identical. Extracting U-Net only.")
lora_network_o.text_encoder_loras = []
diffs.clear()
for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras): for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras):
lora_name = lora_o.lora_name module_key_name = lora_o.lora_name # e.g. "lora_unet_input_blocks_1_1_proj_in"
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device) org_weight = lora_o.org_module.weight.to(device=work_device, dtype=torch.float)
tuned_weight = lora_t.org_module.weight.to(device=work_device, dtype=torch.float)
diff = tuned_weight - org_weight
if torch.max(torch.abs(diff)) > min_diff:
logger.info(f"UNet module {module_key_name} differs: max diff {torch.max(torch.abs(diff))}")
diffs[module_key_name] = diff
else:
logger.info(f"UNet module {module_key_name} has no significant difference.")
lora_o.org_module.weight = None lora_o.org_module.weight = None
lora_t.org_module.weight = None lora_t.org_module.weight = None
diffs[lora_name] = diff del org_weight, tuned_weight
del lora_network_t, unet_t if not text_encoder_differs:
logger.warning("Text encoder weights are identical or below min_diff. Text encoder LoRA modules will not be included.")
# Remove text encoder diffs if none were significant
diffs = {k: v for k, v in diffs.items() if "unet" in k}
# Filter relevant modules del lora_network_o, lora_network_t, text_encoders_o, text_encoders_t, unet_o, unet_t
lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras) torch.cuda.empty_cache()
lora_module_weights = {} # This will store the final decomposed weights for LoRA/LoHA
logger.info(f"Extracting and resizing {algo.upper()} modules via SVD")
# Extract and resize LoRA using SVD
logger.info("Extracting and resizing LoRA via SVD")
lora_weights = {}
with torch.no_grad(): with torch.no_grad():
for lora_name in tqdm(lora_names): for module_key_name, mat_diff in tqdm(diffs.items()):
mat = diffs[lora_name] if compute_device != "cpu": # Move to GPU for SVD if specified
if device: mat_diff = mat_diff.to(compute_device)
mat = mat.to(device)
mat = mat.to(torch.float)
conv2d = len(mat.size()) == 4 # Determine if the layer is convolutional and its properties
kernel_size = mat.size()[2:4] if conv2d else None is_conv = len(mat_diff.shape) == 4
conv2d_3x3 = conv2d and kernel_size != (1, 1) kernel_size = None
out_dim, in_dim = mat.size()[0:2] if is_conv:
kernel_size = mat_diff.shape[2:4]
if conv2d: # For LoRA, conv_dim/alpha are specific to 3x3 convs. For others, it uses network_dim/alpha.
mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze() # For LoHA, we use conv_dim/alpha for any conv layer, and network_dim/alpha for linear.
# This logic can be refined based on how LyCORIS typically handles different conv kernel sizes for LoHA.
# Here, we'll use conv parameters if it's a conv layer, otherwise network parameters.
U, S, Vh = torch.linalg.svd(mat) current_dim_target = conv_dim if is_conv else network_dim
current_alpha_target = conv_alpha if is_conv else network_alpha
# For LoHA, 'alpha_target' is the rank of the first SVD (rank_initial)
# and 'dim_target' is the rank of the second SVD (rank_factor).
# For LoRA, 'dim_target' is the rank of the SVD, and 'alpha_target' is the scaling factor.
# Reshape convolutional weights for SVD
original_shape = mat_diff.shape
if is_conv:
if kernel_size == (1, 1):
mat_diff = mat_diff.squeeze() # Becomes (out_channels, in_channels)
else: # kernel_size (3,3) or others
mat_diff = mat_diff.flatten(start_dim=1) # Becomes (out_channels, in_channels*k_w*k_h)
out_features, in_features = mat_diff.shape[0], mat_diff.shape[1]
# Perform first SVD
try:
U, S, Vh = torch.linalg.svd(mat_diff)
except Exception as e:
logger.error(f"SVD failed for {module_key_name} with shape {mat_diff.shape}: {e}")
continue
if compute_device != "cpu": # Move results back to CPU if computation was on GPU
U, S, Vh = U.cpu(), S.cpu(), Vh.cpu()
# Determine rank for the first SVD (rank_initial for LoHA, rank for LoRA)
max_rank_initial = min(out_features, in_features) # Theoretical max rank
# Default rank_initial to current_alpha_target (which is network_alpha or conv_alpha)
rank_initial = current_alpha_target
# Determine rank
max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
if dynamic_method: if dynamic_method:
if S[0] <= MIN_SV: if S[0] <= MIN_SV:
rank = 1 determined_rank = 1
elif dynamic_method == "sv_ratio": elif dynamic_method == "sv_ratio": determined_rank = index_sv_ratio(S, dynamic_param)
rank = index_sv_ratio(S, dynamic_param) elif dynamic_method == "sv_cumulative": determined_rank = index_sv_cumulative(S, dynamic_param)
elif dynamic_method == "sv_cumulative": elif dynamic_method == "sv_fro": determined_rank = index_sv_fro(S, dynamic_param)
rank = index_sv_cumulative(S, dynamic_param) elif dynamic_method == "sv_knee": determined_rank = index_sv_knee_improved(S, MIN_SV)
elif dynamic_method == "sv_fro": elif dynamic_method == "sv_cumulative_knee": determined_rank = index_sv_cumulative_knee(S, MIN_SV)
rank = index_sv_fro(S, dynamic_param) elif dynamic_method == "sv_rel_decrease": determined_rank = index_sv_rel_decrease(S, dynamic_param)
elif dynamic_method == "sv_knee": else: determined_rank = rank_initial # Fallback if dynamic method unknown
rank = index_sv_knee(S) rank_initial = min(determined_rank, current_alpha_target, max_rank_initial)
elif dynamic_method == "sv_rel_decrease":
rank = index_sv_rel_decrease(S, dynamic_param)
rank = min(rank, max_rank, in_dim, out_dim)
else: else:
rank = min(max_rank, in_dim, out_dim) rank_initial = min(current_alpha_target, max_rank_initial)
rank_initial = max(1, rank_initial) # Ensure rank is at least 1
# Truncate SVD components # --- LoRA specific decomposition ---
U = U[:, :rank] @ torch.diag(S[:rank]) if algo == 'lora':
Vh = Vh[:rank, :] lora_down = Vh[:rank_initial, :]
lora_up = U[:, :rank_initial] @ torch.diag(S[:rank_initial])
# Clamp values # Clamp values
dist = torch.cat([U.flatten(), Vh.flatten()]) dist = torch.cat([lora_up.flatten(), lora_down.flatten()])
hi_val = torch.quantile(dist, clamp_quantile) hi_val = torch.quantile(dist, clamp_quantile) if clamp_quantile < 1.0 else dist.abs().max()
U = U.clamp(-hi_val, hi_val) lora_up = lora_up.clamp(-hi_val, hi_val)
Vh = Vh.clamp(-hi_val, hi_val) lora_down = lora_down.clamp(-hi_val, hi_val)
if conv2d: # Reshape for conv layers if necessary
U = U.reshape(out_dim, rank, 1, 1) if is_conv:
Vh = Vh.reshape(rank, in_dim, *kernel_size) if kernel_size == (1,1):
# These are already (out_c, rank) and (rank, in_c)
# Some LoRA impls might expect them to be 4D (out_c, rank, 1, 1) and (rank, in_c, 1, 1)
lora_up = lora_up.reshape(out_features, rank_initial, 1, 1)
lora_down = lora_down.reshape(rank_initial, in_features, 1, 1)
else: # e.g. 3x3 conv
# lora_down was (rank, in_c*k_w*k_h), needs to be (rank, in_c, k_w, k_h)
# lora_up was (out_c, rank)
lora_up = lora_up.reshape(out_features, rank_initial, 1, 1) # often up is kept as 1x1 conv like
lora_down = lora_down.reshape(rank_initial, original_shape[1], *kernel_size)
U = U.to(work_device, dtype=save_dtype).contiguous() lora_module_weights[f"{module_key_name}.lora_down.weight"] = lora_down.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to(work_device, dtype=save_dtype).contiguous() lora_module_weights[f"{module_key_name}.lora_up.weight"] = lora_up.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U, Vh) lora_module_weights[f"{module_key_name}.alpha"] = torch.tensor(float(current_alpha_target), dtype=save_dtype) # Use actual alpha target from params
# --- LoHA specific decomposition ---
elif algo == 'loha':
lora_down_equivalent = Vh[:rank_initial, :]
lora_up_equivalent = U[:, :rank_initial] @ torch.diag(S[:rank_initial])
# current_dim_target is the "factor" for LoHA's second SVD
rank_factor = min(current_dim_target, rank_initial) # Factor cannot exceed rank_initial
if is_conv and kernel_size != (1,1): # For conv3x3, factor also limited by in/out features
rank_factor = min(rank_factor, original_shape[1], original_shape[0]) #original_shape[1] is in_channels
else: # Linear or Conv1x1
rank_factor = min(rank_factor, in_features, out_features)
rank_factor = max(1, rank_factor)
# Decompose Lora_Down_Equivalent (shape: rank_initial x in_features_eff)
# Target: hada_w1_b (rank_initial x rank_factor) @ hada_w1_a (rank_factor x in_features_eff)
if compute_device != "cpu": lora_down_equivalent = lora_down_equivalent.to(compute_device)
Ud, Sd, Vhd = torch.linalg.svd(lora_down_equivalent)
if compute_device != "cpu": Ud, Sd, Vhd = Ud.cpu(), Sd.cpu(), Vhd.cpu()
hada_w1_a = Vhd[:rank_factor, :]
hada_w1_b = Ud[:, :rank_factor] @ torch.diag(Sd[:rank_factor])
# Decompose Lora_Up_Equivalent (shape: out_features_eff x rank_initial)
# Target: hada_w2_b (out_features_eff x rank_factor) @ hada_w2_a (rank_factor x rank_initial)
if compute_device != "cpu": lora_up_equivalent = lora_up_equivalent.to(compute_device)
Uu, Su, Vhu = torch.linalg.svd(lora_up_equivalent)
if compute_device != "cpu": Uu, Su, Vhu = Uu.cpu(), Su.cpu(), Vhu.cpu()
hada_w2_a = Vhu[:rank_factor, :]
hada_w2_b = Uu[:, :rank_factor] @ torch.diag(Su[:rank_factor])
# Clamp LoHA components
dist_w1a = hada_w1_a.flatten()
dist_w1b = hada_w1_b.flatten()
dist_w2a = hada_w2_a.flatten()
dist_w2b = hada_w2_b.flatten()
if clamp_quantile < 1.0:
hi_val_w1a = torch.quantile(dist_w1a.abs(), clamp_quantile)
hi_val_w1b = torch.quantile(dist_w1b.abs(), clamp_quantile)
hi_val_w2a = torch.quantile(dist_w2a.abs(), clamp_quantile)
hi_val_w2b = torch.quantile(dist_w2b.abs(), clamp_quantile)
else: # Use max abs value if quantile is 1.0 or more
hi_val_w1a = dist_w1a.abs().max()
hi_val_w1b = dist_w1b.abs().max()
hi_val_w2a = dist_w2a.abs().max()
hi_val_w2b = dist_w2b.abs().max()
hada_w1_a = hada_w1_a.clamp(-hi_val_w1a, hi_val_w1a)
hada_w1_b = hada_w1_b.clamp(-hi_val_w1b, hi_val_w1b)
hada_w2_a = hada_w2_a.clamp(-hi_val_w2a, hi_val_w2a)
hada_w2_b = hada_w2_b.clamp(-hi_val_w2b, hi_val_w2b)
# LoHA weights are typically stored as 2D matrices.
# LyCORIS library handles reshaping or uses 1x1 convs internally.
lora_module_weights[f"{module_key_name}.hada_w1_a"] = hada_w1_a.to(work_device, dtype=save_dtype).contiguous()
lora_module_weights[f"{module_key_name}.hada_w1_b"] = hada_w1_b.to(work_device, dtype=save_dtype).contiguous()
lora_module_weights[f"{module_key_name}.hada_w2_a"] = hada_w2_a.to(work_device, dtype=save_dtype).contiguous()
lora_module_weights[f"{module_key_name}.hada_w2_b"] = hada_w2_b.to(work_device, dtype=save_dtype).contiguous()
lora_module_weights[f"{module_key_name}.alpha"] = torch.tensor(float(current_alpha_target), dtype=save_dtype) # This is rank_initial
# Verbose output
if verbose: if verbose:
s_sum = float(torch.sum(S)) s_sum = float(torch.sum(S))
s_rank = float(torch.sum(S[:rank])) s_rank_initial = float(torch.sum(S[:rank_initial]))
fro = float(torch.sqrt(torch.sum(S.pow(2)))) fro_initial = float(torch.sqrt(torch.sum(S.pow(2))))
fro_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2)))) fro_rank_initial = float(torch.sqrt(torch.sum(S[:rank_initial].pow(2))))
ratio = S[0] / S[rank - 1] if rank > 1 else float('inf') # This verbose output is for the first SVD. Adding verbose for second SVD would be more complex.
logger.info(f"{lora_name:75} | sum(S) retained: {s_rank/s_sum:.1%}, fro retained: {fro_rank/fro:.1%}, max ratio: {ratio:.1f}, rank: {rank}") logger.info(
f"{module_key_name[:75]:75} | Algo: {algo.upper()}, Rank/Alpha (initial): {rank_initial}, Dim/Factor (final): {rank_factor if algo=='loha' else rank_initial} "
f"| Sum(S) Retained (1st SVD): {s_rank_initial/s_sum if s_sum > 0 else 0:.1%}, Fro Retained (1st SVD): {fro_rank_initial/fro_initial if fro_initial > 0 else 0:.1%}"
)
# Create state dict del U, S, Vh, mat_diff
lora_sd = {} if algo == 'loha':
for lora_name, (up_weight, down_weight) in lora_weights.items(): del lora_down_equivalent, lora_up_equivalent, Ud, Sd, Vhd, Uu, Su, Vhu
lora_sd[lora_name + ".lora_up.weight"] = up_weight del hada_w1_a, hada_w1_b, hada_w2_a, hada_w2_b
lora_sd[lora_name + ".lora_down.weight"] = down_weight else: # lora
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype) del lora_down, lora_up
if compute_device != "cpu":
torch.cuda.empty_cache()
# Load and save LoRA
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd) if not lora_module_weights:
lora_network_save.apply_to(text_encoders_o, unet_o) logger.error("No LoRA/LoHA modules were generated. This might be due to models being too similar or min_diff being too high.")
info = lora_network_save.load_state_dict(lora_sd) return
logger.info(f"Loaded extracted and resized LoRA weights: {info}")
os.makedirs(os.path.dirname(save_to), exist_ok=True) os.makedirs(os.path.dirname(save_to), exist_ok=True)
# Metadata # Metadata
net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {}
metadata = { metadata = {
"ss_v2": str(v2), "ss_v2": str(v2) if v2 is not None else "false", # More explicit "false"
"ss_base_model_version": model_version, "ss_base_model_version": model_version,
"ss_network_module": "networks.lora", "ss_sdxl_model_version": "1.0" if sdxl else "null", # LyCORIS convention
"ss_network_dim": str(dim) if not dynamic_method else "Dynamic", "ss_network_module": "networks.lora" if algo == "lora" else "lycoris.kohya", # For LoHA
"ss_network_alpha": str(float(dim)) if not dynamic_method else "Dynamic", # For LoRA, dim and alpha are usually the same as what's passed.
"ss_network_args": json.dumps(net_kwargs), # For LoHA, network_dim is the 'factor' (our network_dim/conv_dim),
# and network_alpha is the 'rank_initial' (our network_alpha/conv_alpha).
"ss_network_dim": str(network_dim) if not dynamic_method or algo == "loha" else "Dynamic", # For LoHA, this is 'factor'
"ss_network_alpha": str(network_alpha) if not dynamic_method or algo == "loha" else "Dynamic", # For LoHA, this is 'rank_initial'
} }
net_kwargs = {}
if algo == "lora":
if conv_dim is not None: # Only add if conv_dim was actually specified for LoRA
net_kwargs["conv_dim"] = str(conv_dim)
net_kwargs["conv_alpha"] = str(float(conv_alpha if conv_alpha is not None else conv_dim))
elif algo == "loha":
net_kwargs["algo"] = "loha"
# For LoHA, conv_dim and conv_alpha are distinct concepts (factor and rank_initial for conv layers)
net_kwargs["conv_dim"] = str(conv_dim)
net_kwargs["conv_alpha"] = str(float(conv_alpha))
# LyCORIS sometimes uses dropout, but we are not implementing it here.
# net_kwargs["dropout"] = "0" # Example if we had dropout
metadata["ss_network_args"] = json.dumps(net_kwargs) if net_kwargs else "null" # LyCORIS uses "null" for empty
if not no_metadata: if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0] title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title) # sai_model_spec might need adjustment if it expects specific lora types not lycoris
metadata.update(sai_metadata) try:
sai_metadata = sai_model_spec.build_metadata(
None, v2, v_parameterization, sdxl, True,
is_lycoris=(algo != "lora"), # Pass if it's LyCORIS
is_lora=(algo == "lora"), # Pass if it's LoRA
creation_time=time.time(), title=title
)
metadata.update(sai_metadata)
except TypeError as e:
logger.warning(f"Could not generate full SAI metadata, possibly due to outdated sai_model_spec.py or new flags: {e}")
logger.warning("Falling back to basic metadata for SAI fields.")
metadata.update({ # Basic fallback
"sai_model_name": title,
"sai_base_model": model_version,
"sai_is_sdxl": str(sdxl).lower(),
})
save_to_file(save_to, lora_module_weights, save_dtype, metadata)
logger.info(f"{algo.upper()} saved to: {save_to}")
save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata)
logger.info(f"LoRA saved to: {save_to}")
def setup_parser(): def setup_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model") parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)") parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2 if applicable)")
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model") parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models")
parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA") parser.add_argument("--model_org_path", type=str, required=True, help="Path to the original Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)") parser.add_argument("--model_tuned_path", type=str, required=True, help="Path to the tuned Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)") parser.add_argument("--save_to", type=str, required=True, help="Output file name for the LoRA/LoHA (ckpt/safetensors)")
parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)")
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers") parser.add_argument("--algo", type=str, default="lora", choices=["lora", "loha"], help="Algorithm to use: lora or loha")
parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3")
parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)") parser.add_argument("--network_dim", type=int, default=4, help="Network dimension. For LoRA: rank. For LoHA: 'factor' or 'hada_dim'.")
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights") parser.add_argument("--network_alpha", type=int, default=None, help="Network alpha. For LoRA: alpha (often same as rank). For LoHA: 'rank_initial' or 'hada_alpha'. Defaults to network_dim if not set.")
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract") parser.add_argument("--conv_dim", type=int, default=None, help="Conv dimension for conv layers. For LoRA: rank. For LoHA: 'factor'. Defaults to network_dim if not set.")
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata") parser.add_argument("--conv_alpha", type=int, default=None, help="Conv alpha for conv layers. For LoRA: alpha. For LoHA: 'rank_initial'. Defaults to conv_dim if not set.")
parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)")
parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)") parser.add_argument("--load_precision", type=str, choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for loading models (None means default float32)")
parser.add_argument("--dynamic_method", choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease"], help="Dynamic rank reduction method") parser.add_argument("--save_precision", type=str, choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA/LoHA (None means float32)")
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info") parser.add_argument("--device", type=str, default=None, help="Device for SVD computation (e.g., 'cuda', 'cpu'). If None, defaults to 'cuda' if available, else 'cpu'. SVD results are moved to CPU for storage.")
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights (0.0 to 1.0). 1.0 means clamp to max abs value.")
parser.add_argument("--min_diff", type=float, default=1e-6, help="Minimum weight difference threshold for a module to be considered for extraction.") # Lowered default
parser.add_argument("--no_metadata", action="store_true", help="Do not save detailed metadata (minimal ss_ metadata will still be saved).")
parser.add_argument("--load_original_model_to", type=str, default=None, help="Device to load original model to (SDXL only, e.g., 'cpu', 'cuda:0')")
parser.add_argument("--load_tuned_model_to", type=str, default=None, help="Device to load tuned model to (SDXL only, e.g., 'cpu', 'cuda:1')")
parser.add_argument("--dynamic_method", type=str, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], default=None, help="Dynamic method to determine rank/alpha (for the first SVD in LoHA). Overrides fixed network_alpha/conv_alpha if set.")
parser.add_argument("--dynamic_param", type=float, default=0.9, help="Parameter for the chosen dynamic_method (e.g., target ratio/cumulative sum/Frobenius norm percentage).")
parser.add_argument("--verbose", action="store_true", help="Show detailed SVD info for each module.")
return parser return parser
if __name__ == "__main__": if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
if args.dynamic_method and not args.dynamic_param:
raise ValueError("Dynamic method requires a dynamic parameter") if args.dynamic_method and (args.dynamic_param is None):
svd(**vars(args)) # Default dynamic_param for methods if not specified, or raise error
if args.dynamic_method in ["sv_cumulative", "sv_fro"]:
args.dynamic_param = 0.99 # Example: 99% variance/energy
logger.info(f"Dynamic method {args.dynamic_method} chosen, dynamic_param defaulted to {args.dynamic_param}")
elif args.dynamic_method in ["sv_ratio"]:
args.dynamic_param = 1000 # Example: ratio of 1000
logger.info(f"Dynamic method {args.dynamic_method} chosen, dynamic_param defaulted to {args.dynamic_param}")
elif args.dynamic_method in ["sv_rel_decrease"]:
args.dynamic_param = 0.05 # Example: 5% relative decrease
logger.info(f"Dynamic method {args.dynamic_method} chosen, dynamic_param defaulted to {args.dynamic_param}")
# sv_knee and sv_cumulative_knee do not require dynamic_param in this implementation.
elif args.dynamic_method not in ["sv_knee", "sv_cumulative_knee"]:
parser.error("--dynamic_method requires --dynamic_param for most methods.")
# Default device selection
if args.device is None:
args.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {args.device} for SVD computation.")
# Rename args for clarity before passing to function
func_args = vars(args).copy()
func_args["model_org_path"] = func_args.pop("model_org_path")
func_args["model_tuned_path"] = func_args.pop("model_tuned_path")
svd_decomposition(**func_args)

View File

@ -0,0 +1,432 @@
import argparse
import json
import os
import time
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, model_util, sdxl_model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
def index_sv_knee_improved(S, MIN_SV_KNEE=1e-8): # MIN_SV_KNEE can be same as global MIN_SV or specific
"""
Determine rank using the knee point detection method with normalization.
Normalizes singular values and their indices to the [0,1] range
to make the knee detection scale-invariant.
"""
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
# S is expected to be sorted in descending order.
s_max, s_min = S[0], S[-1]
# Handle flat or nearly flat singular value spectrum
if s_max - s_min < MIN_SV_KNEE:
# If all singular values are almost the same, a knee is not well-defined.
# Returning 1 is a conservative choice for low rank.
# Alternatively, n // 2 or n - 1 could be chosen depending on desired behavior.
return 1
# Normalize singular values to [0, 1]
# s_normalized[0] will be 1, s_normalized[n-1] will be 0
s_normalized = (S - s_min) / (s_max - s_min)
# Normalize indices to [0, 1]
# x_normalized[0] will be 0, x_normalized[n-1] will be 1
x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
# The line in normalized space connects (x_norm[0], s_norm[0]) to (x_norm[n-1], s_norm[n-1])
# This is (0, 1) to (1, 0).
# The equation of the line passing through (0,1) and (1,0) is x + y - 1 = 0.
# So, A=1, B=1, C=-1 for the line equation Ax + By + C = 0.
# Calculate the perpendicular distance from each point (x_normalized[i], s_normalized[i]) to this line.
# Distance = |A*x_i + B*y_i + C| / sqrt(A^2 + B^2)
# Distance = |1*x_normalized + 1*s_normalized - 1| / sqrt(1^2 + 1^2)
# Distance = |x_normalized + s_normalized - 1| / sqrt(2)
# The sqrt(2) denominator is constant and doesn't affect argmax, so it can be omitted for finding the index.
distances = (x_normalized + s_normalized - 1).abs()
# Find the 0-based index of the point with the maximum distance.
knee_index_0based = torch.argmax(distances).item()
# Convert 0-based index to 1-based rank.
rank = knee_index_0based + 1
# Clamp rank similar to original: must be > 0 and <= n-1 (typical for rank reduction)
# If knee_index_0based is n-1 (last point), rank becomes n. min(n, n-1) results in n-1.
rank = max(1, min(rank, n - 1))
return rank
def index_sv_cumulative_knee(S, min_sv_threshold=1e-8):
"""
Determine rank using the knee point detection method on the normalized cumulative sum of singular values.
This method identifies a point where adding more singular values contributes diminishingly to the total sum.
"""
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
s_sum = torch.sum(S)
# If all singular values are zero or very small, return rank 1.
if s_sum < min_sv_threshold:
return 1
# Calculate cumulative sum of singular values, normalized by the total sum.
# y_values[0] = S[0]/s_sum, ..., y_values[n-1] = 1.0
y_values = torch.cumsum(S, dim=0) / s_sum
# Normalize these y_values (cumulative sums) to the range [0,1] for knee detection.
y_min, y_max = y_values[0], y_values[n-1] # y_max is typically 1.0
# If the normalized cumulative sum curve is very flat (e.g., S[0] captures almost all energy),
# it implies the first few components are dominant.
if y_max - y_min < min_sv_threshold: # Using min_sv_threshold here as a sensitivity for flatness
# This condition means (S[0] + ... + S[n-1]) - S[0] is small relative to sum(S) if n>1
# Effectively, S[1:] components are negligible.
return 1
# y_norm[0] = 0, y_norm[n-1] = 1 (represents the normalized cumulative sum from start to end)
y_norm = (y_values - y_min) / (y_max - y_min)
# x_values are indices, normalized to [0, 1]
# x_norm[0] = 0, ..., x_norm[n-1] = 1
x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
# The "knee" is the point on the curve (x_norm[i], y_norm[i]) that is farthest
# from the line connecting the start and end of this normalized curve.
# In this normalized space, the line connects (0,0) to (1,1).
# The equation of this line is Y = X, or X - Y = 0.
# The distance from a point (x_i, y_i) to the line X - Y = 0 is |x_i - y_i| / sqrt(1^2 + (-1)^2).
# We can maximize |x_i - y_i| (or |y_i - x_i|) as sqrt(2) is a constant factor.
distances = (y_norm - x_norm).abs() # y_norm is expected to be >= x_norm for a concave cumulative curve.
# Find the 0-based index of the point with the maximum distance.
knee_index_0based = torch.argmax(distances).item()
# Convert 0-based index to 1-based rank.
rank = knee_index_0based + 1
# Clamp rank to be between 1 and n-1 (as n elements give n-1 possible ranks for truncation).
# A rank of n means no truncation. n-1 is the highest sensible rank for reduction.
rank = max(1, min(rank, n - 1))
return rank
def index_sv_rel_decrease(S, tau=0.1):
"""Determine rank based on relative decrease threshold."""
if len(S) < 2:
return 1
# Compute ratios of consecutive singular values
ratios = S[1:] / S[:-1]
# Find the smallest k where ratio < tau
for k in range(len(ratios)):
if ratios[k] < tau:
return max(1, k + 1) # k + 1 because we want rank after the drop
# If no drop below tau, return max rank
return min(len(S), len(S) - 1)
def save_to_file(file_name, model, state_dict, dtype, metadata=None):
if dtype is not None:
for key in list(state_dict.keys()):
if isinstance(state_dict[key], torch.Tensor):
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
dynamic_method=None,
dynamic_param=None,
verbose=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time"
v_parameterization = v2 if v_parameterization is None else v_parameterization
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision) if save_precision else torch.float
work_device = "cpu"
# Load models
if not sdxl:
logger.info(f"Loading original SD model: {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
if load_dtype:
text_encoder_o.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SD model: {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
if load_dtype:
text_encoder_t.to(load_dtype)
unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
device_org = load_original_model_to or "cpu"
device_tuned = load_tuned_model_to or "cpu"
logger.info(f"Loading original SDXL model: {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype:
text_encoder_o1.to(load_dtype)
text_encoder_o2.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SDXL model: {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype:
text_encoder_t1.to(load_dtype)
text_encoder_t2.to(load_dtype)
unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# Create LoRA network
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {}
# Define a small initial dimension for memory efficiency
init_dim = 4 # Small value to minimize memory usage
# Create LoRA networks with minimal dimension
lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)"
# Compute differences
diffs = {}
text_encoder_different = False
for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}")
diffs[lora_name] = diff
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
logger.warning("Text encoders are identical. Extracting U-Net only.")
lora_network_o.text_encoder_loras = []
diffs.clear()
for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
diffs[lora_name] = diff
del lora_network_t, unet_t
# Filter relevant modules
lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
# Extract and resize LoRA using SVD
logger.info("Extracting and resizing LoRA via SVD")
lora_weights = {}
with torch.no_grad():
for lora_name in tqdm(lora_names):
mat = diffs[lora_name]
if device:
mat = mat.to(device)
mat = mat.to(torch.float)
conv2d = len(mat.size()) == 4
kernel_size = mat.size()[2:4] if conv2d else None
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = mat.size()[0:2]
if conv2d:
mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
# Determine rank
max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
if dynamic_method:
if S[0] <= MIN_SV:
rank = 1
elif dynamic_method == "sv_ratio":
rank = index_sv_ratio(S, dynamic_param)
elif dynamic_method == "sv_cumulative":
rank = index_sv_cumulative(S, dynamic_param)
elif dynamic_method == "sv_fro":
rank = index_sv_fro(S, dynamic_param)
elif dynamic_method == "sv_knee":
rank = index_sv_knee_improved(S, MIN_SV) # Pass MIN_SV or a specific threshold
elif dynamic_method == "sv_cumulative_knee": # New method
rank = index_sv_cumulative_knee(S, MIN_SV) # Pass MIN_SV or a specific threshold
elif dynamic_method == "sv_rel_decrease":
rank = index_sv_rel_decrease(S, dynamic_param)
rank = min(rank, max_rank, in_dim, out_dim)
else:
rank = min(max_rank, in_dim, out_dim)
# Truncate SVD components
U = U[:, :rank] @ torch.diag(S[:rank])
Vh = Vh[:rank, :]
# Clamp values
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
U = U.clamp(-hi_val, hi_val)
Vh = Vh.clamp(-hi_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, *kernel_size)
U = U.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U, Vh)
# Verbose output
if verbose:
s_sum = float(torch.sum(S))
s_rank = float(torch.sum(S[:rank]))
fro = float(torch.sqrt(torch.sum(S.pow(2))))
fro_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2))))
ratio = S[0] / S[rank - 1] if rank > 1 else float('inf')
logger.info(f"{lora_name:75} | sum(S) retained: {s_rank/s_sum:.1%}, fro retained: {fro_rank/fro:.1%}, max ratio: {ratio:.1f}, rank: {rank}")
# Create state dict
lora_sd = {}
for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype)
# Load and save LoRA
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoders_o, unet_o)
info = lora_network_save.load_state_dict(lora_sd)
logger.info(f"Loaded extracted and resized LoRA weights: {info}")
os.makedirs(os.path.dirname(save_to), exist_ok=True)
# Metadata
net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {}
metadata = {
"ss_v2": str(v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": str(dim) if not dynamic_method else "Dynamic",
"ss_network_alpha": str(float(dim)) if not dynamic_method else "Dynamic",
"ss_network_args": json.dumps(net_kwargs),
}
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
metadata.update(sai_metadata)
save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata)
logger.info(f"LoRA saved to: {save_to}")
def setup_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)")
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models")
parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA")
parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)")
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3")
parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)")
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract")
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata")
parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)")
parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)")
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info")
parser.add_argument(
"--dynamic_method",
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], # Added "sv_cumulative_knee"
help="Dynamic rank reduction method"
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
if args.dynamic_method and not args.dynamic_param:
raise ValueError("Dynamic method requires a dynamic parameter")
svd(**vars(args))

View File

@ -0,0 +1,175 @@
import torch
from safetensors.torch import load_file, save_file
from collections import OrderedDict
import os
import argparse # Import argparse
def extract_model_differences(base_model_path, finetuned_model_path, output_delta_path=None, save_dtype_str="float32"):
"""
Calculates the difference between the state dictionaries of a fine-tuned model
and a base model.
Args:
base_model_path (str): Path to the base model .safetensors file.
finetuned_model_path (str): Path to the fine-tuned model .safetensors file.
output_delta_path (str, optional): Path to save the resulting delta weights
.safetensors file. If None, not saved.
save_dtype_str (str, optional): Data type to save the delta weights ('float32', 'float16', 'bfloat16').
Defaults to 'float32'.
Returns:
OrderedDict: A state dictionary containing the delta weights.
Returns None if loading fails or other critical errors.
"""
print(f"Loading base model from: {base_model_path}")
try:
# Ensure model is loaded to CPU to avoid CUDA issues if not needed for diffing
base_state_dict = load_file(base_model_path, device="cpu")
print(f"Base model loaded. Found {len(base_state_dict)} tensors.")
except Exception as e:
print(f"Error loading base model: {e}")
return None
print(f"\nLoading fine-tuned model from: {finetuned_model_path}")
try:
finetuned_state_dict = load_file(finetuned_model_path, device="cpu")
print(f"Fine-tuned model loaded. Found {len(finetuned_state_dict)} tensors.")
except Exception as e:
print(f"Error loading fine-tuned model: {e}")
return None
delta_state_dict = OrderedDict()
diff_count = 0
skipped_count = 0
error_count = 0
unique_to_finetuned_count = 0
unique_to_base_count = 0
print("\nCalculating differences...")
# Keys in finetuned model
finetuned_keys = set(finetuned_state_dict.keys())
base_keys = set(base_state_dict.keys())
common_keys = finetuned_keys.intersection(base_keys)
keys_only_in_finetuned = finetuned_keys - base_keys
keys_only_in_base = base_keys - finetuned_keys
for key in common_keys:
ft_tensor = finetuned_state_dict[key]
base_tensor = base_state_dict[key]
if not (ft_tensor.is_floating_point() and base_tensor.is_floating_point()):
# print(f"Skipping key '{key}': Non-floating point tensors (FT: {ft_tensor.dtype}, Base: {base_tensor.dtype}).")
skipped_count += 1
continue
if ft_tensor.shape != base_tensor.shape:
print(f"Skipping key '{key}': Shape mismatch (FT: {ft_tensor.shape}, Base: {base_tensor.shape}).")
skipped_count += 1
continue
try:
# Calculate difference in float32 for precision, then cast to save_dtype
delta_tensor = ft_tensor.to(dtype=torch.float32) - base_tensor.to(dtype=torch.float32)
delta_state_dict[key] = delta_tensor
diff_count += 1
except Exception as e:
print(f"Error calculating difference for key '{key}': {e}")
error_count += 1
for key in keys_only_in_finetuned:
print(f"Warning: Key '{key}' (Shape: {finetuned_state_dict[key].shape}, Dtype: {finetuned_state_dict[key].dtype}) is present in fine-tuned model but not in base model. Storing as is.")
delta_state_dict[key] = finetuned_state_dict[key] # Store the tensor from the finetuned model
unique_to_finetuned_count += 1
if keys_only_in_base:
print(f"\nWarning: {len(keys_only_in_base)} key(s) are present only in the base model and will not be in the delta file.")
for key in list(keys_only_in_base)[:5]: # Print first 5 as examples
print(f" - Example key only in base: {key}")
if len(keys_only_in_base) > 5:
print(f" ... and {len(keys_only_in_base) - 5} more.")
print(f"\nDifference calculation complete.")
print(f" {diff_count} layers successfully diffed.")
print(f" {unique_to_finetuned_count} layers unique to fine-tuned model (added as is).")
print(f" {skipped_count} common layers skipped (shape/type mismatch).")
print(f" {error_count} common layers had errors during diffing.")
if output_delta_path and delta_state_dict:
save_dtype = torch.float32 # Default
if save_dtype_str == "float16":
save_dtype = torch.float16
elif save_dtype_str == "bfloat16":
save_dtype = torch.bfloat16
elif save_dtype_str != "float32":
print(f"Warning: Invalid save_dtype '{save_dtype_str}'. Defaulting to float32.")
save_dtype_str = "float32" # for print message
print(f"\nPreparing to save delta weights with dtype: {save_dtype_str}")
final_save_dict = OrderedDict()
for k, v_tensor in delta_state_dict.items():
if v_tensor.is_floating_point():
final_save_dict[k] = v_tensor.to(dtype=save_dtype)
else:
final_save_dict[k] = v_tensor # Keep non-float as is (e.g. int tensors if any)
try:
save_file(final_save_dict, output_delta_path)
print(f"Delta weights saved to: {output_delta_path}")
except Exception as e:
print(f"Error saving delta weights: {e}")
import traceback
traceback.print_exc()
return delta_state_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract weight differences between a fine-tuned and a base SDXL model.")
parser.add_argument("base_model_path", type=str, help="File path for the BASE SDXL model (.safetensors).")
parser.add_argument("finetuned_model_path", type=str, help="File path for the FINE-TUNED SDXL model (.safetensors).")
parser.add_argument("--output_path", type=str, default=None,
help="Optional: File path to save the delta weights (.safetensors). "
"If not provided, defaults to 'model_deltas/delta_[finetuned_model_name].safetensors'.")
parser.add_argument("--save_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"],
help="Data type for saving the delta weights. Choose from 'float32', 'float16', 'bfloat16'. "
"Defaults to 'float32'.")
args = parser.parse_args()
print("--- Model Difference Extraction Script ---")
if not os.path.exists(args.base_model_path):
print(f"Error: Base model file not found at {args.base_model_path}")
exit(1)
if not os.path.exists(args.finetuned_model_path):
print(f"Error: Fine-tuned model file not found at {args.finetuned_model_path}")
exit(1)
output_delta_file = args.output_path
if output_delta_file is None:
output_dir = "model_deltas"
os.makedirs(output_dir, exist_ok=True)
finetuned_basename = os.path.splitext(os.path.basename(args.finetuned_model_path))[0]
output_delta_file = os.path.join(output_dir, f"delta_{finetuned_basename}.safetensors")
# Ensure the output directory exists if a full path is given
if output_delta_file:
output_dir_for_file = os.path.dirname(output_delta_file)
if output_dir_for_file and not os.path.exists(output_dir_for_file):
os.makedirs(output_dir_for_file, exist_ok=True)
differences = extract_model_differences(
args.base_model_path,
args.finetuned_model_path,
output_delta_path=output_delta_file,
save_dtype_str=args.save_dtype
)
if differences:
print(f"\nExtraction process finished. {len(differences)} total keys in the delta state_dict.")
else:
print("\nCould not extract differences due to errors during model loading.")