mirror of https://github.com/bmaltais/kohya_ss
wip
parent
be30140562
commit
ece84c6d06
|
|
@ -0,0 +1,159 @@
|
|||
import safetensors.torch
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
import sys # To redirect stdout
|
||||
import traceback
|
||||
|
||||
class Logger(object):
|
||||
def __init__(self, filename="loha_analysis_output.txt"):
|
||||
self.terminal = sys.stdout
|
||||
self.log = open(filename, "w", encoding='utf-8')
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
|
||||
def flush(self):
|
||||
# This flush method is needed for python 3 compatibility.
|
||||
# This handles the flush command, which shutil.copytree or os.system uses.
|
||||
self.terminal.flush()
|
||||
self.log.flush()
|
||||
|
||||
def close(self):
|
||||
self.log.close()
|
||||
|
||||
def analyze_safetensors_file(filepath, output_filename="loha_analysis_output.txt"):
|
||||
"""
|
||||
Analyzes a .safetensors file to extract and print its metadata
|
||||
and tensor information (keys, shapes, dtypes) to a file.
|
||||
"""
|
||||
original_stdout = sys.stdout
|
||||
logger = Logger(filename=output_filename)
|
||||
sys.stdout = logger
|
||||
|
||||
try:
|
||||
print(f"--- Analyzing: {filepath} ---\n")
|
||||
print(f"--- Output will be saved to: {output_filename} ---\n")
|
||||
|
||||
# Load the tensors to get their structure
|
||||
state_dict = safetensors.torch.load_file(filepath, device="cpu") # Load to CPU to avoid potential CUDA issues
|
||||
|
||||
print("--- Tensor Information ---")
|
||||
if not state_dict:
|
||||
print("No tensors found in the state dictionary.")
|
||||
else:
|
||||
# Sort keys for consistent output
|
||||
sorted_keys = sorted(state_dict.keys())
|
||||
current_module_prefix = ""
|
||||
|
||||
# First, identify all unique module prefixes for better grouping
|
||||
module_prefixes = sorted(list(set([".".join(key.split(".")[:-1]) for key in sorted_keys if "." in key])))
|
||||
|
||||
for prefix in module_prefixes:
|
||||
if not prefix: # Skip keys that don't seem to be part of a module (e.g. global metadata tensors if any)
|
||||
continue
|
||||
print(f"\nModule: {prefix}")
|
||||
for key in sorted_keys:
|
||||
if key.startswith(prefix + "."):
|
||||
tensor = state_dict[key]
|
||||
print(f" - Key: {key}")
|
||||
print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}") # Output shape as list for clarity
|
||||
if key.endswith((".alpha", ".dim")):
|
||||
try:
|
||||
value = tensor.item()
|
||||
# Check if value is float and format if it is
|
||||
if isinstance(value, float):
|
||||
print(f" Value: {value:.8f}") # Format float to a certain precision
|
||||
else:
|
||||
print(f" Value: {value}")
|
||||
except Exception as e:
|
||||
print(f" Value: Could not extract scalar value ({tensor}, error: {e})")
|
||||
elif tensor.numel() < 10: # Print small tensors' values
|
||||
print(f" Values (first few): {tensor.flatten()[:10].tolist()}")
|
||||
|
||||
|
||||
# Print keys that might not fit the module pattern (e.g., older formats or single tensors)
|
||||
print("\n--- Other Tensor Keys (if any, not fitting typical module.parameter pattern) ---")
|
||||
other_keys_found = False
|
||||
for key in sorted_keys:
|
||||
if not any(key.startswith(p + ".") for p in module_prefixes if p):
|
||||
other_keys_found = True
|
||||
tensor = state_dict[key]
|
||||
print(f" - Key: {key}")
|
||||
print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}")
|
||||
if key.endswith((".alpha", ".dim")) or tensor.numel() == 1:
|
||||
try:
|
||||
value = tensor.item()
|
||||
if isinstance(value, float):
|
||||
print(f" Value: {value:.8f}")
|
||||
else:
|
||||
print(f" Value: {value}")
|
||||
except Exception as e:
|
||||
print(f" Value: Could not extract scalar value ({tensor}, error: {e})")
|
||||
|
||||
if not other_keys_found:
|
||||
print("No other keys found.")
|
||||
|
||||
print(f"\nTotal tensor keys found: {len(state_dict)}")
|
||||
|
||||
print("\n--- Metadata (from safetensors header) ---")
|
||||
metadata_content = OrderedDict()
|
||||
malformed_metadata_keys = []
|
||||
try:
|
||||
# Use safe_open to access the metadata separately
|
||||
with safetensors.safe_open(filepath, framework="pt", device="cpu") as f:
|
||||
metadata_keys = f.metadata()
|
||||
if metadata_keys is None:
|
||||
print("No metadata dictionary found in the file header (f.metadata() returned None).")
|
||||
else:
|
||||
for k in metadata_keys.keys():
|
||||
try:
|
||||
metadata_content[k] = metadata_keys.get(k)
|
||||
except Exception as e:
|
||||
malformed_metadata_keys.append((k, str(e)))
|
||||
metadata_content[k] = f"[Error reading value: {e}]"
|
||||
except Exception as e:
|
||||
print(f"Could not open or read metadata using safe_open: {e}")
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
|
||||
if not metadata_content and not malformed_metadata_keys:
|
||||
print("No metadata content extracted.")
|
||||
else:
|
||||
for key, value in metadata_content.items():
|
||||
print(f"- {key}: {value}")
|
||||
if key == "ss_network_args" and value and not value.startswith("[Error"):
|
||||
try:
|
||||
parsed_args = json.loads(value)
|
||||
print(" Parsed ss_network_args:")
|
||||
for arg_key, arg_value in parsed_args.items():
|
||||
print(f" - {arg_key}: {arg_value}")
|
||||
except json.JSONDecodeError:
|
||||
print(" (ss_network_args is not a valid JSON string)")
|
||||
if malformed_metadata_keys:
|
||||
print("\n--- Malformed Metadata Keys (could not be read) ---")
|
||||
for key, error_msg in malformed_metadata_keys:
|
||||
print(f"- {key}: Error: {error_msg}")
|
||||
|
||||
print("\n--- End of Analysis ---")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n!!! An error occurred during analysis !!!")
|
||||
print(str(e))
|
||||
traceback.print_exc(file=sys.stdout) # Print full traceback to the log file
|
||||
finally:
|
||||
sys.stdout = original_stdout # Restore standard output
|
||||
logger.close()
|
||||
print(f"\nAnalysis complete. Output saved to: {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_file_path = input("Enter the path to your working LoHA .safetensors file: ")
|
||||
output_file_name = "loha_analysis_results.txt" # You can change this default
|
||||
|
||||
# Suggest a default output name based on input file if desired
|
||||
# import os
|
||||
# base_name = os.path.splitext(os.path.basename(input_file_path))[0]
|
||||
# output_file_name = f"{base_name}_analysis.txt"
|
||||
|
||||
print(f"The analysis will be saved to: {output_file_name}")
|
||||
analyze_safetensors_file(input_file_path, output_filename=output_file_name)
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
|
||||
# --- Script Configuration ---
|
||||
# This script generates a minimal, non-functional LoHA (LyCORIS Hadamard Product Adaptation)
|
||||
# .safetensors file, designed to be structurally compatible with ComfyUI and
|
||||
# based on the analysis of a working SDXL LoHA file.
|
||||
|
||||
# --- Global LoHA Parameters (mimicking metadata from your working file) ---
|
||||
# These can be overridden per layer if needed for more complex dummies.
|
||||
# From your metadata: ss_network_dim: 32, ss_network_alpha: 32.0
|
||||
DEFAULT_RANK = 32
|
||||
DEFAULT_ALPHA = 32.0
|
||||
CONV_RANK = 8 # From your ss_network_args: "conv_dim": "8"
|
||||
CONV_ALPHA = 4.0 # From your ss_network_args: "conv_alpha": "4"
|
||||
|
||||
# Define example target layers.
|
||||
# We'll use names and dimensions that are representative of SDXL and your analysis.
|
||||
# Format: (layer_name, in_dim, out_dim, rank, alpha)
|
||||
# Note: For Conv2d, in_dim = in_channels, out_dim = out_channels.
|
||||
# The hada_wX_b for conv will have shape (rank, in_channels * kernel_h * kernel_w)
|
||||
# For simplicity in this dummy, we'll primarily focus on linear/attention
|
||||
# layers first, and then add one representative conv-like layer.
|
||||
|
||||
# Layer that previously caused error:
|
||||
# "ERROR loha diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight shape '[640, 640]' is invalid..."
|
||||
# This corresponds to lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_v
|
||||
# In your working LoHA, similar attention layers (e.g., *_attn1_to_k) have out_dim=640, in_dim=640, rank=32, alpha=32.0
|
||||
|
||||
EXAMPLE_LAYERS_CONFIG = [
|
||||
# UNet Attention Layers (mimicking typical SDXL structure)
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_q", # Query
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_k", # Key
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_v", # Value - this one errored previously
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_out_0", # Output Projection
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
# A deeper UNet attention block
|
||||
{
|
||||
"name": "lora_unet_middle_block_1_transformer_blocks_0_attn1_to_q",
|
||||
"in_dim": 1280, "out_dim": 1280, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_middle_block_1_transformer_blocks_0_attn1_to_out_0",
|
||||
"in_dim": 1280, "out_dim": 1280, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
# Example UNet "Convolutional" LoHA (e.g., for a ResBlock's conv layer)
|
||||
# Based on your lora_unet_input_blocks_1_0_in_layers_2 which had rank 8, alpha 4
|
||||
# Assuming original conv was Conv2d(320, 320, kernel_size=3, padding=1)
|
||||
{
|
||||
"name": "lora_unet_input_blocks_1_0_in_layers_2",
|
||||
"in_dim": 320, # in_channels
|
||||
"out_dim": 320, # out_channels
|
||||
"rank": CONV_RANK,
|
||||
"alpha": CONV_ALPHA,
|
||||
"is_conv": True,
|
||||
"kernel_size": 3 # Assume 3x3 kernel for this example
|
||||
},
|
||||
# Example Text Encoder Layer (CLIP-L, first one from your list)
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1 (original Linear(768, 3072))
|
||||
{
|
||||
"name": "lora_te1_text_model_encoder_layers_0_mlp_fc1",
|
||||
"in_dim": 768, "out_dim": 3072, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
]
|
||||
|
||||
# Use bfloat16 as seen in the analysis
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
# --- Main Script ---
|
||||
def create_dummy_loha_file(filepath="dummy_loha_corrected.safetensors"):
|
||||
"""
|
||||
Creates and saves a dummy LoHA .safetensors file with corrected structure
|
||||
and metadata based on analysis of a working file.
|
||||
"""
|
||||
state_dict = OrderedDict()
|
||||
metadata = OrderedDict()
|
||||
|
||||
print(f"Generating dummy LoHA with default rank={DEFAULT_RANK}, default alpha={DEFAULT_ALPHA}")
|
||||
print(f"Targeting DTYPE: {DTYPE}")
|
||||
|
||||
for layer_config in EXAMPLE_LAYERS_CONFIG:
|
||||
layer_name = layer_config["name"]
|
||||
in_dim = layer_config["in_dim"]
|
||||
out_dim = layer_config["out_dim"]
|
||||
rank = layer_config["rank"]
|
||||
alpha = layer_config["alpha"]
|
||||
is_conv = layer_config["is_conv"]
|
||||
|
||||
print(f"Processing layer: {layer_name} (in: {in_dim}, out: {out_dim}, rank: {rank}, alpha: {alpha}, conv: {is_conv})")
|
||||
|
||||
# --- LoHA Tensor Shapes Correction based on analysis ---
|
||||
# hada_wX_a (maps to original layer's out_features): (out_dim, rank)
|
||||
# hada_wX_b (maps from original layer's in_features): (rank, in_dim)
|
||||
# For Convolutions, in_dim refers to in_channels, out_dim to out_channels.
|
||||
# For hada_wX_b in conv, the effective input dimension includes kernel size.
|
||||
|
||||
if is_conv:
|
||||
kernel_size = layer_config.get("kernel_size", 3) # Default to 3x3 if not specified
|
||||
# This is for LoHA types that decompose the full kernel (e.g. LyCORIS full conv):
|
||||
# (rank, in_channels * kernel_h * kernel_w)
|
||||
# For simpler conv LoHA (like applying to 1x1 equivalent), it might just be (rank, in_channels)
|
||||
# The analysis for `lora_unet_input_blocks_1_0_in_layers_2` showed hada_w1_b as [8, 2880]
|
||||
# where in_dim=320, rank=8. 2880 = 320 * 9 (i.e., in_channels * kernel_h * kernel_w for 3x3)
|
||||
# This indicates a full kernel decomposition.
|
||||
eff_in_dim_conv_b = in_dim * kernel_size * kernel_size
|
||||
|
||||
hada_w1_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w1_b = torch.randn(rank, eff_in_dim_conv_b, dtype=DTYPE) * 0.01
|
||||
hada_w2_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w2_b = torch.randn(rank, eff_in_dim_conv_b, dtype=DTYPE) * 0.01
|
||||
else: # Linear layers
|
||||
hada_w1_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w1_b = torch.randn(rank, in_dim, dtype=DTYPE) * 0.01
|
||||
hada_w2_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w2_b = torch.randn(rank, in_dim, dtype=DTYPE) * 0.01
|
||||
|
||||
state_dict[f"{layer_name}.hada_w1_a"] = hada_w1_a
|
||||
state_dict[f"{layer_name}.hada_w1_b"] = hada_w1_b
|
||||
state_dict[f"{layer_name}.hada_w2_a"] = hada_w2_a
|
||||
state_dict[f"{layer_name}.hada_w2_b"] = hada_w2_b
|
||||
|
||||
# Alpha tensor (scalar)
|
||||
state_dict[f"{layer_name}.alpha"] = torch.tensor(float(alpha), dtype=DTYPE)
|
||||
|
||||
# IMPORTANT: No per-module ".dim" tensor, as per analysis of working file.
|
||||
# Rank is implicit in weight shapes and global metadata.
|
||||
|
||||
# --- Metadata (mimicking the working LoHA file) ---
|
||||
metadata["ss_network_module"] = "lycoris.kohya"
|
||||
metadata["ss_network_dim"] = str(DEFAULT_RANK) # Global/default rank
|
||||
metadata["ss_network_alpha"] = str(DEFAULT_ALPHA) # Global/default alpha
|
||||
metadata["ss_network_algo"] = "loha" # Also specified inside ss_network_args by convention
|
||||
|
||||
# Mimic ss_network_args from your file
|
||||
network_args = {
|
||||
"conv_dim": str(CONV_RANK),
|
||||
"conv_alpha": str(CONV_ALPHA),
|
||||
"algo": "loha",
|
||||
# Add other args from your file if they seem relevant for loading structure,
|
||||
# but these are the most critical for type/rank.
|
||||
"dropout": "0.0", # From your file, though value might not matter for dummy
|
||||
"rank_dropout": "0", # from your file
|
||||
"module_dropout": "0", # from your file
|
||||
"use_tucker": "False", # from your file
|
||||
"use_scalar": "False", # from your file
|
||||
"rank_dropout_scale": "False", # from your file
|
||||
"train_norm": "False" # from your file
|
||||
}
|
||||
metadata["ss_network_args"] = json.dumps(network_args)
|
||||
|
||||
# Other potentially useful metadata from your working file (optional for basic loading)
|
||||
metadata["ss_sd_model_name"] = "sd_xl_base_1.0.safetensors" # Example base model
|
||||
metadata["ss_resolution"] = "(1024,1024)" # Example, format might vary
|
||||
metadata["modelspec.sai_model_spec"] = "1.0.0"
|
||||
metadata["modelspec.implementation"] = "https_//github.com/Stability-AI/generative-models" # fixed typo
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base/lora" # Even for LoHA, this is often used
|
||||
metadata["ss_mixed_precision"] = "bf16"
|
||||
metadata["ss_note"] = "Dummy LoHA (corrected) for ComfyUI validation. Not trained."
|
||||
|
||||
|
||||
# --- Save the State Dictionary with Metadata ---
|
||||
try:
|
||||
save_file(state_dict, filepath, metadata=metadata)
|
||||
print(f"\nSuccessfully saved dummy LoHA file to: {filepath}")
|
||||
print("\nFile structure (tensor keys):")
|
||||
for key in state_dict.keys():
|
||||
print(f"- {key}: shape {state_dict[key].shape}, dtype {state_dict[key].dtype}")
|
||||
print("\nMetadata:")
|
||||
for key, value in metadata.items():
|
||||
print(f"- {key}: {value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError saving file: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_dummy_loha_file()
|
||||
|
||||
# --- Verification Note for ComfyUI ---
|
||||
# 1. Place `dummy_loha_corrected.safetensors` into `ComfyUI/models/loras/`.
|
||||
# 2. Load an SDXL base model in ComfyUI.
|
||||
# 3. Add a "Load LoRA" node and select `dummy_loha_corrected.safetensors`.
|
||||
# 4. Connect the LoRA node between the checkpoint loader and the KSampler.
|
||||
#
|
||||
# Expected outcome:
|
||||
# - ComfyUI should load the file without "key not loaded" or "dimension mismatch" errors
|
||||
# for the layers defined in EXAMPLE_LAYERS_CONFIG.
|
||||
# - The LoRA node should correctly identify it as a LoHA/LyCORIS model.
|
||||
# - If you have layers in your SDXL model that match the names in EXAMPLE_LAYERS_CONFIG,
|
||||
# ComfyUI will attempt to apply these (random) weights.
|
||||
|
|
@ -0,0 +1,397 @@
|
|||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from safetensors.torch import save_file, load_file
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
import signal
|
||||
import sys
|
||||
|
||||
# --- Global variables (ensure they are defined as before) ---
|
||||
extracted_loha_state_dict_global = OrderedDict()
|
||||
layer_optimization_stats_global = []
|
||||
args_global = None
|
||||
processed_layers_count_global = 0
|
||||
skipped_identical_count_global = 0
|
||||
skipped_other_count_global = 0
|
||||
keys_scanned_this_run_global = 0 # This will track scans for the outer pbar
|
||||
save_attempted_on_interrupt = False
|
||||
outer_pbar_global = None
|
||||
|
||||
# --- optimize_loha_for_layer and get_module_shape_info_from_weight (UNCHANGED from your last version) ---
|
||||
def optimize_loha_for_layer(
|
||||
layer_name: str, delta_W_target: torch.Tensor, out_dim: int, in_dim_effective: int,
|
||||
k_h: int, k_w: int, rank: int, initial_alpha_val: float, lr: float = 1e-3,
|
||||
max_iterations: int = 1000, min_iterations: int = 100, target_loss: float = None,
|
||||
weight_decay: float = 1e-4, device: str = 'cuda', dtype: torch.dtype = torch.float32,
|
||||
is_conv: bool = True, verbose_layer_debug: bool = False
|
||||
):
|
||||
delta_W_target = delta_W_target.to(device, dtype=dtype)
|
||||
|
||||
if is_conv:
|
||||
k_ops = k_h * k_w
|
||||
hada_w1_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype))
|
||||
hada_w1_b = nn.Parameter(torch.empty(rank, in_dim_effective * k_ops, device=device, dtype=dtype))
|
||||
hada_w2_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype))
|
||||
hada_w2_b = nn.Parameter(torch.empty(rank, in_dim_effective * k_ops, device=device, dtype=dtype))
|
||||
else: # Linear
|
||||
hada_w1_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype))
|
||||
hada_w1_b = nn.Parameter(torch.empty(rank, in_dim_effective, device=device, dtype=dtype))
|
||||
hada_w2_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype))
|
||||
hada_w2_b = nn.Parameter(torch.empty(rank, in_dim_effective, device=device, dtype=dtype))
|
||||
|
||||
nn.init.kaiming_uniform_(hada_w1_a, a=math.sqrt(5))
|
||||
nn.init.normal_(hada_w1_b, std=0.02)
|
||||
nn.init.kaiming_uniform_(hada_w2_a, a=math.sqrt(5))
|
||||
nn.init.normal_(hada_w2_b, std=0.02)
|
||||
alpha_param = nn.Parameter(torch.tensor(initial_alpha_val, device=device, dtype=dtype))
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
[hada_w1_a, hada_w1_b, hada_w2_a, hada_w2_b, alpha_param], lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
patience_epochs = max(10, int(max_iterations * 0.05))
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience_epochs, factor=0.5, min_lr=1e-7, verbose=False)
|
||||
|
||||
iter_pbar = tqdm(range(max_iterations), desc=f"Opt: {layer_name}", leave=False, dynamic_ncols=True, position=1)
|
||||
|
||||
final_loss = float('inf')
|
||||
stopped_early_by_loss = False
|
||||
iterations_actually_done = 0
|
||||
|
||||
for i in iter_pbar:
|
||||
iterations_actually_done = i + 1
|
||||
if save_attempted_on_interrupt:
|
||||
print(f"\n Interrupt during opt of {layer_name}. Stopping this layer after {i} iters.")
|
||||
break
|
||||
|
||||
optimizer.zero_grad()
|
||||
eff_alpha_scale = alpha_param / rank
|
||||
|
||||
if is_conv:
|
||||
term1_flat = hada_w1_a @ hada_w1_b; term1_reshaped = term1_flat.view(out_dim, in_dim_effective, k_h, k_w)
|
||||
term2_flat = hada_w2_a @ hada_w2_b; term2_reshaped = term2_flat.view(out_dim, in_dim_effective, k_h, k_w)
|
||||
delta_W_loha = eff_alpha_scale * term1_reshaped * term2_reshaped
|
||||
else:
|
||||
term1 = hada_w1_a @ hada_w1_b; term2 = hada_w2_a @ hada_w2_b
|
||||
delta_W_loha = eff_alpha_scale * term1 * term2
|
||||
|
||||
loss = F.mse_loss(delta_W_loha, delta_W_target)
|
||||
final_loss = loss.item()
|
||||
loss.backward(); optimizer.step(); scheduler.step(loss)
|
||||
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
iter_pbar.set_postfix_str(f"Loss={final_loss:.3e}, AlphaP={alpha_param.item():.2f}, LR={current_lr:.1e}", refresh=True)
|
||||
|
||||
if verbose_layer_debug and (i == 0 or (i + 1) % (max_iterations // 10 if max_iterations >= 10 else 1) == 0 or i == max_iterations - 1):
|
||||
iter_pbar.write(f" Debug {layer_name} - Iter {i+1}/{max_iterations}: Loss: {final_loss:.6e}, LR: {current_lr:.2e}, AlphaP: {alpha_param.item():.4f}")
|
||||
|
||||
if target_loss is not None and i >= min_iterations -1 :
|
||||
if final_loss <= target_loss:
|
||||
if verbose_layer_debug or (args_global and args_global.verbose):
|
||||
iter_pbar.write(f" Target loss {target_loss:.2e} reached for {layer_name} at iter {i+1}.")
|
||||
stopped_early_by_loss = True; break
|
||||
|
||||
if not save_attempted_on_interrupt :
|
||||
iter_pbar.set_description_str(f"Opt: {layer_name} (Done)")
|
||||
iter_pbar.set_postfix_str(f"FinalLoss={final_loss:.2e}, It={iterations_actually_done}{', EarlyStop' if stopped_early_by_loss else ''}")
|
||||
iter_pbar.close()
|
||||
|
||||
if save_attempted_on_interrupt and not stopped_early_by_loss and iterations_actually_done < max_iterations:
|
||||
return {'final_loss': final_loss, 'stopped_early': False, 'iterations_done': iterations_actually_done, 'interrupted_mid_layer': True}
|
||||
return {
|
||||
'hada_w1_a': hada_w1_a.data.cpu().contiguous(), 'hada_w1_b': hada_w1_b.data.cpu().contiguous(),
|
||||
'hada_w2_a': hada_w2_a.data.cpu().contiguous(), 'hada_w2_b': hada_w2_b.data.cpu().contiguous(),
|
||||
'alpha': alpha_param.data.cpu().contiguous(), 'final_loss': final_loss,
|
||||
'stopped_early': stopped_early_by_loss, 'iterations_done': iterations_actually_done,
|
||||
'interrupted_mid_layer': False
|
||||
}
|
||||
|
||||
def get_module_shape_info_from_weight(weight_tensor: torch.Tensor):
|
||||
if len(weight_tensor.shape) == 4: is_conv = True; out_dim, in_dim_effective, k_h, k_w = weight_tensor.shape; groups = 1; return out_dim, in_dim_effective, k_h, k_w, groups, is_conv
|
||||
elif len(weight_tensor.shape) == 2: is_conv = False; out_dim, in_dim = weight_tensor.shape; return out_dim, in_dim, None, None, 1, is_conv
|
||||
return None
|
||||
|
||||
# --- perform_graceful_save (UNCHANGED from previous version) ---
|
||||
def perform_graceful_save(output_path_override=None):
|
||||
global extracted_loha_state_dict_global, layer_optimization_stats_global, args_global
|
||||
global processed_layers_count_global, save_attempted_on_interrupt
|
||||
global skipped_identical_count_global, skipped_other_count_global, keys_scanned_this_run_global
|
||||
|
||||
if not extracted_loha_state_dict_global: print("No layers were processed enough to save."); return
|
||||
args_to_use = args_global
|
||||
if not args_to_use: print("Error: Global args not available for saving metadata."); return
|
||||
final_save_path = output_path_override if output_path_override else args_to_use.save_to
|
||||
if args_to_use.save_weights_dtype == "fp16": final_save_dtype_torch = torch.float16
|
||||
elif args_to_use.save_weights_dtype == "bf16": final_save_dtype_torch = torch.bfloat16
|
||||
else: final_save_dtype_torch = torch.float32
|
||||
final_state_dict_to_save = OrderedDict()
|
||||
for k, v_tensor in extracted_loha_state_dict_global.items():
|
||||
if v_tensor.is_floating_point(): final_state_dict_to_save[k] = v_tensor.to(final_save_dtype_torch)
|
||||
else: final_state_dict_to_save[k] = v_tensor
|
||||
print(f"\nAttempting to save {len(final_state_dict_to_save)} LoHA param sets for {processed_layers_count_global} layers to {final_save_path}")
|
||||
eff_global_network_alpha_val = args_to_use.initial_alpha; eff_global_network_alpha_str = f"{eff_global_network_alpha_val:.8f}"
|
||||
global_rank_str = str(args_to_use.rank)
|
||||
conv_rank_str = str(args_to_use.conv_rank if args_to_use.conv_rank is not None else args_to_use.rank)
|
||||
eff_conv_alpha_val = args_to_use.initial_conv_alpha; conv_alpha_str = f"{eff_conv_alpha_val:.8f}"
|
||||
network_args_dict = {
|
||||
"algo": "loha", "dim": global_rank_str, "alpha": eff_global_network_alpha_str,
|
||||
"conv_dim": conv_rank_str, "conv_alpha": conv_alpha_str,
|
||||
"dropout": str(args_to_use.dropout), "rank_dropout": str(args_to_use.rank_dropout), "module_dropout": str(args_to_use.module_dropout),
|
||||
"use_tucker": "false", "use_scalar": "false", "block_size": "1",}
|
||||
sf_metadata = {
|
||||
"ss_network_module": "lycoris.kohya", "ss_network_rank": global_rank_str,
|
||||
"ss_network_alpha": eff_global_network_alpha_str, "ss_network_algo": "loha",
|
||||
"ss_network_args": json.dumps(network_args_dict),
|
||||
"ss_comment": f"Extracted LoHA (Interrupt: {save_attempted_on_interrupt}). OptPrec: {args_to_use.precision}. SaveDtype: {args_to_use.save_weights_dtype}. ATOL: {args_to_use.atol_fp32_check}. Layers: {processed_layers_count_global}. MaxIter: {args_to_use.max_iterations}. TargetLoss: {args_to_use.target_loss}",
|
||||
"ss_base_model_name": os.path.splitext(os.path.basename(args_to_use.base_model_path))[0],
|
||||
"ss_ft_model_name": os.path.splitext(os.path.basename(args_to_use.ft_model_path))[0],
|
||||
"ss_save_weights_dtype": args_to_use.save_weights_dtype, "ss_optimization_precision": args_to_use.precision,}
|
||||
json_metadata_for_file = {
|
||||
"comfyui_lora_type": "LyCORIS_LoHa", "model_name": os.path.splitext(os.path.basename(final_save_path))[0],
|
||||
"base_model_path": args_to_use.base_model_path, "ft_model_path": args_to_use.ft_model_path,
|
||||
"loha_extraction_settings": {k: str(v) if isinstance(v, type(os.pathsep)) else v for k,v in vars(args_to_use).items()},
|
||||
"extraction_summary":{"processed_layers_count": processed_layers_count_global, "skipped_identical_count": skipped_identical_count_global, "skipped_other_count": skipped_other_count_global, "total_candidate_keys_scanned": keys_scanned_this_run_global,},
|
||||
"layer_optimization_details": layer_optimization_stats_global, "embedded_safetensors_metadata": sf_metadata, "interrupted_save": save_attempted_on_interrupt}
|
||||
if final_save_path.endswith(".safetensors"):
|
||||
try: save_file(final_state_dict_to_save, final_save_path, metadata=sf_metadata); print(f"LoHA state_dict saved to: {final_save_path}")
|
||||
except Exception as e: print(f"Error saving .safetensors file: {e}"); return
|
||||
metadata_json_file_path = os.path.splitext(final_save_path)[0] + "_extraction_metadata.json"
|
||||
try:
|
||||
with open(metadata_json_file_path, 'w') as f: json.dump(json_metadata_for_file, f, indent=4)
|
||||
print(f"Extended metadata saved to: {metadata_json_file_path}")
|
||||
except Exception as e: print(f"Could not save extended metadata JSON: {e}")
|
||||
else: print(f"Saving to .pt not fully supported with interrupt metadata.")
|
||||
|
||||
# --- handle_interrupt (UNCHANGED from previous version) ---
|
||||
def handle_interrupt(signum, frame):
|
||||
global save_attempted_on_interrupt, outer_pbar_global
|
||||
print("\n" + "="*30 + "\nCtrl+C (SIGINT) detected!\n" + "="*30)
|
||||
if save_attempted_on_interrupt: print("Save already attempted. Force exiting."); os._exit(1); return
|
||||
save_attempted_on_interrupt = True
|
||||
if outer_pbar_global: outer_pbar_global.close()
|
||||
print("Attempting to save progress for processed layers...")
|
||||
perform_graceful_save()
|
||||
print("Graceful save attempt finished. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main(cli_args):
|
||||
global args_global, extracted_loha_state_dict_global, layer_optimization_stats_global
|
||||
global processed_layers_count_global, save_attempted_on_interrupt, outer_pbar_global
|
||||
global skipped_identical_count_global, skipped_other_count_global, keys_scanned_this_run_global
|
||||
|
||||
args_global = cli_args
|
||||
signal.signal(signal.SIGINT, handle_interrupt)
|
||||
|
||||
if args_global.precision == "fp16": target_opt_dtype = torch.float16
|
||||
elif args_global.precision == "bf16": target_opt_dtype = torch.bfloat16
|
||||
else: target_opt_dtype = torch.float32
|
||||
|
||||
if args_global.save_weights_dtype == "fp16": final_save_dtype = torch.float16
|
||||
elif args_global.save_weights_dtype == "bf16": final_save_dtype = torch.bfloat16
|
||||
else: final_save_dtype = torch.float32
|
||||
|
||||
print(f"Using device: {args_global.device}, Opt Dtype: {target_opt_dtype}, Save Dtype: {final_save_dtype}")
|
||||
if args_global.target_loss: print(f"Target Loss: {args_global.target_loss:.2e} (after {args_global.min_iterations} min iters)")
|
||||
print(f"Max Iters/Layer: {args_global.max_iterations}")
|
||||
|
||||
print(f"Loading base: {args_global.base_model_path}")
|
||||
if args_global.base_model_path.endswith(".safetensors"): base_model_sd = load_file(args_global.base_model_path, device='cpu')
|
||||
else: base_model_sd = torch.load(args_global.base_model_path, map_location='cpu'); base_model_sd = base_model_sd.get('state_dict', base_model_sd)
|
||||
|
||||
print(f"Loading fine-tuned: {args_global.ft_model_path}")
|
||||
if args_global.ft_model_path.endswith(".safetensors"): ft_model_sd = load_file(args_global.ft_model_path, device='cpu')
|
||||
else: ft_model_sd = torch.load(args_global.ft_model_path, map_location='cpu'); ft_model_sd = ft_model_sd.get('state_dict', ft_model_sd)
|
||||
|
||||
extracted_loha_state_dict_global = OrderedDict()
|
||||
layer_optimization_stats_global = []
|
||||
processed_layers_count_global = 0; skipped_identical_count_global = 0; skipped_other_count_global = 0; keys_scanned_this_run_global = 0
|
||||
|
||||
all_candidate_keys = []
|
||||
for k in base_model_sd.keys():
|
||||
if k.endswith('.weight') and k in ft_model_sd and (len(base_model_sd[k].shape) == 2 or len(base_model_sd[k].shape) == 4):
|
||||
all_candidate_keys.append(k)
|
||||
elif k.endswith('.weight') and k not in ft_model_sd and args_global.verbose:
|
||||
print(f"Note: Base key '{k}' not in FT model.")
|
||||
all_candidate_keys.sort()
|
||||
total_candidate_keys_to_scan = len(all_candidate_keys)
|
||||
|
||||
print(f"Found {total_candidate_keys_to_scan} candidate '.weight' keys common to both models and of suitable shape.")
|
||||
|
||||
# The outer progress bar will now track the scan through ALL candidate keys
|
||||
outer_pbar_global = tqdm(total=total_candidates_to_scan, desc=f"Scanning Layers (Processed 0)", dynamic_ncols=True, position=0)
|
||||
|
||||
try:
|
||||
for key_name in all_candidate_keys:
|
||||
if save_attempted_on_interrupt: break
|
||||
keys_scanned_this_run_global += 1
|
||||
outer_pbar_global.update(1) # Update for every key scanned
|
||||
|
||||
if args_global.max_layers is not None and args_global.max_layers > 0 and processed_layers_count_global >= args_global.max_layers:
|
||||
if args_global.verbose: print(f"\nReached max_layers limit ({args_global.max_layers}). All further candidate layers will be skipped by the scan.")
|
||||
# No 'break' here, let the loop finish scanning so the pbar completes to 100% of total_candidates_to_scan
|
||||
# The actual optimization will be skipped by the condition above.
|
||||
# Update description to show scanning is finishing due to max_layers
|
||||
outer_pbar_global.set_description_str(f"Scan (Max Layers Reached: {processed_layers_count_global}/{args_global.max_layers}, Scanned {keys_scanned_this_run_global}/{total_candidates_to_scan})")
|
||||
skipped_other_count_global +=1 # Count as skipped_other if we don't even check identical
|
||||
continue # Continue to scan remaining keys but don't process them
|
||||
|
||||
|
||||
base_W = base_model_sd[key_name].to(dtype=torch.float32)
|
||||
ft_W = ft_model_sd[key_name].to(dtype=torch.float32)
|
||||
|
||||
if base_W.shape != ft_W.shape:
|
||||
if args_global.verbose: print(f"\nSkipping {key_name} (scan): shape mismatch.")
|
||||
skipped_other_count_global +=1; continue
|
||||
shape_info = get_module_shape_info_from_weight(base_W)
|
||||
if shape_info is None:
|
||||
if args_global.verbose: print(f"\nSkipping {key_name} (scan): not Linear or Conv2d.")
|
||||
skipped_other_count_global +=1; continue
|
||||
|
||||
delta_W_fp32 = (ft_W - base_W)
|
||||
if torch.allclose(delta_W_fp32, torch.zeros_like(delta_W_fp32), atol=args_global.atol_fp32_check):
|
||||
if args_global.verbose: print(f"\nSkipping {key_name} (scan): weights effectively identical.")
|
||||
skipped_identical_count_global += 1; continue
|
||||
|
||||
# This layer WILL be processed. Update description.
|
||||
max_layers_target_str = f"/{args_global.max_layers}" if args_global.max_layers is not None and args_global.max_layers > 0 else ""
|
||||
outer_pbar_global.set_description_str(f"Optimizing L{processed_layers_count_global + 1}{max_layers_target_str} (Scanned {keys_scanned_this_run_global}/{total_candidates_to_scan})")
|
||||
|
||||
original_module_path = key_name[:-len(".weight")]
|
||||
loha_key_prefix = ""
|
||||
if original_module_path.startswith("model.diffusion_model."): loha_key_prefix = "lora_unet_" + original_module_path[len("model.diffusion_model."):].replace(".", "_")
|
||||
elif original_module_path.startswith("conditioner.embedders.0.transformer."): loha_key_prefix = "lora_te1_" + original_module_path[len("conditioner.embedders.0.transformer."):].replace(".", "_")
|
||||
elif original_module_path.startswith("conditioner.embedders.1.model.transformer."): loha_key_prefix = "lora_te2_" + original_module_path[len("conditioner.embedders.1.model.transformer."):].replace(".", "_")
|
||||
else: loha_key_prefix = "lora_" + original_module_path.replace(".", "_")
|
||||
|
||||
if args_global.verbose: tqdm.write(f"\n Orig: {key_name} -> LoHA: {loha_key_prefix}") # Use tqdm.write for multiline
|
||||
|
||||
out_dim, in_dim_effective, k_h, k_w, _, is_conv = shape_info
|
||||
delta_W_target_for_opt = delta_W_fp32.to(dtype=target_opt_dtype)
|
||||
current_rank = args_global.conv_rank if is_conv and args_global.conv_rank is not None else args_global.rank
|
||||
current_initial_alpha = args_global.initial_conv_alpha if is_conv else args_global.initial_alpha
|
||||
|
||||
tqdm.write(f"Optimizing Layer {processed_layers_count_global + 1}{max_layers_target_str}: {loha_key_prefix} (Orig: {original_module_path}, Shp: {list(base_W.shape)}, R: {current_rank}, Alpha_init: {current_initial_alpha:.1f})")
|
||||
|
||||
|
||||
try:
|
||||
opt_results = optimize_loha_for_layer(
|
||||
layer_name=loha_key_prefix, delta_W_target=delta_W_target_for_opt,
|
||||
out_dim=out_dim, in_dim_effective=in_dim_effective, k_h=k_h, k_w=k_w, rank=current_rank,
|
||||
initial_alpha_val=current_initial_alpha, lr=args_global.lr,
|
||||
max_iterations=args_global.max_iterations, min_iterations=args_global.min_iterations,
|
||||
target_loss=args_global.target_loss, weight_decay=args_global.weight_decay,
|
||||
device=args_global.device, dtype=target_opt_dtype, is_conv=is_conv,
|
||||
verbose_layer_debug=args_global.verbose_layer_debug
|
||||
)
|
||||
|
||||
if not opt_results.get('interrupted_mid_layer'):
|
||||
for p_name, p_val in opt_results.items():
|
||||
if p_name not in ['final_loss', 'stopped_early', 'iterations_done', 'interrupted_mid_layer']:
|
||||
extracted_loha_state_dict_global[f'{loha_key_prefix}.{p_name}'] = p_val.to(final_save_dtype)
|
||||
|
||||
layer_optimization_stats_global.append({
|
||||
"name": loha_key_prefix, "original_name": original_module_path,
|
||||
"final_loss": opt_results['final_loss'], "iterations_done": opt_results['iterations_done'],
|
||||
"stopped_early_by_loss_target": opt_results['stopped_early']
|
||||
})
|
||||
tqdm.write(f" Layer {loha_key_prefix} Done. Loss: {opt_results['final_loss']:.4e}, Iters: {opt_results['iterations_done']}{', Stopped by Loss' if opt_results['stopped_early'] else ''}")
|
||||
|
||||
if args_global.use_bias:
|
||||
original_bias_key = f"{original_module_path}.bias"
|
||||
if original_bias_key in ft_model_sd and original_bias_key in base_model_sd:
|
||||
base_B = base_model_sd[original_bias_key].to(dtype=torch.float32); ft_B = ft_model_sd[original_bias_key].to(dtype=torch.float32)
|
||||
if not torch.allclose(base_B, ft_B, atol=args_global.atol_fp32_check):
|
||||
extracted_loha_state_dict_global[original_bias_key] = ft_B.cpu().to(final_save_dtype)
|
||||
if args_global.verbose: tqdm.write(f" Saved differing bias for {original_bias_key}")
|
||||
elif original_bias_key in ft_model_sd and original_bias_key not in base_model_sd:
|
||||
if args_global.verbose: tqdm.write(f" Bias {original_bias_key} in FT only. Saving.")
|
||||
extracted_loha_state_dict_global[original_bias_key] = ft_model_sd[original_bias_key].cpu().to(final_save_dtype)
|
||||
|
||||
processed_layers_count_global += 1
|
||||
# Outer pbar update is now at the start of the loop for each scanned key.
|
||||
# The description will reflect the processed count.
|
||||
else:
|
||||
if args_global.verbose: tqdm.write(f" Opt for {loha_key_prefix} interrupted; not saving params.")
|
||||
except Exception as e:
|
||||
print(f"\nError during optimization for {original_module_path} ({loha_key_prefix}): {e}")
|
||||
import traceback; traceback.print_exc()
|
||||
skipped_other_count_global +=1
|
||||
|
||||
finally:
|
||||
if outer_pbar_global:
|
||||
if outer_pbar_global.n < outer_pbar_global.total: # Fill to 100% if loop broke early
|
||||
outer_pbar_global.update(outer_pbar_global.total - outer_pbar_global.n)
|
||||
outer_pbar_global.close()
|
||||
|
||||
if not save_attempted_on_interrupt: # Normal completion
|
||||
print("\n--- Final Optimization Summary (Normal Completion) ---")
|
||||
for stat in layer_optimization_stats_global:
|
||||
print(f"Layer: {stat['name']}, Final Loss: {stat['final_loss']:.4e}, Iters: {stat['iterations_done']}{', Stopped by Loss' if stat['stopped_early_by_loss_target'] else ''}")
|
||||
print(f"\n--- Overall Summary ---")
|
||||
print(f"Total candidate weight keys: {total_candidates_to_scan}")
|
||||
print(f"Total keys scanned by loop: {keys_scanned_this_run_global}")
|
||||
print(f"Processed {processed_layers_count_global} layers for LoHA extraction.")
|
||||
print(f"Skipped {skipped_identical_count_global} layers (identical).")
|
||||
print(f"Skipped {skipped_other_count_global} layers (other reasons).")
|
||||
perform_graceful_save()
|
||||
else: # Interrupted
|
||||
print("\nProcess was interrupted. Saved data for fully completed layers.")
|
||||
|
||||
# --- __main__ block (UNCHANGED from previous version) ---
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Extract LoHA parameters by optimizing against weight differences.")
|
||||
parser.add_argument("base_model_path", type=str, help="Path to the base model state_dict file (.pt, .pth, .safetensors)")
|
||||
parser.add_argument("ft_model_path", type=str, help="Path to the fine-tuned model state_dict file (.pt, .pth, .safetensors)")
|
||||
parser.add_argument("save_to", type=str, help="Path to save the extracted LoHA file (recommended .safetensors)")
|
||||
|
||||
parser.add_argument("--rank", type=int, default=4, help="Default rank for LoHA decomposition (used for linear layers and as fallback for conv).")
|
||||
parser.add_argument("--conv_rank", type=int, default=None, help="Specific rank for convolutional LoHA layers. Defaults to --rank if not set.")
|
||||
|
||||
parser.add_argument("--initial_alpha", type=float, default=None,
|
||||
help="Global initial alpha for optimization (used for linear and as fallback for conv). Defaults to 'rank'. This is also used for 'ss_network_alpha'.")
|
||||
parser.add_argument("--initial_conv_alpha", type=float, default=None,
|
||||
help="Specific initial alpha for convolutional LoHA layers. Defaults to '--initial_alpha' or conv_rank if neither initial_alpha nor initial_conv_alpha is set, it defaults to the respective rank.")
|
||||
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate for LoHA optimization per layer.")
|
||||
parser.add_argument("--max_iterations", type=int, default=1000, help="Maximum number of optimization iterations per layer.")
|
||||
parser.add_argument("--min_iterations", type=int, default=100, help="Minimum number of optimization iterations per layer before checking target loss. Default 100.")
|
||||
parser.add_argument("--target_loss", type=float, default=None, help="Target MSE loss to achieve for stopping optimization for a layer. If None, runs for max_iterations.")
|
||||
|
||||
parser.add_argument("--weight_decay", type=float, default=1e-5, help="Weight decay for LoHA optimization.")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device for computation ('cuda' or 'cpu').")
|
||||
parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"],
|
||||
help="Computation precision for LoHA optimization. This is 'ss_mixed_precision'. Default: fp32")
|
||||
parser.add_argument("--save_weights_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"],
|
||||
help="Data type for saving the final LoHA weights in the .safetensors file. Default: bf16.")
|
||||
parser.add_argument("--atol_fp32_check", type=float, default=1e-6,
|
||||
help="Absolute tolerance for fp32 weight difference check to consider layers identical and skip them.")
|
||||
parser.add_argument("--use_bias", action="store_true", help="If set, save fine-tuned bias terms if they differ from the base model's bias (saved with original key names).")
|
||||
|
||||
parser.add_argument("--dropout", type=float, default=0.0, help="General dropout for LoHA modules (for metadata).")
|
||||
parser.add_argument("--rank_dropout", type=float, default=0.0, help="Rank dropout rate for LoHA modules (for metadata).")
|
||||
parser.add_argument("--module_dropout", type=float, default=0.0, help="Module dropout rate for LoHA modules (for metadata).")
|
||||
|
||||
parser.add_argument("--max_layers", type=int, default=None,
|
||||
help="Process at most N differing layers for quick testing. Layers are sorted by name before processing.")
|
||||
parser.add_argument("--verbose", action="store_true", help="Print general verbose information during processing.")
|
||||
parser.add_argument("--verbose_layer_debug", action="store_true",
|
||||
help="Print very detailed per-iteration debug info for each layer's optimization (can be very spammy).")
|
||||
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(parsed_args.base_model_path): print(f"Error: Base model path not found: {parsed_args.base_model_path}"); exit(1)
|
||||
if not os.path.exists(parsed_args.ft_model_path): print(f"Error: Fine-tuned model path not found: {parsed_args.ft_model_path}"); exit(1)
|
||||
|
||||
save_dir = os.path.dirname(parsed_args.save_to)
|
||||
if save_dir and not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True); print(f"Created directory: {save_dir}")
|
||||
|
||||
if parsed_args.initial_alpha is None: parsed_args.initial_alpha = float(parsed_args.rank)
|
||||
conv_rank_for_alpha_default = parsed_args.conv_rank if parsed_args.conv_rank is not None else parsed_args.rank
|
||||
if parsed_args.initial_conv_alpha is None: parsed_args.initial_conv_alpha = parsed_args.initial_alpha
|
||||
|
||||
main(parsed_args)
|
||||
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, model_util, sdxl_model_util
|
||||
import lora
|
||||
import lora # Assuming this is your existing lora script/library
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
|
|
@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
# ... (Keep all your existing helper functions: index_sv_cumulative, index_sv_fro, etc.)
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
|
|
@ -36,62 +37,73 @@ def index_sv_ratio(S, target):
|
|||
index = max(1, min(index, len(S) - 1))
|
||||
return index
|
||||
|
||||
def index_sv_knee(S):
|
||||
"""Determine rank using the knee point detection method."""
|
||||
def index_sv_knee_improved(S, MIN_SV_KNEE=1e-8): # MIN_SV_KNEE can be same as global MIN_SV or specific
|
||||
n = len(S)
|
||||
if n < 3: # Need at least 3 points to detect a knee
|
||||
return 1
|
||||
if n < 3: return 1
|
||||
s_max, s_min = S[0], S[-1]
|
||||
if s_max - s_min < MIN_SV_KNEE: return 1
|
||||
s_normalized = (S - s_min) / (s_max - s_min)
|
||||
x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
|
||||
distances = (x_normalized + s_normalized - 1).abs()
|
||||
knee_index_0based = torch.argmax(distances).item()
|
||||
rank = knee_index_0based + 1
|
||||
rank = max(1, min(rank, n - 1))
|
||||
return rank
|
||||
|
||||
# Line coefficients from (1, S[0]) to (n, S[-1])
|
||||
a = S[0] - S[-1]
|
||||
b = n - 1
|
||||
c = 1 * S[-1] - n * S[0]
|
||||
|
||||
# Compute distances for each k
|
||||
distances = []
|
||||
for k in range(1, n + 1):
|
||||
dist = abs(a * k + b * S[k - 1] + c) / (a**2 + b**2)**0.5
|
||||
distances.append(dist)
|
||||
|
||||
# Find index of maximum distance (add 1 because k starts at 1)
|
||||
index = torch.argmax(torch.tensor(distances)).item() + 1
|
||||
index = max(1, min(index, n - 1))
|
||||
return index
|
||||
def index_sv_cumulative_knee(S, min_sv_threshold=1e-8):
|
||||
n = len(S)
|
||||
if n < 3: return 1
|
||||
s_sum = torch.sum(S)
|
||||
if s_sum < min_sv_threshold: return 1
|
||||
y_values = torch.cumsum(S, dim=0) / s_sum
|
||||
y_min, y_max = y_values[0], y_values[n-1]
|
||||
if y_max - y_min < min_sv_threshold: return 1
|
||||
y_norm = (y_values - y_min) / (y_max - y_min)
|
||||
x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
|
||||
distances = (y_norm - x_norm).abs()
|
||||
knee_index_0based = torch.argmax(distances).item()
|
||||
rank = knee_index_0based + 1
|
||||
rank = max(1, min(rank, n - 1))
|
||||
return rank
|
||||
|
||||
def index_sv_rel_decrease(S, tau=0.1):
|
||||
"""Determine rank based on relative decrease threshold."""
|
||||
if len(S) < 2:
|
||||
return 1
|
||||
|
||||
# Compute ratios of consecutive singular values
|
||||
if len(S) < 2: return 1
|
||||
ratios = S[1:] / S[:-1]
|
||||
|
||||
# Find the smallest k where ratio < tau
|
||||
for k in range(len(ratios)):
|
||||
if ratios[k] < tau:
|
||||
return max(1, k + 1) # k + 1 because we want rank after the drop
|
||||
return max(1, k + 1)
|
||||
return min(len(S), len(S) - 1 if len(S) > 1 else 1)
|
||||
|
||||
# If no drop below tau, return max rank
|
||||
return min(len(S), len(S) - 1)
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata=None):
|
||||
def save_to_file(file_name, model_sd, dtype, metadata=None): # Changed model to model_sd for clarity
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if isinstance(state_dict[key], torch.Tensor):
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name, metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
for key in list(model_sd.keys()):
|
||||
if isinstance(model_sd[key], torch.Tensor):
|
||||
model_sd[key] = model_sd[key].to(dtype)
|
||||
|
||||
def svd(
|
||||
model_org=None,
|
||||
model_tuned=None,
|
||||
# Filter out non-tensor metadata if it accidentally gets into model_sd
|
||||
final_sd = {k: v for k, v in model_sd.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(final_sd, file_name, metadata=metadata) # Pass metadata here
|
||||
else:
|
||||
# For .pt, metadata is typically not saved in this manner.
|
||||
# If you need to save metadata with .pt, you might save a dict like {'state_dict': final_sd, 'metadata': metadata}
|
||||
torch.save(final_sd, file_name)
|
||||
if metadata:
|
||||
logger.warning(".pt format does not standardly support metadata like safetensors. Metadata not saved in file.")
|
||||
|
||||
def svd_decomposition(
|
||||
model_org_path=None, # Renamed for clarity
|
||||
model_tuned_path=None, # Renamed for clarity
|
||||
save_to=None,
|
||||
dim=4,
|
||||
algo="lora", # New: lora or loha
|
||||
network_dim=4, # For LoRA: rank. For LoHA: "factor" or "hada_dim"
|
||||
network_alpha=None, # For LoRA: alpha (often same as rank). For LoHA: "rank_initial" or "hada_alpha"
|
||||
conv_dim=None, # For LoRA: conv_rank. For LoHA: "conv_factor"
|
||||
conv_alpha=None, # For LoRA: conv_alpha. For LoHA: "conv_rank_initial"
|
||||
v2=None,
|
||||
sdxl=None,
|
||||
conv_dim=None,
|
||||
v_parameterization=None,
|
||||
device=None,
|
||||
save_precision=None,
|
||||
|
|
@ -106,12 +118,9 @@ def svd(
|
|||
verbose=False,
|
||||
):
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
if p == "float": return torch.float
|
||||
if p == "fp16": return torch.float16
|
||||
if p == "bf16": return torch.bfloat16
|
||||
return None
|
||||
|
||||
assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time"
|
||||
|
|
@ -119,222 +128,395 @@ def svd(
|
|||
|
||||
load_dtype = str_to_dtype(load_precision) if load_precision else None
|
||||
save_dtype = str_to_dtype(save_precision) if save_precision else torch.float
|
||||
work_device = "cpu"
|
||||
work_device = "cpu" # Perform SVD and weight manipulation on CPU then move
|
||||
compute_device = device if device else "cpu"
|
||||
|
||||
# Handle default alpha values based on dim values if not provided
|
||||
if network_alpha is None: network_alpha = network_dim
|
||||
if conv_dim is None: conv_dim = network_dim # default conv_dim to network_dim
|
||||
if conv_alpha is None: conv_alpha = conv_dim # default conv_alpha to conv_dim
|
||||
|
||||
# Load models
|
||||
if not sdxl:
|
||||
logger.info(f"Loading original SD model: {model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||
logger.info(f"Loading original SD model: {model_org_path}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org_path, load_dtype)
|
||||
text_encoders_o = [text_encoder_o]
|
||||
if load_dtype:
|
||||
text_encoder_o.to(load_dtype)
|
||||
unet_o.to(load_dtype)
|
||||
|
||||
logger.info(f"Loading tuned SD model: {model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||
logger.info(f"Loading tuned SD model: {model_tuned_path}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned_path, load_dtype)
|
||||
text_encoders_t = [text_encoder_t]
|
||||
if load_dtype:
|
||||
text_encoder_t.to(load_dtype)
|
||||
unet_t.to(load_dtype)
|
||||
|
||||
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||
else:
|
||||
else: # SDXL
|
||||
device_org = load_original_model_to or "cpu"
|
||||
device_tuned = load_tuned_model_to or "cpu"
|
||||
|
||||
logger.info(f"Loading original SDXL model: {model_org}")
|
||||
logger.info(f"Loading original SDXL model: {model_org_path}")
|
||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org_path, device_org, load_dtype
|
||||
)
|
||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||
if load_dtype:
|
||||
text_encoder_o1.to(load_dtype)
|
||||
text_encoder_o2.to(load_dtype)
|
||||
unet_o.to(load_dtype)
|
||||
|
||||
logger.info(f"Loading tuned SDXL model: {model_tuned}")
|
||||
logger.info(f"Loading tuned SDXL model: {model_tuned_path}")
|
||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned_path, device_tuned, load_dtype
|
||||
)
|
||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||
if load_dtype:
|
||||
text_encoder_t1.to(load_dtype)
|
||||
text_encoder_t2.to(load_dtype)
|
||||
unet_t.to(load_dtype)
|
||||
|
||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# Create LoRA network
|
||||
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {}
|
||||
# Create temporary LoRA network to identify modules and get original weights
|
||||
# Use a minimal fixed dimension for this stage as we only need module structure and original weights
|
||||
temp_lora_kwargs = {"conv_dim": 1, "conv_alpha": 1.0} # Minimal conv settings for network creation
|
||||
lora_network_o = lora.create_network(1.0, 1, 1.0, None, text_encoders_o, unet_o, **temp_lora_kwargs)
|
||||
lora_network_t = lora.create_network(1.0, 1, 1.0, None, text_encoders_t, unet_t, **temp_lora_kwargs)
|
||||
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras)
|
||||
|
||||
# Define a small initial dimension for memory efficiency
|
||||
init_dim = 4 # Small value to minimize memory usage
|
||||
|
||||
# Create LoRA networks with minimal dimension
|
||||
lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs)
|
||||
lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs)
|
||||
|
||||
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)"
|
||||
|
||||
# Compute differences
|
||||
diffs = {}
|
||||
text_encoder_different = False
|
||||
for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras):
|
||||
lora_name = lora_o.lora_name
|
||||
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
|
||||
text_encoder_differs = False
|
||||
# Text Encoders
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
||||
module_key_name = lora_o.lora_name # e.g. "lora_te1_text_model_encoder_layers_0_mlp_fc1"
|
||||
org_weight = lora_o.org_module.weight.to(device=work_device, dtype=torch.float)
|
||||
tuned_weight = lora_t.org_module.weight.to(device=work_device, dtype=torch.float)
|
||||
diff = tuned_weight - org_weight
|
||||
|
||||
if torch.max(torch.abs(diff)) > min_diff:
|
||||
text_encoder_differs = True
|
||||
logger.info(f"Text encoder {i+1} module {module_key_name} differs: max diff {torch.max(torch.abs(diff))}")
|
||||
diffs[module_key_name] = diff
|
||||
else:
|
||||
logger.info(f"Text encoder {i+1} module {module_key_name} has no significant difference.")
|
||||
|
||||
# Free memory
|
||||
lora_o.org_module.weight = None
|
||||
lora_t.org_module.weight = None
|
||||
del org_weight, tuned_weight
|
||||
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||
text_encoder_different = True
|
||||
logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
diffs[lora_name] = diff
|
||||
|
||||
for text_encoder in text_encoders_t:
|
||||
del text_encoder
|
||||
|
||||
if not text_encoder_different:
|
||||
logger.warning("Text encoders are identical. Extracting U-Net only.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
diffs.clear()
|
||||
|
||||
# UNet
|
||||
for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras):
|
||||
lora_name = lora_o.lora_name
|
||||
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
|
||||
module_key_name = lora_o.lora_name # e.g. "lora_unet_input_blocks_1_1_proj_in"
|
||||
org_weight = lora_o.org_module.weight.to(device=work_device, dtype=torch.float)
|
||||
tuned_weight = lora_t.org_module.weight.to(device=work_device, dtype=torch.float)
|
||||
diff = tuned_weight - org_weight
|
||||
|
||||
if torch.max(torch.abs(diff)) > min_diff:
|
||||
logger.info(f"UNet module {module_key_name} differs: max diff {torch.max(torch.abs(diff))}")
|
||||
diffs[module_key_name] = diff
|
||||
else:
|
||||
logger.info(f"UNet module {module_key_name} has no significant difference.")
|
||||
|
||||
lora_o.org_module.weight = None
|
||||
lora_t.org_module.weight = None
|
||||
diffs[lora_name] = diff
|
||||
del org_weight, tuned_weight
|
||||
|
||||
del lora_network_t, unet_t
|
||||
if not text_encoder_differs:
|
||||
logger.warning("Text encoder weights are identical or below min_diff. Text encoder LoRA modules will not be included.")
|
||||
# Remove text encoder diffs if none were significant
|
||||
diffs = {k: v for k, v in diffs.items() if "unet" in k}
|
||||
|
||||
# Filter relevant modules
|
||||
lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
|
||||
del lora_network_o, lora_network_t, text_encoders_o, text_encoders_t, unet_o, unet_t
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
lora_module_weights = {} # This will store the final decomposed weights for LoRA/LoHA
|
||||
|
||||
logger.info(f"Extracting and resizing {algo.upper()} modules via SVD")
|
||||
|
||||
# Extract and resize LoRA using SVD
|
||||
logger.info("Extracting and resizing LoRA via SVD")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name in tqdm(lora_names):
|
||||
mat = diffs[lora_name]
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
mat = mat.to(torch.float)
|
||||
for module_key_name, mat_diff in tqdm(diffs.items()):
|
||||
if compute_device != "cpu": # Move to GPU for SVD if specified
|
||||
mat_diff = mat_diff.to(compute_device)
|
||||
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = mat.size()[2:4] if conv2d else None
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
# Determine if the layer is convolutional and its properties
|
||||
is_conv = len(mat_diff.shape) == 4
|
||||
kernel_size = None
|
||||
if is_conv:
|
||||
kernel_size = mat_diff.shape[2:4]
|
||||
|
||||
if conv2d:
|
||||
mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze()
|
||||
# For LoRA, conv_dim/alpha are specific to 3x3 convs. For others, it uses network_dim/alpha.
|
||||
# For LoHA, we use conv_dim/alpha for any conv layer, and network_dim/alpha for linear.
|
||||
# This logic can be refined based on how LyCORIS typically handles different conv kernel sizes for LoHA.
|
||||
# Here, we'll use conv parameters if it's a conv layer, otherwise network parameters.
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
current_dim_target = conv_dim if is_conv else network_dim
|
||||
current_alpha_target = conv_alpha if is_conv else network_alpha
|
||||
|
||||
# For LoHA, 'alpha_target' is the rank of the first SVD (rank_initial)
|
||||
# and 'dim_target' is the rank of the second SVD (rank_factor).
|
||||
# For LoRA, 'dim_target' is the rank of the SVD, and 'alpha_target' is the scaling factor.
|
||||
|
||||
# Reshape convolutional weights for SVD
|
||||
original_shape = mat_diff.shape
|
||||
if is_conv:
|
||||
if kernel_size == (1, 1):
|
||||
mat_diff = mat_diff.squeeze() # Becomes (out_channels, in_channels)
|
||||
else: # kernel_size (3,3) or others
|
||||
mat_diff = mat_diff.flatten(start_dim=1) # Becomes (out_channels, in_channels*k_w*k_h)
|
||||
|
||||
out_features, in_features = mat_diff.shape[0], mat_diff.shape[1]
|
||||
|
||||
# Perform first SVD
|
||||
try:
|
||||
U, S, Vh = torch.linalg.svd(mat_diff)
|
||||
except Exception as e:
|
||||
logger.error(f"SVD failed for {module_key_name} with shape {mat_diff.shape}: {e}")
|
||||
continue
|
||||
|
||||
if compute_device != "cpu": # Move results back to CPU if computation was on GPU
|
||||
U, S, Vh = U.cpu(), S.cpu(), Vh.cpu()
|
||||
|
||||
# Determine rank for the first SVD (rank_initial for LoHA, rank for LoRA)
|
||||
max_rank_initial = min(out_features, in_features) # Theoretical max rank
|
||||
|
||||
# Default rank_initial to current_alpha_target (which is network_alpha or conv_alpha)
|
||||
rank_initial = current_alpha_target
|
||||
|
||||
# Determine rank
|
||||
max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
|
||||
if dynamic_method:
|
||||
if S[0] <= MIN_SV:
|
||||
rank = 1
|
||||
elif dynamic_method == "sv_ratio":
|
||||
rank = index_sv_ratio(S, dynamic_param)
|
||||
elif dynamic_method == "sv_cumulative":
|
||||
rank = index_sv_cumulative(S, dynamic_param)
|
||||
elif dynamic_method == "sv_fro":
|
||||
rank = index_sv_fro(S, dynamic_param)
|
||||
elif dynamic_method == "sv_knee":
|
||||
rank = index_sv_knee(S)
|
||||
elif dynamic_method == "sv_rel_decrease":
|
||||
rank = index_sv_rel_decrease(S, dynamic_param)
|
||||
rank = min(rank, max_rank, in_dim, out_dim)
|
||||
determined_rank = 1
|
||||
elif dynamic_method == "sv_ratio": determined_rank = index_sv_ratio(S, dynamic_param)
|
||||
elif dynamic_method == "sv_cumulative": determined_rank = index_sv_cumulative(S, dynamic_param)
|
||||
elif dynamic_method == "sv_fro": determined_rank = index_sv_fro(S, dynamic_param)
|
||||
elif dynamic_method == "sv_knee": determined_rank = index_sv_knee_improved(S, MIN_SV)
|
||||
elif dynamic_method == "sv_cumulative_knee": determined_rank = index_sv_cumulative_knee(S, MIN_SV)
|
||||
elif dynamic_method == "sv_rel_decrease": determined_rank = index_sv_rel_decrease(S, dynamic_param)
|
||||
else: determined_rank = rank_initial # Fallback if dynamic method unknown
|
||||
rank_initial = min(determined_rank, current_alpha_target, max_rank_initial)
|
||||
else:
|
||||
rank = min(max_rank, in_dim, out_dim)
|
||||
rank_initial = min(current_alpha_target, max_rank_initial)
|
||||
rank_initial = max(1, rank_initial) # Ensure rank is at least 1
|
||||
|
||||
# Truncate SVD components
|
||||
U = U[:, :rank] @ torch.diag(S[:rank])
|
||||
Vh = Vh[:rank, :]
|
||||
# --- LoRA specific decomposition ---
|
||||
if algo == 'lora':
|
||||
lora_down = Vh[:rank_initial, :]
|
||||
lora_up = U[:, :rank_initial] @ torch.diag(S[:rank_initial])
|
||||
|
||||
# Clamp values
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, clamp_quantile)
|
||||
U = U.clamp(-hi_val, hi_val)
|
||||
Vh = Vh.clamp(-hi_val, hi_val)
|
||||
# Clamp values
|
||||
dist = torch.cat([lora_up.flatten(), lora_down.flatten()])
|
||||
hi_val = torch.quantile(dist, clamp_quantile) if clamp_quantile < 1.0 else dist.abs().max()
|
||||
lora_up = lora_up.clamp(-hi_val, hi_val)
|
||||
lora_down = lora_down.clamp(-hi_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, *kernel_size)
|
||||
# Reshape for conv layers if necessary
|
||||
if is_conv:
|
||||
if kernel_size == (1,1):
|
||||
# These are already (out_c, rank) and (rank, in_c)
|
||||
# Some LoRA impls might expect them to be 4D (out_c, rank, 1, 1) and (rank, in_c, 1, 1)
|
||||
lora_up = lora_up.reshape(out_features, rank_initial, 1, 1)
|
||||
lora_down = lora_down.reshape(rank_initial, in_features, 1, 1)
|
||||
else: # e.g. 3x3 conv
|
||||
# lora_down was (rank, in_c*k_w*k_h), needs to be (rank, in_c, k_w, k_h)
|
||||
# lora_up was (out_c, rank)
|
||||
lora_up = lora_up.reshape(out_features, rank_initial, 1, 1) # often up is kept as 1x1 conv like
|
||||
lora_down = lora_down.reshape(rank_initial, original_shape[1], *kernel_size)
|
||||
|
||||
U = U.to(work_device, dtype=save_dtype).contiguous()
|
||||
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
lora_module_weights[f"{module_key_name}.lora_down.weight"] = lora_down.to(work_device, dtype=save_dtype).contiguous()
|
||||
lora_module_weights[f"{module_key_name}.lora_up.weight"] = lora_up.to(work_device, dtype=save_dtype).contiguous()
|
||||
lora_module_weights[f"{module_key_name}.alpha"] = torch.tensor(float(current_alpha_target), dtype=save_dtype) # Use actual alpha target from params
|
||||
|
||||
# --- LoHA specific decomposition ---
|
||||
elif algo == 'loha':
|
||||
lora_down_equivalent = Vh[:rank_initial, :]
|
||||
lora_up_equivalent = U[:, :rank_initial] @ torch.diag(S[:rank_initial])
|
||||
|
||||
# current_dim_target is the "factor" for LoHA's second SVD
|
||||
rank_factor = min(current_dim_target, rank_initial) # Factor cannot exceed rank_initial
|
||||
if is_conv and kernel_size != (1,1): # For conv3x3, factor also limited by in/out features
|
||||
rank_factor = min(rank_factor, original_shape[1], original_shape[0]) #original_shape[1] is in_channels
|
||||
else: # Linear or Conv1x1
|
||||
rank_factor = min(rank_factor, in_features, out_features)
|
||||
rank_factor = max(1, rank_factor)
|
||||
|
||||
|
||||
# Decompose Lora_Down_Equivalent (shape: rank_initial x in_features_eff)
|
||||
# Target: hada_w1_b (rank_initial x rank_factor) @ hada_w1_a (rank_factor x in_features_eff)
|
||||
if compute_device != "cpu": lora_down_equivalent = lora_down_equivalent.to(compute_device)
|
||||
Ud, Sd, Vhd = torch.linalg.svd(lora_down_equivalent)
|
||||
if compute_device != "cpu": Ud, Sd, Vhd = Ud.cpu(), Sd.cpu(), Vhd.cpu()
|
||||
|
||||
hada_w1_a = Vhd[:rank_factor, :]
|
||||
hada_w1_b = Ud[:, :rank_factor] @ torch.diag(Sd[:rank_factor])
|
||||
|
||||
# Decompose Lora_Up_Equivalent (shape: out_features_eff x rank_initial)
|
||||
# Target: hada_w2_b (out_features_eff x rank_factor) @ hada_w2_a (rank_factor x rank_initial)
|
||||
if compute_device != "cpu": lora_up_equivalent = lora_up_equivalent.to(compute_device)
|
||||
Uu, Su, Vhu = torch.linalg.svd(lora_up_equivalent)
|
||||
if compute_device != "cpu": Uu, Su, Vhu = Uu.cpu(), Su.cpu(), Vhu.cpu()
|
||||
|
||||
hada_w2_a = Vhu[:rank_factor, :]
|
||||
hada_w2_b = Uu[:, :rank_factor] @ torch.diag(Su[:rank_factor])
|
||||
|
||||
# Clamp LoHA components
|
||||
dist_w1a = hada_w1_a.flatten()
|
||||
dist_w1b = hada_w1_b.flatten()
|
||||
dist_w2a = hada_w2_a.flatten()
|
||||
dist_w2b = hada_w2_b.flatten()
|
||||
|
||||
if clamp_quantile < 1.0:
|
||||
hi_val_w1a = torch.quantile(dist_w1a.abs(), clamp_quantile)
|
||||
hi_val_w1b = torch.quantile(dist_w1b.abs(), clamp_quantile)
|
||||
hi_val_w2a = torch.quantile(dist_w2a.abs(), clamp_quantile)
|
||||
hi_val_w2b = torch.quantile(dist_w2b.abs(), clamp_quantile)
|
||||
else: # Use max abs value if quantile is 1.0 or more
|
||||
hi_val_w1a = dist_w1a.abs().max()
|
||||
hi_val_w1b = dist_w1b.abs().max()
|
||||
hi_val_w2a = dist_w2a.abs().max()
|
||||
hi_val_w2b = dist_w2b.abs().max()
|
||||
|
||||
hada_w1_a = hada_w1_a.clamp(-hi_val_w1a, hi_val_w1a)
|
||||
hada_w1_b = hada_w1_b.clamp(-hi_val_w1b, hi_val_w1b)
|
||||
hada_w2_a = hada_w2_a.clamp(-hi_val_w2a, hi_val_w2a)
|
||||
hada_w2_b = hada_w2_b.clamp(-hi_val_w2b, hi_val_w2b)
|
||||
|
||||
# LoHA weights are typically stored as 2D matrices.
|
||||
# LyCORIS library handles reshaping or uses 1x1 convs internally.
|
||||
lora_module_weights[f"{module_key_name}.hada_w1_a"] = hada_w1_a.to(work_device, dtype=save_dtype).contiguous()
|
||||
lora_module_weights[f"{module_key_name}.hada_w1_b"] = hada_w1_b.to(work_device, dtype=save_dtype).contiguous()
|
||||
lora_module_weights[f"{module_key_name}.hada_w2_a"] = hada_w2_a.to(work_device, dtype=save_dtype).contiguous()
|
||||
lora_module_weights[f"{module_key_name}.hada_w2_b"] = hada_w2_b.to(work_device, dtype=save_dtype).contiguous()
|
||||
lora_module_weights[f"{module_key_name}.alpha"] = torch.tensor(float(current_alpha_target), dtype=save_dtype) # This is rank_initial
|
||||
|
||||
# Verbose output
|
||||
if verbose:
|
||||
s_sum = float(torch.sum(S))
|
||||
s_rank = float(torch.sum(S[:rank]))
|
||||
fro = float(torch.sqrt(torch.sum(S.pow(2))))
|
||||
fro_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2))))
|
||||
ratio = S[0] / S[rank - 1] if rank > 1 else float('inf')
|
||||
logger.info(f"{lora_name:75} | sum(S) retained: {s_rank/s_sum:.1%}, fro retained: {fro_rank/fro:.1%}, max ratio: {ratio:.1f}, rank: {rank}")
|
||||
s_rank_initial = float(torch.sum(S[:rank_initial]))
|
||||
fro_initial = float(torch.sqrt(torch.sum(S.pow(2))))
|
||||
fro_rank_initial = float(torch.sqrt(torch.sum(S[:rank_initial].pow(2))))
|
||||
# This verbose output is for the first SVD. Adding verbose for second SVD would be more complex.
|
||||
logger.info(
|
||||
f"{module_key_name[:75]:75} | Algo: {algo.upper()}, Rank/Alpha (initial): {rank_initial}, Dim/Factor (final): {rank_factor if algo=='loha' else rank_initial} "
|
||||
f"| Sum(S) Retained (1st SVD): {s_rank_initial/s_sum if s_sum > 0 else 0:.1%}, Fro Retained (1st SVD): {fro_rank_initial/fro_initial if fro_initial > 0 else 0:.1%}"
|
||||
)
|
||||
|
||||
# Create state dict
|
||||
lora_sd = {}
|
||||
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
||||
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
||||
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype)
|
||||
del U, S, Vh, mat_diff
|
||||
if algo == 'loha':
|
||||
del lora_down_equivalent, lora_up_equivalent, Ud, Sd, Vhd, Uu, Su, Vhu
|
||||
del hada_w1_a, hada_w1_b, hada_w2_a, hada_w2_b
|
||||
else: # lora
|
||||
del lora_down, lora_up
|
||||
if compute_device != "cpu":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load and save LoRA
|
||||
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
|
||||
lora_network_save.apply_to(text_encoders_o, unet_o)
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
logger.info(f"Loaded extracted and resized LoRA weights: {info}")
|
||||
|
||||
if not lora_module_weights:
|
||||
logger.error("No LoRA/LoHA modules were generated. This might be due to models being too similar or min_diff being too high.")
|
||||
return
|
||||
|
||||
os.makedirs(os.path.dirname(save_to), exist_ok=True)
|
||||
|
||||
# Metadata
|
||||
net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {}
|
||||
metadata = {
|
||||
"ss_v2": str(v2),
|
||||
"ss_v2": str(v2) if v2 is not None else "false", # More explicit "false"
|
||||
"ss_base_model_version": model_version,
|
||||
"ss_network_module": "networks.lora",
|
||||
"ss_network_dim": str(dim) if not dynamic_method else "Dynamic",
|
||||
"ss_network_alpha": str(float(dim)) if not dynamic_method else "Dynamic",
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
"ss_sdxl_model_version": "1.0" if sdxl else "null", # LyCORIS convention
|
||||
"ss_network_module": "networks.lora" if algo == "lora" else "lycoris.kohya", # For LoHA
|
||||
# For LoRA, dim and alpha are usually the same as what's passed.
|
||||
# For LoHA, network_dim is the 'factor' (our network_dim/conv_dim),
|
||||
# and network_alpha is the 'rank_initial' (our network_alpha/conv_alpha).
|
||||
"ss_network_dim": str(network_dim) if not dynamic_method or algo == "loha" else "Dynamic", # For LoHA, this is 'factor'
|
||||
"ss_network_alpha": str(network_alpha) if not dynamic_method or algo == "loha" else "Dynamic", # For LoHA, this is 'rank_initial'
|
||||
}
|
||||
|
||||
net_kwargs = {}
|
||||
if algo == "lora":
|
||||
if conv_dim is not None: # Only add if conv_dim was actually specified for LoRA
|
||||
net_kwargs["conv_dim"] = str(conv_dim)
|
||||
net_kwargs["conv_alpha"] = str(float(conv_alpha if conv_alpha is not None else conv_dim))
|
||||
elif algo == "loha":
|
||||
net_kwargs["algo"] = "loha"
|
||||
# For LoHA, conv_dim and conv_alpha are distinct concepts (factor and rank_initial for conv layers)
|
||||
net_kwargs["conv_dim"] = str(conv_dim)
|
||||
net_kwargs["conv_alpha"] = str(float(conv_alpha))
|
||||
# LyCORIS sometimes uses dropout, but we are not implementing it here.
|
||||
# net_kwargs["dropout"] = "0" # Example if we had dropout
|
||||
|
||||
metadata["ss_network_args"] = json.dumps(net_kwargs) if net_kwargs else "null" # LyCORIS uses "null" for empty
|
||||
|
||||
if not no_metadata:
|
||||
title = os.path.splitext(os.path.basename(save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
|
||||
metadata.update(sai_metadata)
|
||||
# sai_model_spec might need adjustment if it expects specific lora types not lycoris
|
||||
try:
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
None, v2, v_parameterization, sdxl, True,
|
||||
is_lycoris=(algo != "lora"), # Pass if it's LyCORIS
|
||||
is_lora=(algo == "lora"), # Pass if it's LoRA
|
||||
creation_time=time.time(), title=title
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
except TypeError as e:
|
||||
logger.warning(f"Could not generate full SAI metadata, possibly due to outdated sai_model_spec.py or new flags: {e}")
|
||||
logger.warning("Falling back to basic metadata for SAI fields.")
|
||||
metadata.update({ # Basic fallback
|
||||
"sai_model_name": title,
|
||||
"sai_base_model": model_version,
|
||||
"sai_is_sdxl": str(sdxl).lower(),
|
||||
})
|
||||
|
||||
|
||||
save_to_file(save_to, lora_module_weights, save_dtype, metadata)
|
||||
logger.info(f"{algo.upper()} saved to: {save_to}")
|
||||
|
||||
save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata)
|
||||
logger.info(f"LoRA saved to: {save_to}")
|
||||
|
||||
def setup_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
|
||||
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)")
|
||||
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2 if applicable)")
|
||||
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
|
||||
parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models")
|
||||
parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA")
|
||||
parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
|
||||
parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
|
||||
parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)")
|
||||
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
|
||||
parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3")
|
||||
parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)")
|
||||
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
|
||||
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract")
|
||||
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata")
|
||||
parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)")
|
||||
parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)")
|
||||
parser.add_argument("--dynamic_method", choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease"], help="Dynamic rank reduction method")
|
||||
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
|
||||
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info")
|
||||
|
||||
parser.add_argument("--model_org_path", type=str, required=True, help="Path to the original Stable Diffusion model (ckpt/safetensors)")
|
||||
parser.add_argument("--model_tuned_path", type=str, required=True, help="Path to the tuned Stable Diffusion model (ckpt/safetensors)")
|
||||
parser.add_argument("--save_to", type=str, required=True, help="Output file name for the LoRA/LoHA (ckpt/safetensors)")
|
||||
|
||||
parser.add_argument("--algo", type=str, default="lora", choices=["lora", "loha"], help="Algorithm to use: lora or loha")
|
||||
|
||||
parser.add_argument("--network_dim", type=int, default=4, help="Network dimension. For LoRA: rank. For LoHA: 'factor' or 'hada_dim'.")
|
||||
parser.add_argument("--network_alpha", type=int, default=None, help="Network alpha. For LoRA: alpha (often same as rank). For LoHA: 'rank_initial' or 'hada_alpha'. Defaults to network_dim if not set.")
|
||||
parser.add_argument("--conv_dim", type=int, default=None, help="Conv dimension for conv layers. For LoRA: rank. For LoHA: 'factor'. Defaults to network_dim if not set.")
|
||||
parser.add_argument("--conv_alpha", type=int, default=None, help="Conv alpha for conv layers. For LoRA: alpha. For LoHA: 'rank_initial'. Defaults to conv_dim if not set.")
|
||||
|
||||
parser.add_argument("--load_precision", type=str, choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for loading models (None means default float32)")
|
||||
parser.add_argument("--save_precision", type=str, choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA/LoHA (None means float32)")
|
||||
|
||||
parser.add_argument("--device", type=str, default=None, help="Device for SVD computation (e.g., 'cuda', 'cpu'). If None, defaults to 'cuda' if available, else 'cpu'. SVD results are moved to CPU for storage.")
|
||||
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights (0.0 to 1.0). 1.0 means clamp to max abs value.")
|
||||
parser.add_argument("--min_diff", type=float, default=1e-6, help="Minimum weight difference threshold for a module to be considered for extraction.") # Lowered default
|
||||
parser.add_argument("--no_metadata", action="store_true", help="Do not save detailed metadata (minimal ss_ metadata will still be saved).")
|
||||
|
||||
parser.add_argument("--load_original_model_to", type=str, default=None, help="Device to load original model to (SDXL only, e.g., 'cpu', 'cuda:0')")
|
||||
parser.add_argument("--load_tuned_model_to", type=str, default=None, help="Device to load tuned model to (SDXL only, e.g., 'cpu', 'cuda:1')")
|
||||
|
||||
parser.add_argument("--dynamic_method", type=str, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], default=None, help="Dynamic method to determine rank/alpha (for the first SVD in LoHA). Overrides fixed network_alpha/conv_alpha if set.")
|
||||
parser.add_argument("--dynamic_param", type=float, default=0.9, help="Parameter for the chosen dynamic_method (e.g., target ratio/cumulative sum/Frobenius norm percentage).")
|
||||
parser.add_argument("--verbose", action="store_true", help="Show detailed SVD info for each module.")
|
||||
|
||||
return parser
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
if args.dynamic_method and not args.dynamic_param:
|
||||
raise ValueError("Dynamic method requires a dynamic parameter")
|
||||
svd(**vars(args))
|
||||
|
||||
if args.dynamic_method and (args.dynamic_param is None):
|
||||
# Default dynamic_param for methods if not specified, or raise error
|
||||
if args.dynamic_method in ["sv_cumulative", "sv_fro"]:
|
||||
args.dynamic_param = 0.99 # Example: 99% variance/energy
|
||||
logger.info(f"Dynamic method {args.dynamic_method} chosen, dynamic_param defaulted to {args.dynamic_param}")
|
||||
elif args.dynamic_method in ["sv_ratio"]:
|
||||
args.dynamic_param = 1000 # Example: ratio of 1000
|
||||
logger.info(f"Dynamic method {args.dynamic_method} chosen, dynamic_param defaulted to {args.dynamic_param}")
|
||||
elif args.dynamic_method in ["sv_rel_decrease"]:
|
||||
args.dynamic_param = 0.05 # Example: 5% relative decrease
|
||||
logger.info(f"Dynamic method {args.dynamic_method} chosen, dynamic_param defaulted to {args.dynamic_param}")
|
||||
# sv_knee and sv_cumulative_knee do not require dynamic_param in this implementation.
|
||||
elif args.dynamic_method not in ["sv_knee", "sv_cumulative_knee"]:
|
||||
parser.error("--dynamic_method requires --dynamic_param for most methods.")
|
||||
|
||||
|
||||
# Default device selection
|
||||
if args.device is None:
|
||||
args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {args.device} for SVD computation.")
|
||||
|
||||
# Rename args for clarity before passing to function
|
||||
func_args = vars(args).copy()
|
||||
func_args["model_org_path"] = func_args.pop("model_org_path")
|
||||
func_args["model_tuned_path"] = func_args.pop("model_tuned_path")
|
||||
|
||||
svd_decomposition(**func_args)
|
||||
|
|
@ -0,0 +1,432 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, model_util, sdxl_model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
return index
|
||||
|
||||
def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
return index
|
||||
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv / target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
return index
|
||||
|
||||
def index_sv_knee_improved(S, MIN_SV_KNEE=1e-8): # MIN_SV_KNEE can be same as global MIN_SV or specific
|
||||
"""
|
||||
Determine rank using the knee point detection method with normalization.
|
||||
Normalizes singular values and their indices to the [0,1] range
|
||||
to make the knee detection scale-invariant.
|
||||
"""
|
||||
n = len(S)
|
||||
if n < 3: # Need at least 3 points to detect a knee
|
||||
return 1
|
||||
|
||||
# S is expected to be sorted in descending order.
|
||||
s_max, s_min = S[0], S[-1]
|
||||
|
||||
# Handle flat or nearly flat singular value spectrum
|
||||
if s_max - s_min < MIN_SV_KNEE:
|
||||
# If all singular values are almost the same, a knee is not well-defined.
|
||||
# Returning 1 is a conservative choice for low rank.
|
||||
# Alternatively, n // 2 or n - 1 could be chosen depending on desired behavior.
|
||||
return 1
|
||||
|
||||
# Normalize singular values to [0, 1]
|
||||
# s_normalized[0] will be 1, s_normalized[n-1] will be 0
|
||||
s_normalized = (S - s_min) / (s_max - s_min)
|
||||
|
||||
# Normalize indices to [0, 1]
|
||||
# x_normalized[0] will be 0, x_normalized[n-1] will be 1
|
||||
x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
|
||||
|
||||
# The line in normalized space connects (x_norm[0], s_norm[0]) to (x_norm[n-1], s_norm[n-1])
|
||||
# This is (0, 1) to (1, 0).
|
||||
# The equation of the line passing through (0,1) and (1,0) is x + y - 1 = 0.
|
||||
# So, A=1, B=1, C=-1 for the line equation Ax + By + C = 0.
|
||||
|
||||
# Calculate the perpendicular distance from each point (x_normalized[i], s_normalized[i]) to this line.
|
||||
# Distance = |A*x_i + B*y_i + C| / sqrt(A^2 + B^2)
|
||||
# Distance = |1*x_normalized + 1*s_normalized - 1| / sqrt(1^2 + 1^2)
|
||||
# Distance = |x_normalized + s_normalized - 1| / sqrt(2)
|
||||
|
||||
# The sqrt(2) denominator is constant and doesn't affect argmax, so it can be omitted for finding the index.
|
||||
distances = (x_normalized + s_normalized - 1).abs()
|
||||
|
||||
# Find the 0-based index of the point with the maximum distance.
|
||||
knee_index_0based = torch.argmax(distances).item()
|
||||
|
||||
# Convert 0-based index to 1-based rank.
|
||||
rank = knee_index_0based + 1
|
||||
|
||||
# Clamp rank similar to original: must be > 0 and <= n-1 (typical for rank reduction)
|
||||
# If knee_index_0based is n-1 (last point), rank becomes n. min(n, n-1) results in n-1.
|
||||
rank = max(1, min(rank, n - 1))
|
||||
|
||||
return rank
|
||||
|
||||
def index_sv_cumulative_knee(S, min_sv_threshold=1e-8):
|
||||
"""
|
||||
Determine rank using the knee point detection method on the normalized cumulative sum of singular values.
|
||||
This method identifies a point where adding more singular values contributes diminishingly to the total sum.
|
||||
"""
|
||||
n = len(S)
|
||||
if n < 3: # Need at least 3 points to detect a knee
|
||||
return 1
|
||||
|
||||
s_sum = torch.sum(S)
|
||||
# If all singular values are zero or very small, return rank 1.
|
||||
if s_sum < min_sv_threshold:
|
||||
return 1
|
||||
|
||||
# Calculate cumulative sum of singular values, normalized by the total sum.
|
||||
# y_values[0] = S[0]/s_sum, ..., y_values[n-1] = 1.0
|
||||
y_values = torch.cumsum(S, dim=0) / s_sum
|
||||
|
||||
# Normalize these y_values (cumulative sums) to the range [0,1] for knee detection.
|
||||
y_min, y_max = y_values[0], y_values[n-1] # y_max is typically 1.0
|
||||
|
||||
# If the normalized cumulative sum curve is very flat (e.g., S[0] captures almost all energy),
|
||||
# it implies the first few components are dominant.
|
||||
if y_max - y_min < min_sv_threshold: # Using min_sv_threshold here as a sensitivity for flatness
|
||||
# This condition means (S[0] + ... + S[n-1]) - S[0] is small relative to sum(S) if n>1
|
||||
# Effectively, S[1:] components are negligible.
|
||||
return 1
|
||||
|
||||
# y_norm[0] = 0, y_norm[n-1] = 1 (represents the normalized cumulative sum from start to end)
|
||||
y_norm = (y_values - y_min) / (y_max - y_min)
|
||||
|
||||
# x_values are indices, normalized to [0, 1]
|
||||
# x_norm[0] = 0, ..., x_norm[n-1] = 1
|
||||
x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
|
||||
|
||||
# The "knee" is the point on the curve (x_norm[i], y_norm[i]) that is farthest
|
||||
# from the line connecting the start and end of this normalized curve.
|
||||
# In this normalized space, the line connects (0,0) to (1,1).
|
||||
# The equation of this line is Y = X, or X - Y = 0.
|
||||
# The distance from a point (x_i, y_i) to the line X - Y = 0 is |x_i - y_i| / sqrt(1^2 + (-1)^2).
|
||||
# We can maximize |x_i - y_i| (or |y_i - x_i|) as sqrt(2) is a constant factor.
|
||||
distances = (y_norm - x_norm).abs() # y_norm is expected to be >= x_norm for a concave cumulative curve.
|
||||
|
||||
# Find the 0-based index of the point with the maximum distance.
|
||||
knee_index_0based = torch.argmax(distances).item()
|
||||
|
||||
# Convert 0-based index to 1-based rank.
|
||||
rank = knee_index_0based + 1
|
||||
|
||||
# Clamp rank to be between 1 and n-1 (as n elements give n-1 possible ranks for truncation).
|
||||
# A rank of n means no truncation. n-1 is the highest sensible rank for reduction.
|
||||
rank = max(1, min(rank, n - 1))
|
||||
|
||||
return rank
|
||||
|
||||
def index_sv_rel_decrease(S, tau=0.1):
|
||||
"""Determine rank based on relative decrease threshold."""
|
||||
if len(S) < 2:
|
||||
return 1
|
||||
|
||||
# Compute ratios of consecutive singular values
|
||||
ratios = S[1:] / S[:-1]
|
||||
|
||||
# Find the smallest k where ratio < tau
|
||||
for k in range(len(ratios)):
|
||||
if ratios[k] < tau:
|
||||
return max(1, k + 1) # k + 1 because we want rank after the drop
|
||||
|
||||
# If no drop below tau, return max rank
|
||||
return min(len(S), len(S) - 1)
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata=None):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if isinstance(state_dict[key], torch.Tensor):
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name, metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
def svd(
|
||||
model_org=None,
|
||||
model_tuned=None,
|
||||
save_to=None,
|
||||
dim=4,
|
||||
v2=None,
|
||||
sdxl=None,
|
||||
conv_dim=None,
|
||||
v_parameterization=None,
|
||||
device=None,
|
||||
save_precision=None,
|
||||
clamp_quantile=0.99,
|
||||
min_diff=0.01,
|
||||
no_metadata=False,
|
||||
load_precision=None,
|
||||
load_original_model_to=None,
|
||||
load_tuned_model_to=None,
|
||||
dynamic_method=None,
|
||||
dynamic_param=None,
|
||||
verbose=False,
|
||||
):
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time"
|
||||
v_parameterization = v2 if v_parameterization is None else v_parameterization
|
||||
|
||||
load_dtype = str_to_dtype(load_precision) if load_precision else None
|
||||
save_dtype = str_to_dtype(save_precision) if save_precision else torch.float
|
||||
work_device = "cpu"
|
||||
|
||||
# Load models
|
||||
if not sdxl:
|
||||
logger.info(f"Loading original SD model: {model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||
text_encoders_o = [text_encoder_o]
|
||||
if load_dtype:
|
||||
text_encoder_o.to(load_dtype)
|
||||
unet_o.to(load_dtype)
|
||||
|
||||
logger.info(f"Loading tuned SD model: {model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||
text_encoders_t = [text_encoder_t]
|
||||
if load_dtype:
|
||||
text_encoder_t.to(load_dtype)
|
||||
unet_t.to(load_dtype)
|
||||
|
||||
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||
else:
|
||||
device_org = load_original_model_to or "cpu"
|
||||
device_tuned = load_tuned_model_to or "cpu"
|
||||
|
||||
logger.info(f"Loading original SDXL model: {model_org}")
|
||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
||||
)
|
||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||
if load_dtype:
|
||||
text_encoder_o1.to(load_dtype)
|
||||
text_encoder_o2.to(load_dtype)
|
||||
unet_o.to(load_dtype)
|
||||
|
||||
logger.info(f"Loading tuned SDXL model: {model_tuned}")
|
||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
||||
)
|
||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||
if load_dtype:
|
||||
text_encoder_t1.to(load_dtype)
|
||||
text_encoder_t2.to(load_dtype)
|
||||
unet_t.to(load_dtype)
|
||||
|
||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# Create LoRA network
|
||||
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {}
|
||||
|
||||
# Define a small initial dimension for memory efficiency
|
||||
init_dim = 4 # Small value to minimize memory usage
|
||||
|
||||
# Create LoRA networks with minimal dimension
|
||||
lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs)
|
||||
lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs)
|
||||
|
||||
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)"
|
||||
|
||||
# Compute differences
|
||||
diffs = {}
|
||||
text_encoder_different = False
|
||||
for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras):
|
||||
lora_name = lora_o.lora_name
|
||||
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
|
||||
lora_o.org_module.weight = None
|
||||
lora_t.org_module.weight = None
|
||||
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||
text_encoder_different = True
|
||||
logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
diffs[lora_name] = diff
|
||||
|
||||
for text_encoder in text_encoders_t:
|
||||
del text_encoder
|
||||
|
||||
if not text_encoder_different:
|
||||
logger.warning("Text encoders are identical. Extracting U-Net only.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
diffs.clear()
|
||||
|
||||
for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras):
|
||||
lora_name = lora_o.lora_name
|
||||
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
|
||||
lora_o.org_module.weight = None
|
||||
lora_t.org_module.weight = None
|
||||
diffs[lora_name] = diff
|
||||
|
||||
del lora_network_t, unet_t
|
||||
|
||||
# Filter relevant modules
|
||||
lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
|
||||
|
||||
# Extract and resize LoRA using SVD
|
||||
logger.info("Extracting and resizing LoRA via SVD")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name in tqdm(lora_names):
|
||||
mat = diffs[lora_name]
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
mat = mat.to(torch.float)
|
||||
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = mat.size()[2:4] if conv2d else None
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
|
||||
if conv2d:
|
||||
mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
# Determine rank
|
||||
max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
|
||||
if dynamic_method:
|
||||
if S[0] <= MIN_SV:
|
||||
rank = 1
|
||||
elif dynamic_method == "sv_ratio":
|
||||
rank = index_sv_ratio(S, dynamic_param)
|
||||
elif dynamic_method == "sv_cumulative":
|
||||
rank = index_sv_cumulative(S, dynamic_param)
|
||||
elif dynamic_method == "sv_fro":
|
||||
rank = index_sv_fro(S, dynamic_param)
|
||||
elif dynamic_method == "sv_knee":
|
||||
rank = index_sv_knee_improved(S, MIN_SV) # Pass MIN_SV or a specific threshold
|
||||
elif dynamic_method == "sv_cumulative_knee": # New method
|
||||
rank = index_sv_cumulative_knee(S, MIN_SV) # Pass MIN_SV or a specific threshold
|
||||
elif dynamic_method == "sv_rel_decrease":
|
||||
rank = index_sv_rel_decrease(S, dynamic_param)
|
||||
rank = min(rank, max_rank, in_dim, out_dim)
|
||||
else:
|
||||
rank = min(max_rank, in_dim, out_dim)
|
||||
|
||||
# Truncate SVD components
|
||||
U = U[:, :rank] @ torch.diag(S[:rank])
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
# Clamp values
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, clamp_quantile)
|
||||
U = U.clamp(-hi_val, hi_val)
|
||||
Vh = Vh.clamp(-hi_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, *kernel_size)
|
||||
|
||||
U = U.to(work_device, dtype=save_dtype).contiguous()
|
||||
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
|
||||
# Verbose output
|
||||
if verbose:
|
||||
s_sum = float(torch.sum(S))
|
||||
s_rank = float(torch.sum(S[:rank]))
|
||||
fro = float(torch.sqrt(torch.sum(S.pow(2))))
|
||||
fro_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2))))
|
||||
ratio = S[0] / S[rank - 1] if rank > 1 else float('inf')
|
||||
logger.info(f"{lora_name:75} | sum(S) retained: {s_rank/s_sum:.1%}, fro retained: {fro_rank/fro:.1%}, max ratio: {ratio:.1f}, rank: {rank}")
|
||||
|
||||
# Create state dict
|
||||
lora_sd = {}
|
||||
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
||||
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
||||
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype)
|
||||
|
||||
# Load and save LoRA
|
||||
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
|
||||
lora_network_save.apply_to(text_encoders_o, unet_o)
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
logger.info(f"Loaded extracted and resized LoRA weights: {info}")
|
||||
|
||||
os.makedirs(os.path.dirname(save_to), exist_ok=True)
|
||||
|
||||
# Metadata
|
||||
net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {}
|
||||
metadata = {
|
||||
"ss_v2": str(v2),
|
||||
"ss_base_model_version": model_version,
|
||||
"ss_network_module": "networks.lora",
|
||||
"ss_network_dim": str(dim) if not dynamic_method else "Dynamic",
|
||||
"ss_network_alpha": str(float(dim)) if not dynamic_method else "Dynamic",
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
}
|
||||
if not no_metadata:
|
||||
title = os.path.splitext(os.path.basename(save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata)
|
||||
logger.info(f"LoRA saved to: {save_to}")
|
||||
|
||||
def setup_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
|
||||
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)")
|
||||
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
|
||||
parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models")
|
||||
parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA")
|
||||
parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
|
||||
parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
|
||||
parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)")
|
||||
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
|
||||
parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3")
|
||||
parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)")
|
||||
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
|
||||
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract")
|
||||
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata")
|
||||
parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)")
|
||||
parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)")
|
||||
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
|
||||
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info")
|
||||
parser.add_argument(
|
||||
"--dynamic_method",
|
||||
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], # Added "sv_cumulative_knee"
|
||||
help="Dynamic rank reduction method"
|
||||
)
|
||||
return parser
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
if args.dynamic_method and not args.dynamic_param:
|
||||
raise ValueError("Dynamic method requires a dynamic parameter")
|
||||
svd(**vars(args))
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import argparse # Import argparse
|
||||
|
||||
def extract_model_differences(base_model_path, finetuned_model_path, output_delta_path=None, save_dtype_str="float32"):
|
||||
"""
|
||||
Calculates the difference between the state dictionaries of a fine-tuned model
|
||||
and a base model.
|
||||
|
||||
Args:
|
||||
base_model_path (str): Path to the base model .safetensors file.
|
||||
finetuned_model_path (str): Path to the fine-tuned model .safetensors file.
|
||||
output_delta_path (str, optional): Path to save the resulting delta weights
|
||||
.safetensors file. If None, not saved.
|
||||
save_dtype_str (str, optional): Data type to save the delta weights ('float32', 'float16', 'bfloat16').
|
||||
Defaults to 'float32'.
|
||||
Returns:
|
||||
OrderedDict: A state dictionary containing the delta weights.
|
||||
Returns None if loading fails or other critical errors.
|
||||
"""
|
||||
print(f"Loading base model from: {base_model_path}")
|
||||
try:
|
||||
# Ensure model is loaded to CPU to avoid CUDA issues if not needed for diffing
|
||||
base_state_dict = load_file(base_model_path, device="cpu")
|
||||
print(f"Base model loaded. Found {len(base_state_dict)} tensors.")
|
||||
except Exception as e:
|
||||
print(f"Error loading base model: {e}")
|
||||
return None
|
||||
|
||||
print(f"\nLoading fine-tuned model from: {finetuned_model_path}")
|
||||
try:
|
||||
finetuned_state_dict = load_file(finetuned_model_path, device="cpu")
|
||||
print(f"Fine-tuned model loaded. Found {len(finetuned_state_dict)} tensors.")
|
||||
except Exception as e:
|
||||
print(f"Error loading fine-tuned model: {e}")
|
||||
return None
|
||||
|
||||
delta_state_dict = OrderedDict()
|
||||
diff_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
unique_to_finetuned_count = 0
|
||||
unique_to_base_count = 0
|
||||
|
||||
print("\nCalculating differences...")
|
||||
|
||||
# Keys in finetuned model
|
||||
finetuned_keys = set(finetuned_state_dict.keys())
|
||||
base_keys = set(base_state_dict.keys())
|
||||
|
||||
common_keys = finetuned_keys.intersection(base_keys)
|
||||
keys_only_in_finetuned = finetuned_keys - base_keys
|
||||
keys_only_in_base = base_keys - finetuned_keys
|
||||
|
||||
for key in common_keys:
|
||||
ft_tensor = finetuned_state_dict[key]
|
||||
base_tensor = base_state_dict[key]
|
||||
|
||||
if not (ft_tensor.is_floating_point() and base_tensor.is_floating_point()):
|
||||
# print(f"Skipping key '{key}': Non-floating point tensors (FT: {ft_tensor.dtype}, Base: {base_tensor.dtype}).")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
if ft_tensor.shape != base_tensor.shape:
|
||||
print(f"Skipping key '{key}': Shape mismatch (FT: {ft_tensor.shape}, Base: {base_tensor.shape}).")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
# Calculate difference in float32 for precision, then cast to save_dtype
|
||||
delta_tensor = ft_tensor.to(dtype=torch.float32) - base_tensor.to(dtype=torch.float32)
|
||||
delta_state_dict[key] = delta_tensor
|
||||
diff_count += 1
|
||||
except Exception as e:
|
||||
print(f"Error calculating difference for key '{key}': {e}")
|
||||
error_count += 1
|
||||
|
||||
for key in keys_only_in_finetuned:
|
||||
print(f"Warning: Key '{key}' (Shape: {finetuned_state_dict[key].shape}, Dtype: {finetuned_state_dict[key].dtype}) is present in fine-tuned model but not in base model. Storing as is.")
|
||||
delta_state_dict[key] = finetuned_state_dict[key] # Store the tensor from the finetuned model
|
||||
unique_to_finetuned_count += 1
|
||||
|
||||
if keys_only_in_base:
|
||||
print(f"\nWarning: {len(keys_only_in_base)} key(s) are present only in the base model and will not be in the delta file.")
|
||||
for key in list(keys_only_in_base)[:5]: # Print first 5 as examples
|
||||
print(f" - Example key only in base: {key}")
|
||||
if len(keys_only_in_base) > 5:
|
||||
print(f" ... and {len(keys_only_in_base) - 5} more.")
|
||||
|
||||
|
||||
print(f"\nDifference calculation complete.")
|
||||
print(f" {diff_count} layers successfully diffed.")
|
||||
print(f" {unique_to_finetuned_count} layers unique to fine-tuned model (added as is).")
|
||||
print(f" {skipped_count} common layers skipped (shape/type mismatch).")
|
||||
print(f" {error_count} common layers had errors during diffing.")
|
||||
|
||||
if output_delta_path and delta_state_dict:
|
||||
save_dtype = torch.float32 # Default
|
||||
if save_dtype_str == "float16":
|
||||
save_dtype = torch.float16
|
||||
elif save_dtype_str == "bfloat16":
|
||||
save_dtype = torch.bfloat16
|
||||
elif save_dtype_str != "float32":
|
||||
print(f"Warning: Invalid save_dtype '{save_dtype_str}'. Defaulting to float32.")
|
||||
save_dtype_str = "float32" # for print message
|
||||
|
||||
print(f"\nPreparing to save delta weights with dtype: {save_dtype_str}")
|
||||
|
||||
final_save_dict = OrderedDict()
|
||||
for k, v_tensor in delta_state_dict.items():
|
||||
if v_tensor.is_floating_point():
|
||||
final_save_dict[k] = v_tensor.to(dtype=save_dtype)
|
||||
else:
|
||||
final_save_dict[k] = v_tensor # Keep non-float as is (e.g. int tensors if any)
|
||||
|
||||
try:
|
||||
save_file(final_save_dict, output_delta_path)
|
||||
print(f"Delta weights saved to: {output_delta_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving delta weights: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
return delta_state_dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Extract weight differences between a fine-tuned and a base SDXL model.")
|
||||
parser.add_argument("base_model_path", type=str, help="File path for the BASE SDXL model (.safetensors).")
|
||||
parser.add_argument("finetuned_model_path", type=str, help="File path for the FINE-TUNED SDXL model (.safetensors).")
|
||||
parser.add_argument("--output_path", type=str, default=None,
|
||||
help="Optional: File path to save the delta weights (.safetensors). "
|
||||
"If not provided, defaults to 'model_deltas/delta_[finetuned_model_name].safetensors'.")
|
||||
parser.add_argument("--save_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"],
|
||||
help="Data type for saving the delta weights. Choose from 'float32', 'float16', 'bfloat16'. "
|
||||
"Defaults to 'float32'.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("--- Model Difference Extraction Script ---")
|
||||
|
||||
if not os.path.exists(args.base_model_path):
|
||||
print(f"Error: Base model file not found at {args.base_model_path}")
|
||||
exit(1)
|
||||
if not os.path.exists(args.finetuned_model_path):
|
||||
print(f"Error: Fine-tuned model file not found at {args.finetuned_model_path}")
|
||||
exit(1)
|
||||
|
||||
output_delta_file = args.output_path
|
||||
if output_delta_file is None:
|
||||
output_dir = "model_deltas"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
finetuned_basename = os.path.splitext(os.path.basename(args.finetuned_model_path))[0]
|
||||
output_delta_file = os.path.join(output_dir, f"delta_{finetuned_basename}.safetensors")
|
||||
|
||||
# Ensure the output directory exists if a full path is given
|
||||
if output_delta_file:
|
||||
output_dir_for_file = os.path.dirname(output_delta_file)
|
||||
if output_dir_for_file and not os.path.exists(output_dir_for_file):
|
||||
os.makedirs(output_dir_for_file, exist_ok=True)
|
||||
|
||||
|
||||
differences = extract_model_differences(
|
||||
args.base_model_path,
|
||||
args.finetuned_model_path,
|
||||
output_delta_path=output_delta_file,
|
||||
save_dtype_str=args.save_dtype
|
||||
)
|
||||
|
||||
if differences:
|
||||
print(f"\nExtraction process finished. {len(differences)} total keys in the delta state_dict.")
|
||||
else:
|
||||
print("\nCould not extract differences due to errors during model loading.")
|
||||
Loading…
Reference in New Issue