mirror of https://github.com/bmaltais/kohya_ss
Merge branch 'lora' into dev
commit
932dbaae57
|
|
@ -0,0 +1,159 @@
|
|||
import safetensors.torch
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
import sys # To redirect stdout
|
||||
import traceback
|
||||
|
||||
class Logger(object):
|
||||
def __init__(self, filename="loha_analysis_output.txt"):
|
||||
self.terminal = sys.stdout
|
||||
self.log = open(filename, "w", encoding='utf-8')
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
|
||||
def flush(self):
|
||||
# This flush method is needed for python 3 compatibility.
|
||||
# This handles the flush command, which shutil.copytree or os.system uses.
|
||||
self.terminal.flush()
|
||||
self.log.flush()
|
||||
|
||||
def close(self):
|
||||
self.log.close()
|
||||
|
||||
def analyze_safetensors_file(filepath, output_filename="loha_analysis_output.txt"):
|
||||
"""
|
||||
Analyzes a .safetensors file to extract and print its metadata
|
||||
and tensor information (keys, shapes, dtypes) to a file.
|
||||
"""
|
||||
original_stdout = sys.stdout
|
||||
logger = Logger(filename=output_filename)
|
||||
sys.stdout = logger
|
||||
|
||||
try:
|
||||
print(f"--- Analyzing: {filepath} ---\n")
|
||||
print(f"--- Output will be saved to: {output_filename} ---\n")
|
||||
|
||||
# Load the tensors to get their structure
|
||||
state_dict = safetensors.torch.load_file(filepath, device="cpu") # Load to CPU to avoid potential CUDA issues
|
||||
|
||||
print("--- Tensor Information ---")
|
||||
if not state_dict:
|
||||
print("No tensors found in the state dictionary.")
|
||||
else:
|
||||
# Sort keys for consistent output
|
||||
sorted_keys = sorted(state_dict.keys())
|
||||
current_module_prefix = ""
|
||||
|
||||
# First, identify all unique module prefixes for better grouping
|
||||
module_prefixes = sorted(list(set([".".join(key.split(".")[:-1]) for key in sorted_keys if "." in key])))
|
||||
|
||||
for prefix in module_prefixes:
|
||||
if not prefix: # Skip keys that don't seem to be part of a module (e.g. global metadata tensors if any)
|
||||
continue
|
||||
print(f"\nModule: {prefix}")
|
||||
for key in sorted_keys:
|
||||
if key.startswith(prefix + "."):
|
||||
tensor = state_dict[key]
|
||||
print(f" - Key: {key}")
|
||||
print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}") # Output shape as list for clarity
|
||||
if key.endswith((".alpha", ".dim")):
|
||||
try:
|
||||
value = tensor.item()
|
||||
# Check if value is float and format if it is
|
||||
if isinstance(value, float):
|
||||
print(f" Value: {value:.8f}") # Format float to a certain precision
|
||||
else:
|
||||
print(f" Value: {value}")
|
||||
except Exception as e:
|
||||
print(f" Value: Could not extract scalar value ({tensor}, error: {e})")
|
||||
elif tensor.numel() < 10: # Print small tensors' values
|
||||
print(f" Values (first few): {tensor.flatten()[:10].tolist()}")
|
||||
|
||||
|
||||
# Print keys that might not fit the module pattern (e.g., older formats or single tensors)
|
||||
print("\n--- Other Tensor Keys (if any, not fitting typical module.parameter pattern) ---")
|
||||
other_keys_found = False
|
||||
for key in sorted_keys:
|
||||
if not any(key.startswith(p + ".") for p in module_prefixes if p):
|
||||
other_keys_found = True
|
||||
tensor = state_dict[key]
|
||||
print(f" - Key: {key}")
|
||||
print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}")
|
||||
if key.endswith((".alpha", ".dim")) or tensor.numel() == 1:
|
||||
try:
|
||||
value = tensor.item()
|
||||
if isinstance(value, float):
|
||||
print(f" Value: {value:.8f}")
|
||||
else:
|
||||
print(f" Value: {value}")
|
||||
except Exception as e:
|
||||
print(f" Value: Could not extract scalar value ({tensor}, error: {e})")
|
||||
|
||||
if not other_keys_found:
|
||||
print("No other keys found.")
|
||||
|
||||
print(f"\nTotal tensor keys found: {len(state_dict)}")
|
||||
|
||||
print("\n--- Metadata (from safetensors header) ---")
|
||||
metadata_content = OrderedDict()
|
||||
malformed_metadata_keys = []
|
||||
try:
|
||||
# Use safe_open to access the metadata separately
|
||||
with safetensors.safe_open(filepath, framework="pt", device="cpu") as f:
|
||||
metadata_keys = f.metadata()
|
||||
if metadata_keys is None:
|
||||
print("No metadata dictionary found in the file header (f.metadata() returned None).")
|
||||
else:
|
||||
for k in metadata_keys.keys():
|
||||
try:
|
||||
metadata_content[k] = metadata_keys.get(k)
|
||||
except Exception as e:
|
||||
malformed_metadata_keys.append((k, str(e)))
|
||||
metadata_content[k] = f"[Error reading value: {e}]"
|
||||
except Exception as e:
|
||||
print(f"Could not open or read metadata using safe_open: {e}")
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
|
||||
if not metadata_content and not malformed_metadata_keys:
|
||||
print("No metadata content extracted.")
|
||||
else:
|
||||
for key, value in metadata_content.items():
|
||||
print(f"- {key}: {value}")
|
||||
if key == "ss_network_args" and value and not value.startswith("[Error"):
|
||||
try:
|
||||
parsed_args = json.loads(value)
|
||||
print(" Parsed ss_network_args:")
|
||||
for arg_key, arg_value in parsed_args.items():
|
||||
print(f" - {arg_key}: {arg_value}")
|
||||
except json.JSONDecodeError:
|
||||
print(" (ss_network_args is not a valid JSON string)")
|
||||
if malformed_metadata_keys:
|
||||
print("\n--- Malformed Metadata Keys (could not be read) ---")
|
||||
for key, error_msg in malformed_metadata_keys:
|
||||
print(f"- {key}: Error: {error_msg}")
|
||||
|
||||
print("\n--- End of Analysis ---")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n!!! An error occurred during analysis !!!")
|
||||
print(str(e))
|
||||
traceback.print_exc(file=sys.stdout) # Print full traceback to the log file
|
||||
finally:
|
||||
sys.stdout = original_stdout # Restore standard output
|
||||
logger.close()
|
||||
print(f"\nAnalysis complete. Output saved to: {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_file_path = input("Enter the path to your working LoHA .safetensors file: ")
|
||||
output_file_name = "loha_analysis_results.txt" # You can change this default
|
||||
|
||||
# Suggest a default output name based on input file if desired
|
||||
# import os
|
||||
# base_name = os.path.splitext(os.path.basename(input_file_path))[0]
|
||||
# output_file_name = f"{base_name}_analysis.txt"
|
||||
|
||||
print(f"The analysis will be saved to: {output_file_name}")
|
||||
analyze_safetensors_file(input_file_path, output_filename=output_file_name)
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
|
||||
# --- Script Configuration ---
|
||||
# This script generates a minimal, non-functional LoHA (LyCORIS Hadamard Product Adaptation)
|
||||
# .safetensors file, designed to be structurally compatible with ComfyUI and
|
||||
# based on the analysis of a working SDXL LoHA file.
|
||||
|
||||
# --- Global LoHA Parameters (mimicking metadata from your working file) ---
|
||||
# These can be overridden per layer if needed for more complex dummies.
|
||||
# From your metadata: ss_network_dim: 32, ss_network_alpha: 32.0
|
||||
DEFAULT_RANK = 32
|
||||
DEFAULT_ALPHA = 32.0
|
||||
CONV_RANK = 8 # From your ss_network_args: "conv_dim": "8"
|
||||
CONV_ALPHA = 4.0 # From your ss_network_args: "conv_alpha": "4"
|
||||
|
||||
# Define example target layers.
|
||||
# We'll use names and dimensions that are representative of SDXL and your analysis.
|
||||
# Format: (layer_name, in_dim, out_dim, rank, alpha)
|
||||
# Note: For Conv2d, in_dim = in_channels, out_dim = out_channels.
|
||||
# The hada_wX_b for conv will have shape (rank, in_channels * kernel_h * kernel_w)
|
||||
# For simplicity in this dummy, we'll primarily focus on linear/attention
|
||||
# layers first, and then add one representative conv-like layer.
|
||||
|
||||
# Layer that previously caused error:
|
||||
# "ERROR loha diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight shape '[640, 640]' is invalid..."
|
||||
# This corresponds to lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_v
|
||||
# In your working LoHA, similar attention layers (e.g., *_attn1_to_k) have out_dim=640, in_dim=640, rank=32, alpha=32.0
|
||||
|
||||
EXAMPLE_LAYERS_CONFIG = [
|
||||
# UNet Attention Layers (mimicking typical SDXL structure)
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_q", # Query
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_k", # Key
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_v", # Value - this one errored previously
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_out_0", # Output Projection
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
# A deeper UNet attention block
|
||||
{
|
||||
"name": "lora_unet_middle_block_1_transformer_blocks_0_attn1_to_q",
|
||||
"in_dim": 1280, "out_dim": 1280, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_middle_block_1_transformer_blocks_0_attn1_to_out_0",
|
||||
"in_dim": 1280, "out_dim": 1280, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
# Example UNet "Convolutional" LoHA (e.g., for a ResBlock's conv layer)
|
||||
# Based on your lora_unet_input_blocks_1_0_in_layers_2 which had rank 8, alpha 4
|
||||
# Assuming original conv was Conv2d(320, 320, kernel_size=3, padding=1)
|
||||
{
|
||||
"name": "lora_unet_input_blocks_1_0_in_layers_2",
|
||||
"in_dim": 320, # in_channels
|
||||
"out_dim": 320, # out_channels
|
||||
"rank": CONV_RANK,
|
||||
"alpha": CONV_ALPHA,
|
||||
"is_conv": True,
|
||||
"kernel_size": 3 # Assume 3x3 kernel for this example
|
||||
},
|
||||
# Example Text Encoder Layer (CLIP-L, first one from your list)
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1 (original Linear(768, 3072))
|
||||
{
|
||||
"name": "lora_te1_text_model_encoder_layers_0_mlp_fc1",
|
||||
"in_dim": 768, "out_dim": 3072, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
]
|
||||
|
||||
# Use bfloat16 as seen in the analysis
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
# --- Main Script ---
|
||||
def create_dummy_loha_file(filepath="dummy_loha_corrected.safetensors"):
|
||||
"""
|
||||
Creates and saves a dummy LoHA .safetensors file with corrected structure
|
||||
and metadata based on analysis of a working file.
|
||||
"""
|
||||
state_dict = OrderedDict()
|
||||
metadata = OrderedDict()
|
||||
|
||||
print(f"Generating dummy LoHA with default rank={DEFAULT_RANK}, default alpha={DEFAULT_ALPHA}")
|
||||
print(f"Targeting DTYPE: {DTYPE}")
|
||||
|
||||
for layer_config in EXAMPLE_LAYERS_CONFIG:
|
||||
layer_name = layer_config["name"]
|
||||
in_dim = layer_config["in_dim"]
|
||||
out_dim = layer_config["out_dim"]
|
||||
rank = layer_config["rank"]
|
||||
alpha = layer_config["alpha"]
|
||||
is_conv = layer_config["is_conv"]
|
||||
|
||||
print(f"Processing layer: {layer_name} (in: {in_dim}, out: {out_dim}, rank: {rank}, alpha: {alpha}, conv: {is_conv})")
|
||||
|
||||
# --- LoHA Tensor Shapes Correction based on analysis ---
|
||||
# hada_wX_a (maps to original layer's out_features): (out_dim, rank)
|
||||
# hada_wX_b (maps from original layer's in_features): (rank, in_dim)
|
||||
# For Convolutions, in_dim refers to in_channels, out_dim to out_channels.
|
||||
# For hada_wX_b in conv, the effective input dimension includes kernel size.
|
||||
|
||||
if is_conv:
|
||||
kernel_size = layer_config.get("kernel_size", 3) # Default to 3x3 if not specified
|
||||
# This is for LoHA types that decompose the full kernel (e.g. LyCORIS full conv):
|
||||
# (rank, in_channels * kernel_h * kernel_w)
|
||||
# For simpler conv LoHA (like applying to 1x1 equivalent), it might just be (rank, in_channels)
|
||||
# The analysis for `lora_unet_input_blocks_1_0_in_layers_2` showed hada_w1_b as [8, 2880]
|
||||
# where in_dim=320, rank=8. 2880 = 320 * 9 (i.e., in_channels * kernel_h * kernel_w for 3x3)
|
||||
# This indicates a full kernel decomposition.
|
||||
eff_in_dim_conv_b = in_dim * kernel_size * kernel_size
|
||||
|
||||
hada_w1_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w1_b = torch.randn(rank, eff_in_dim_conv_b, dtype=DTYPE) * 0.01
|
||||
hada_w2_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w2_b = torch.randn(rank, eff_in_dim_conv_b, dtype=DTYPE) * 0.01
|
||||
else: # Linear layers
|
||||
hada_w1_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w1_b = torch.randn(rank, in_dim, dtype=DTYPE) * 0.01
|
||||
hada_w2_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w2_b = torch.randn(rank, in_dim, dtype=DTYPE) * 0.01
|
||||
|
||||
state_dict[f"{layer_name}.hada_w1_a"] = hada_w1_a
|
||||
state_dict[f"{layer_name}.hada_w1_b"] = hada_w1_b
|
||||
state_dict[f"{layer_name}.hada_w2_a"] = hada_w2_a
|
||||
state_dict[f"{layer_name}.hada_w2_b"] = hada_w2_b
|
||||
|
||||
# Alpha tensor (scalar)
|
||||
state_dict[f"{layer_name}.alpha"] = torch.tensor(float(alpha), dtype=DTYPE)
|
||||
|
||||
# IMPORTANT: No per-module ".dim" tensor, as per analysis of working file.
|
||||
# Rank is implicit in weight shapes and global metadata.
|
||||
|
||||
# --- Metadata (mimicking the working LoHA file) ---
|
||||
metadata["ss_network_module"] = "lycoris.kohya"
|
||||
metadata["ss_network_dim"] = str(DEFAULT_RANK) # Global/default rank
|
||||
metadata["ss_network_alpha"] = str(DEFAULT_ALPHA) # Global/default alpha
|
||||
metadata["ss_network_algo"] = "loha" # Also specified inside ss_network_args by convention
|
||||
|
||||
# Mimic ss_network_args from your file
|
||||
network_args = {
|
||||
"conv_dim": str(CONV_RANK),
|
||||
"conv_alpha": str(CONV_ALPHA),
|
||||
"algo": "loha",
|
||||
# Add other args from your file if they seem relevant for loading structure,
|
||||
# but these are the most critical for type/rank.
|
||||
"dropout": "0.0", # From your file, though value might not matter for dummy
|
||||
"rank_dropout": "0", # from your file
|
||||
"module_dropout": "0", # from your file
|
||||
"use_tucker": "False", # from your file
|
||||
"use_scalar": "False", # from your file
|
||||
"rank_dropout_scale": "False", # from your file
|
||||
"train_norm": "False" # from your file
|
||||
}
|
||||
metadata["ss_network_args"] = json.dumps(network_args)
|
||||
|
||||
# Other potentially useful metadata from your working file (optional for basic loading)
|
||||
metadata["ss_sd_model_name"] = "sd_xl_base_1.0.safetensors" # Example base model
|
||||
metadata["ss_resolution"] = "(1024,1024)" # Example, format might vary
|
||||
metadata["modelspec.sai_model_spec"] = "1.0.0"
|
||||
metadata["modelspec.implementation"] = "https_//github.com/Stability-AI/generative-models" # fixed typo
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base/lora" # Even for LoHA, this is often used
|
||||
metadata["ss_mixed_precision"] = "bf16"
|
||||
metadata["ss_note"] = "Dummy LoHA (corrected) for ComfyUI validation. Not trained."
|
||||
|
||||
|
||||
# --- Save the State Dictionary with Metadata ---
|
||||
try:
|
||||
save_file(state_dict, filepath, metadata=metadata)
|
||||
print(f"\nSuccessfully saved dummy LoHA file to: {filepath}")
|
||||
print("\nFile structure (tensor keys):")
|
||||
for key in state_dict.keys():
|
||||
print(f"- {key}: shape {state_dict[key].shape}, dtype {state_dict[key].dtype}")
|
||||
print("\nMetadata:")
|
||||
for key, value in metadata.items():
|
||||
print(f"- {key}: {value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError saving file: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_dummy_loha_file()
|
||||
|
||||
# --- Verification Note for ComfyUI ---
|
||||
# 1. Place `dummy_loha_corrected.safetensors` into `ComfyUI/models/loras/`.
|
||||
# 2. Load an SDXL base model in ComfyUI.
|
||||
# 3. Add a "Load LoRA" node and select `dummy_loha_corrected.safetensors`.
|
||||
# 4. Connect the LoRA node between the checkpoint loader and the KSampler.
|
||||
#
|
||||
# Expected outcome:
|
||||
# - ComfyUI should load the file without "key not loaded" or "dimension mismatch" errors
|
||||
# for the layers defined in EXAMPLE_LAYERS_CONFIG.
|
||||
# - The LoRA node should correctly identify it as a LoHA/LyCORIS model.
|
||||
# - If you have layers in your SDXL model that match the names in EXAMPLE_LAYERS_CONFIG,
|
||||
# ComfyUI will attempt to apply these (random) weights.
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
|
||||
--save_precision fp16 `
|
||||
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
|
||||
--model_tuned E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
|
||||
--save_to E:/lora/sdxl/dreamshaperXL_alpha2Xl10_sv_fro_0.9_1024.safetensors `
|
||||
--dim 1024 `
|
||||
--device cuda `
|
||||
--sdxl `
|
||||
--dynamic_method sv_fro `
|
||||
--dynamic_param 0.9 `
|
||||
--verbose
|
||||
|
||||
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
|
||||
--save_precision fp16 `
|
||||
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
|
||||
--model_tuned E:/models/sdxl/proteus_v06.safetensors `
|
||||
--save_to E:/lora/sdxl/proteus_v06_sv_cumulative_knee_1024.safetensors `
|
||||
--dim 1024 `
|
||||
--device cuda `
|
||||
--sdxl `
|
||||
--dynamic_method sv_cumulative_knee `
|
||||
--verbose
|
||||
|
||||
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\lr_finder.py `
|
||||
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
|
||||
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
|
||||
--lr_finder_num_layers 16 `
|
||||
--lr_finder_min_lr 1e-8 `
|
||||
--lr_finder_max_lr 0.2 `
|
||||
--lr_finder_num_steps 120 `
|
||||
--lr_finder_iters_per_step 40 `
|
||||
--rank 8 `
|
||||
--initial_alpha 8.0 `
|
||||
--precision bf16 `
|
||||
--device cuda `
|
||||
--lr_finder_plot `
|
||||
--lr_finder_show_plot
|
||||
|
||||
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
|
||||
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
|
||||
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
|
||||
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_1e-7.safetensors `
|
||||
--rank 2 `
|
||||
--initial_alpha 2 `
|
||||
--max_rank_retries 7 `
|
||||
--rank_increase_factor 2 `
|
||||
--max_iterations 8000 `
|
||||
--min_iterations 400 `
|
||||
--target_loss 1e-7 `
|
||||
--lr 1e-01 `
|
||||
--device cuda `
|
||||
--precision fp32 `
|
||||
--verbose `
|
||||
--save_weights_dtype bf16 `
|
||||
--progress_check_interval 100 `
|
||||
--save_every_n_layers 10 `
|
||||
--keep_n_resume_files 10 `
|
||||
--skip_delta_threshold 1e-7 `
|
||||
--rank_search_strategy binary_search_min_rank `
|
||||
--probe_aggressive_early_stop
|
||||
|
||||
D:\kohya_ss\venv\Scripts\python.exe D:\kohya_ss\tools\model_diff_report.py `
|
||||
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
|
||||
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
|
||||
--top_n_diff 15 --plot_histograms --plot_histograms_top_n 3 --output_dir ./analysis_results
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,535 @@
|
|||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import logging # Import for logging
|
||||
|
||||
# NEW: Add diffusers import for model loading
|
||||
try:
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
except ImportError:
|
||||
print("Diffusers library not found. Please install it: pip install diffusers transformers accelerate")
|
||||
raise
|
||||
|
||||
# --- Localized Logging Setup ---
|
||||
def _local_setup_logging(log_level=logging.INFO):
|
||||
"""
|
||||
Sets up basic logging to console.
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)-8s %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
_local_setup_logging() # Initialize logging
|
||||
logger = logging.getLogger(__name__) # Get logger for this module
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
# --- Localized sd-scripts constants and utility functions ---
|
||||
_LOCAL_MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_v10"
|
||||
|
||||
def _local_get_model_version_str_for_sd1_sd2(is_v2: bool, is_v_parameterization: bool) -> str:
|
||||
if is_v2:
|
||||
return "v2-v" if is_v_parameterization else "v2"
|
||||
return "v1"
|
||||
|
||||
# --- Localized LoRA Placeholder and Network Creation ---
|
||||
class LocalLoRAModulePlaceholder:
|
||||
def __init__(self, lora_name: str, org_module: torch.nn.Module):
|
||||
self.lora_name = lora_name
|
||||
self.org_module = org_module
|
||||
# Add other attributes if _calculate_module_diffs_and_check needs them,
|
||||
# but it primarily uses .lora_name and .org_module.weight
|
||||
|
||||
def _local_create_network_placeholders(text_encoders: list, unet: torch.nn.Module, lora_conv_dim_init: int):
|
||||
"""
|
||||
Creates placeholders for LoRA-able modules in text encoders and UNet.
|
||||
Mimics the module identification and naming of sd-scripts' lora.create_network.
|
||||
`lora_conv_dim_init`: If > 0, Conv2d layers are considered for LoRA.
|
||||
"""
|
||||
unet_loras = []
|
||||
text_encoder_loras = []
|
||||
|
||||
# Target U-Net modules
|
||||
for name, module in unet.named_modules():
|
||||
lora_name = "lora_unet_" + name.replace(".", "_")
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
unet_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
elif isinstance(module, torch.nn.Conv2d):
|
||||
if lora_conv_dim_init > 0: # Only consider conv layers if conv_dim > 0
|
||||
# Kernel size check might be relevant if sd-scripts has specific logic,
|
||||
# but for diffing, any conv is a candidate if conv_dim > 0.
|
||||
# SVD will later handle rank based on actual layer type (1x1 vs 3x3).
|
||||
unet_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
|
||||
# Target Text Encoder modules
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if text_encoder is None: # SDXL can have None TEs if not loaded
|
||||
continue
|
||||
# Determine prefix based on number of text encoders (for SDXL compatibility)
|
||||
te_prefix = f"lora_te{i+1}_" if len(text_encoders) > 1 else "lora_te_"
|
||||
|
||||
for name, module in text_encoder.named_modules():
|
||||
lora_name = te_prefix + name.replace(".", "_")
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
text_encoder_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
# Conv2d in text encoders is rare but check just in case (sd-scripts might)
|
||||
elif isinstance(module, torch.nn.Conv2d):
|
||||
if lora_conv_dim_init > 0:
|
||||
text_encoder_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
|
||||
logger.info(f"Found {len(text_encoder_loras)} LoRA-able placeholder modules in Text Encoders.")
|
||||
logger.info(f"Found {len(unet_loras)} LoRA-able placeholder modules in U-Net.")
|
||||
return text_encoder_loras, unet_loras
|
||||
|
||||
|
||||
# --- Singular Value Indexing Functions (Unchanged) ---
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
return index
|
||||
|
||||
def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
return index
|
||||
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv / target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
return index
|
||||
|
||||
def index_sv_knee(S, MIN_SV_KNEE=1e-8):
|
||||
n = len(S)
|
||||
if n < 3: return 1
|
||||
s_max, s_min = S[0], S[-1]
|
||||
if s_max - s_min < MIN_SV_KNEE: return 1
|
||||
s_normalized = (S - s_min) / (s_max - s_min)
|
||||
x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
|
||||
distances = (x_normalized + s_normalized - 1).abs()
|
||||
knee_index_0based = torch.argmax(distances).item()
|
||||
rank = knee_index_0based + 1
|
||||
rank = max(1, min(rank, n - 1))
|
||||
return rank
|
||||
|
||||
def index_sv_cumulative_knee(S, min_sv_threshold=1e-8):
|
||||
n = len(S)
|
||||
if n < 3: return 1
|
||||
s_sum = torch.sum(S)
|
||||
if s_sum < min_sv_threshold: return 1
|
||||
y_values = torch.cumsum(S, dim=0) / s_sum
|
||||
y_min, y_max = y_values[0], y_values[n-1]
|
||||
if y_max - y_min < min_sv_threshold: return 1
|
||||
y_norm = (y_values - y_min) / (y_max - y_min)
|
||||
x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
|
||||
distances = (y_norm - x_norm).abs()
|
||||
knee_index_0based = torch.argmax(distances).item()
|
||||
rank = knee_index_0based + 1
|
||||
rank = max(1, min(rank, n - 1))
|
||||
return rank
|
||||
|
||||
def index_sv_rel_decrease(S, tau=0.1):
|
||||
if len(S) < 2: return 1
|
||||
ratios = S[1:] / S[:-1]
|
||||
for k in range(len(ratios)):
|
||||
if ratios[k] < tau:
|
||||
return k + 1
|
||||
return len(S)
|
||||
|
||||
# --- Utility Functions ---
|
||||
def _str_to_dtype(p):
|
||||
if p == "float": return torch.float
|
||||
if p == "fp16": return torch.float16
|
||||
if p == "bf16": return torch.bfloat16
|
||||
return None
|
||||
|
||||
def save_to_file(file_name, state_dict_to_save, dtype, metadata=None):
|
||||
state_dict_final = {}
|
||||
for key, value in state_dict_to_save.items():
|
||||
if isinstance(value, torch.Tensor) and dtype is not None:
|
||||
state_dict_final[key] = value.to(dtype)
|
||||
else:
|
||||
state_dict_final[key] = value
|
||||
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(state_dict_final, file_name, metadata=metadata)
|
||||
else:
|
||||
torch.save(state_dict_final, file_name)
|
||||
|
||||
def _build_local_sai_metadata(title, creation_time, is_v2_flag, is_v_param_flag, is_sdxl_flag):
|
||||
metadata = {}
|
||||
metadata["ss_sd_model_name"] = str(title)
|
||||
metadata["ss_creation_time"] = str(int(creation_time))
|
||||
if is_sdxl_flag:
|
||||
metadata["ss_base_model_version"] = "sdxl_v10"
|
||||
metadata["ss_sdxl_model_version"] = "1.0"
|
||||
if is_v_param_flag:
|
||||
metadata["ss_v_parameterization"] = "true"
|
||||
elif is_v2_flag:
|
||||
metadata["ss_base_model_version"] = "sd_v2"
|
||||
if is_v_param_flag:
|
||||
metadata["ss_v_parameterization"] = "true"
|
||||
else:
|
||||
metadata["ss_base_model_version"] = "sd_v1"
|
||||
if is_v_param_flag:
|
||||
metadata["ss_v_parameterization"] = "true"
|
||||
return metadata
|
||||
|
||||
# --- MODIFIED Helper Functions for Model Loading ---
|
||||
def _load_sd_model_components(model_path, is_v2_flag, target_device_override, load_dtype_torch):
|
||||
logger.info(f"Loading SD model using Diffusers.StableDiffusionPipeline from: {model_path}")
|
||||
pipeline = StableDiffusionPipeline.from_single_file(
|
||||
model_path,
|
||||
torch_dtype=load_dtype_torch
|
||||
)
|
||||
eff_device = target_device_override if target_device_override else "cpu"
|
||||
text_encoder = pipeline.text_encoder.to(eff_device)
|
||||
unet = pipeline.unet.to(eff_device)
|
||||
text_encoders = [text_encoder]
|
||||
logger.info(f"Loaded SD model components. UNet device: {unet.device}, TextEncoder device: {text_encoder.device}")
|
||||
return text_encoders, unet
|
||||
|
||||
def _load_sdxl_model_components(model_path, target_device_override, load_dtype_torch):
|
||||
actual_load_device = target_device_override if target_device_override else "cpu"
|
||||
logger.info(f"Loading SDXL model using Diffusers.StableDiffusionXLPipeline from: {model_path} to device: {actual_load_device}")
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
model_path,
|
||||
torch_dtype=load_dtype_torch
|
||||
)
|
||||
pipeline.to(actual_load_device)
|
||||
text_encoder = pipeline.text_encoder
|
||||
text_encoder_2 = pipeline.text_encoder_2
|
||||
unet = pipeline.unet
|
||||
text_encoders = [text_encoder, text_encoder_2]
|
||||
logger.info(f"Loaded SDXL model components. UNet device: {unet.device}, TextEncoder1 device: {text_encoder.device}, TextEncoder2 device: {text_encoder_2.device}")
|
||||
return text_encoders, unet
|
||||
|
||||
def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_device, min_diff_thresh, module_type_str):
|
||||
diffs_map = {}
|
||||
is_different_flag = False
|
||||
first_diff_logged = False
|
||||
for lora_o, lora_t in zip(module_loras_o, module_loras_t):
|
||||
lora_name = lora_o.lora_name
|
||||
if lora_o.org_module is None or lora_t.org_module is None or \
|
||||
not hasattr(lora_o.org_module, 'weight') or lora_o.org_module.weight is None or \
|
||||
not hasattr(lora_t.org_module, 'weight') or lora_t.org_module.weight is None:
|
||||
logger.warning(f"Skipping {lora_name} in {module_type_str} due to missing org_module or weight.")
|
||||
continue
|
||||
weight_o = lora_o.org_module.weight
|
||||
weight_t = lora_t.org_module.weight
|
||||
if str(weight_o.device) != str(diff_calc_device): weight_o = weight_o.to(diff_calc_device)
|
||||
if str(weight_t.device) != str(diff_calc_device): weight_t = weight_t.to(diff_calc_device)
|
||||
diff = weight_t - weight_o
|
||||
diffs_map[lora_name] = diff
|
||||
current_max_diff = torch.max(torch.abs(diff))
|
||||
if not is_different_flag and current_max_diff > min_diff_thresh:
|
||||
is_different_flag = True
|
||||
if not first_diff_logged:
|
||||
logger.info(f"{module_type_str} '{lora_name}' differs: max diff {current_max_diff} > {min_diff_thresh}")
|
||||
first_diff_logged = True
|
||||
return diffs_map, is_different_flag
|
||||
|
||||
def _determine_rank(S_values, dynamic_method_name, dynamic_param_value, max_rank_limit,
|
||||
module_eff_in_dim, module_eff_out_dim, min_sv_threshold=MIN_SV):
|
||||
if not S_values.numel() or S_values[0] <= min_sv_threshold: return 1
|
||||
rank = 0
|
||||
if dynamic_method_name == "sv_ratio": rank = index_sv_ratio(S_values, dynamic_param_value)
|
||||
elif dynamic_method_name == "sv_cumulative": rank = index_sv_cumulative(S_values, dynamic_param_value)
|
||||
elif dynamic_method_name == "sv_fro": rank = index_sv_fro(S_values, dynamic_param_value)
|
||||
elif dynamic_method_name == "sv_knee": rank = index_sv_knee(S_values, min_sv_threshold)
|
||||
elif dynamic_method_name == "sv_cumulative_knee": rank = index_sv_cumulative_knee(S_values, min_sv_threshold)
|
||||
elif dynamic_method_name == "sv_rel_decrease": rank = index_sv_rel_decrease(S_values, dynamic_param_value)
|
||||
else: rank = max_rank_limit
|
||||
rank = min(rank, max_rank_limit, module_eff_in_dim, module_eff_out_dim, len(S_values))
|
||||
rank = max(1, rank)
|
||||
return rank
|
||||
|
||||
def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, rank,
|
||||
clamp_quantile_val, is_conv2d, is_conv2d_3x3,
|
||||
conv_kernel_size,
|
||||
module_out_channels, module_in_channels,
|
||||
target_device_for_final_weights, target_dtype_for_final_weights):
|
||||
S_k = S_all_values[:rank]
|
||||
U_k = U_full[:, :rank]
|
||||
Vh_k = Vh_full[:rank, :]
|
||||
S_k_non_negative = torch.clamp(S_k, min=0.0)
|
||||
s_sqrt = torch.sqrt(S_k_non_negative)
|
||||
U_final = U_k * s_sqrt.unsqueeze(0)
|
||||
Vh_final = Vh_k * s_sqrt.unsqueeze(1)
|
||||
dist = torch.cat([U_final.flatten(), Vh_final.flatten()])
|
||||
hi_val = torch.quantile(dist, clamp_quantile_val)
|
||||
if hi_val == 0 and torch.max(torch.abs(dist)) > 1e-9:
|
||||
logger.debug(f"Clamping hi_val is zero for non-zero distribution. Max abs val: {torch.max(torch.abs(dist))}. Quantile: {clamp_quantile_val}")
|
||||
U_clamped = U_final.clamp(-hi_val, hi_val)
|
||||
Vh_clamped = Vh_final.clamp(-hi_val, hi_val)
|
||||
if is_conv2d:
|
||||
U_clamped = U_clamped.reshape(module_out_channels, rank, 1, 1)
|
||||
if is_conv2d_3x3:
|
||||
Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, *conv_kernel_size)
|
||||
else:
|
||||
Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, 1, 1)
|
||||
U_clamped = U_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous()
|
||||
Vh_clamped = Vh_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous()
|
||||
return U_clamped, Vh_clamped
|
||||
|
||||
def _log_svd_stats(lora_module_name, S_all_values, rank_used, min_sv_for_calc=MIN_SV):
|
||||
if not S_all_values.numel():
|
||||
logger.info(f"{lora_module_name:75} | rank: {rank_used}, SVD not performed (empty singular values).")
|
||||
return
|
||||
S_cpu = S_all_values.to('cpu')
|
||||
s_sum_total = float(torch.sum(S_cpu))
|
||||
s_sum_rank = float(torch.sum(S_cpu[:rank_used]))
|
||||
fro_orig_total = float(torch.sqrt(torch.sum(S_cpu.pow(2))))
|
||||
fro_reconstructed_rank = float(torch.sqrt(torch.sum(S_cpu[:rank_used].pow(2))))
|
||||
ratio_sv = float('inf')
|
||||
if rank_used > 0 and S_cpu[rank_used - 1].abs() > min_sv_for_calc:
|
||||
ratio_sv = S_cpu[0] / S_cpu[rank_used - 1]
|
||||
sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > min_sv_for_calc else 1.0
|
||||
fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > min_sv_for_calc else 1.0
|
||||
logger.info(
|
||||
f"{lora_module_name:75} | rank: {rank_used}, "
|
||||
f"sum(S) retained: {sum_s_retained_percentage:.2%}, "
|
||||
f"Frobenius norm retained: {fro_retained_percentage:.2%}, "
|
||||
f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}"
|
||||
)
|
||||
|
||||
def _prepare_lora_metadata(output_path, is_v2_flag, kohya_base_model_version_str, network_conv_dim_val,
|
||||
use_dynamic_method_flag, network_dim_config_val,
|
||||
is_v_param_flag, is_sdxl_flag, skip_sai_meta):
|
||||
net_kwargs = {"conv_dim": str(network_conv_dim_val), "conv_alpha": str(float(network_conv_dim_val))} if network_conv_dim_val is not None else {}
|
||||
if use_dynamic_method_flag:
|
||||
network_dim_meta = "Dynamic"
|
||||
network_alpha_meta = "Dynamic"
|
||||
else:
|
||||
network_dim_meta = str(network_dim_config_val)
|
||||
network_alpha_meta = str(float(network_dim_config_val))
|
||||
final_metadata = {
|
||||
"ss_v2": str(is_v2_flag),
|
||||
"ss_base_model_version": kohya_base_model_version_str,
|
||||
"ss_network_module": "networks.lora", # This remains for compatibility with tools expecting it
|
||||
"ss_network_dim": network_dim_meta,
|
||||
"ss_network_alpha": network_alpha_meta,
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
"ss_lowram": "False",
|
||||
"ss_num_train_images": "N/A",
|
||||
}
|
||||
if not skip_sai_meta:
|
||||
title = os.path.splitext(os.path.basename(output_path))[0]
|
||||
current_time = time.time()
|
||||
sai_metadata_content = _build_local_sai_metadata(
|
||||
title=title, creation_time=current_time, is_v2_flag=is_v2_flag,
|
||||
is_v_param_flag=is_v_param_flag, is_sdxl_flag=is_sdxl_flag
|
||||
)
|
||||
final_metadata.update(sai_metadata_content)
|
||||
return final_metadata
|
||||
|
||||
# --- Main SVD Function ---
|
||||
def svd(
|
||||
model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None,
|
||||
conv_dim=None, v_parameterization=None, device=None, save_precision=None,
|
||||
clamp_quantile=0.99, min_diff=0.01, no_metadata=False, load_precision=None,
|
||||
load_original_model_to=None, load_tuned_model_to=None,
|
||||
dynamic_method=None, dynamic_param=None, verbose=False,
|
||||
):
|
||||
actual_v_parameterization = v2 if v_parameterization is None else v_parameterization
|
||||
load_dtype_torch = _str_to_dtype(load_precision)
|
||||
save_dtype_torch = _str_to_dtype(save_precision) if save_precision else torch.float
|
||||
|
||||
svd_computation_device = torch.device(device if device else "cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Using SVD computation device: {svd_computation_device}")
|
||||
diff_calculation_device = torch.device("cpu")
|
||||
logger.info(f"Calculating weight differences on: {diff_calculation_device}")
|
||||
final_weights_device = torch.device("cpu")
|
||||
|
||||
if not sdxl:
|
||||
text_encoders_o, unet_o = _load_sd_model_components(model_org, v2, load_original_model_to, load_dtype_torch)
|
||||
text_encoders_t, unet_t = _load_sd_model_components(model_tuned, v2, load_tuned_model_to, load_dtype_torch)
|
||||
kohya_model_version = _local_get_model_version_str_for_sd1_sd2(v2, actual_v_parameterization)
|
||||
else:
|
||||
text_encoders_o, unet_o = _load_sdxl_model_components(model_org, load_original_model_to, load_dtype_torch)
|
||||
text_encoders_t, unet_t = _load_sdxl_model_components(model_tuned, load_tuned_model_to, load_dtype_torch)
|
||||
kohya_model_version = _LOCAL_MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# Determine lora_conv_dim_init based on conv_dim argument for network creation
|
||||
# The original script used init_dim_val (1) if conv_dim was None.
|
||||
# Here, conv_dim is already defaulted to args.dim if None by the main block.
|
||||
# So, lora_conv_dim_init will be args.conv_dim (which defaults to args.dim).
|
||||
# If args.conv_dim was explicitly 0, this would be 0.
|
||||
lora_conv_dim_init_val = conv_dim # conv_dim is args.conv_dim (or args.dim)
|
||||
|
||||
# Create LoRA placeholders using the localized function
|
||||
text_encoder_loras_o, unet_loras_o = _local_create_network_placeholders(text_encoders_o, unet_o, lora_conv_dim_init_val)
|
||||
text_encoder_loras_t, unet_loras_t = _local_create_network_placeholders(text_encoders_t, unet_t, lora_conv_dim_init_val) # same conv_dim logic for tuned
|
||||
|
||||
# Group LoRA placeholders for easier processing (mimicking LoraNetwork structure somewhat)
|
||||
class LocalLoraNetworkPlaceholder:
|
||||
def __init__(self, te_loras, unet_loras_list):
|
||||
self.text_encoder_loras = te_loras
|
||||
self.unet_loras = unet_loras_list
|
||||
|
||||
lora_network_o = LocalLoraNetworkPlaceholder(text_encoder_loras_o, unet_loras_o)
|
||||
lora_network_t = LocalLoraNetworkPlaceholder(text_encoder_loras_t, unet_loras_t)
|
||||
|
||||
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), \
|
||||
f"Model versions (based on identified LoRA-able TE modules) differ: {len(lora_network_o.text_encoder_loras)} vs {len(lora_network_t.text_encoder_loras)} TEs"
|
||||
|
||||
all_diffs = {}
|
||||
te_diffs, text_encoder_different = _calculate_module_diffs_and_check(
|
||||
lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras,
|
||||
diff_calculation_device, min_diff, "Text Encoder"
|
||||
)
|
||||
|
||||
if text_encoder_different:
|
||||
all_diffs.update(te_diffs)
|
||||
else:
|
||||
logger.warning("Text encoders are considered identical based on min_diff. Not extracting TE LoRA.")
|
||||
# To prevent processing empty list later, ensure it's empty if no diffs
|
||||
lora_network_o.text_encoder_loras = []
|
||||
del text_encoders_t # Free memory early
|
||||
|
||||
unet_diffs, _ = _calculate_module_diffs_and_check(
|
||||
lora_network_o.unet_loras, lora_network_t.unet_loras,
|
||||
diff_calculation_device, min_diff, "U-Net"
|
||||
)
|
||||
all_diffs.update(unet_diffs)
|
||||
del lora_network_t, unet_t # Free memory early
|
||||
|
||||
# Ensure lora_names_to_process only includes modules from lora_network_o
|
||||
# that are actually present (e.g., if TEs were skipped)
|
||||
lora_names_to_process = set()
|
||||
if text_encoder_different: # Only add TE loras if they were deemed different
|
||||
lora_names_to_process.update(p.lora_name for p in lora_network_o.text_encoder_loras)
|
||||
lora_names_to_process.update(p.lora_name for p in lora_network_o.unet_loras)
|
||||
|
||||
logger.info("Extracting and resizing LoRA via SVD")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name in tqdm(lora_names_to_process):
|
||||
if lora_name not in all_diffs:
|
||||
logger.warning(f"Skipping {lora_name} as no diff was calculated for it (e.g., Text Encoders were identical).")
|
||||
continue
|
||||
original_diff_tensor = all_diffs[lora_name]
|
||||
is_conv2d_layer = len(original_diff_tensor.size()) == 4
|
||||
kernel_s = original_diff_tensor.size()[2:4] if is_conv2d_layer else None
|
||||
is_conv2d_3x3_layer = is_conv2d_layer and kernel_s != (1, 1)
|
||||
module_true_out_channels, module_true_in_channels = original_diff_tensor.size()[0:2]
|
||||
mat_for_svd = original_diff_tensor.to(svd_computation_device, dtype=torch.float)
|
||||
if is_conv2d_layer:
|
||||
if is_conv2d_3x3_layer: mat_for_svd = mat_for_svd.flatten(start_dim=1)
|
||||
else: mat_for_svd = mat_for_svd.squeeze()
|
||||
if mat_for_svd.numel() == 0 or mat_for_svd.shape[0] == 0 or mat_for_svd.shape[1] == 0 :
|
||||
logger.warning(f"Skipping SVD for {lora_name} due to empty/invalid shape: {mat_for_svd.shape}")
|
||||
continue
|
||||
try:
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(mat_for_svd)
|
||||
except Exception as e:
|
||||
logger.error(f"SVD failed for {lora_name} with shape {mat_for_svd.shape}. Error: {e}")
|
||||
continue
|
||||
|
||||
# Max rank for SVD is based on 'dim' for linear and 'conv_dim' for conv3x3
|
||||
# The original `current_max_rank` logic was:
|
||||
# current_max_rank = dim if not is_conv2d_3x3_layer or conv_dim is None else conv_dim
|
||||
# Here, `dim` is args.dim and `conv_dim` is args.conv_dim (defaulted to args.dim)
|
||||
module_specific_max_rank = conv_dim if is_conv2d_3x3_layer else dim
|
||||
|
||||
eff_out_dim, eff_in_dim = mat_for_svd.shape[0], mat_for_svd.shape[1]
|
||||
rank = _determine_rank(S_full, dynamic_method, dynamic_param,
|
||||
module_specific_max_rank, eff_in_dim, eff_out_dim, MIN_SV)
|
||||
U_clamped, Vh_clamped = _construct_lora_weights_from_svd_components(
|
||||
U_full, S_full, Vh_full, rank, clamp_quantile,
|
||||
is_conv2d_layer, is_conv2d_3x3_layer, kernel_s,
|
||||
module_true_out_channels, module_true_in_channels,
|
||||
final_weights_device, save_dtype_torch
|
||||
)
|
||||
lora_weights[lora_name] = (U_clamped, Vh_clamped)
|
||||
if verbose: _log_svd_stats(lora_name, S_full, rank, MIN_SV)
|
||||
|
||||
lora_sd = {}
|
||||
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
||||
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
||||
# Alpha is set to the rank (dim of down_weight's 0th axis, which is rank)
|
||||
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype_torch, device=final_weights_device)
|
||||
|
||||
del text_encoders_o, unet_o, lora_network_o, all_diffs # Clean up original models and placeholders
|
||||
if 'torch' in sys.modules and hasattr(torch, 'cuda') and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not os.path.exists(os.path.dirname(save_to)) and os.path.dirname(save_to) != "":
|
||||
os.makedirs(os.path.dirname(save_to), exist_ok=True)
|
||||
|
||||
metadata_to_save = _prepare_lora_metadata(
|
||||
output_path=save_to,
|
||||
is_v2_flag=v2,
|
||||
kohya_base_model_version_str=kohya_model_version,
|
||||
network_conv_dim_val=conv_dim, # This is args.conv_dim (defaulted to args.dim)
|
||||
use_dynamic_method_flag=bool(dynamic_method),
|
||||
network_dim_config_val=dim, # This is args.dim
|
||||
is_v_param_flag=actual_v_parameterization,
|
||||
is_sdxl_flag=sdxl,
|
||||
skip_sai_meta=no_metadata
|
||||
)
|
||||
|
||||
save_to_file(save_to, lora_sd, save_dtype_torch, metadata_to_save)
|
||||
logger.info(f"LoRA saved to: {save_to}")
|
||||
|
||||
def setup_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
|
||||
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2 if --v2 is set)")
|
||||
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
|
||||
parser.add_argument("--load_precision", type=str, choices=["float", "fp16", "bf16"], default=None, help="Precision for loading models (applied after initial load)")
|
||||
parser.add_argument("--save_precision", type=str, choices=["float", "fp16", "bf16"], default="float", help="Precision for saving LoRA weights")
|
||||
parser.add_argument("--model_org", type=str, required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
|
||||
parser.add_argument("--model_tuned", type=str, required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
|
||||
parser.add_argument("--save_to", type=str, required=True, help="Output file name (ckpt/safetensors)")
|
||||
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
|
||||
parser.add_argument("--conv_dim", type=int, default=None, help="Max dimension (rank) of LoRA for Conv2d-3x3. Defaults to 'dim' if not set.")
|
||||
parser.add_argument("--device", type=str, default=None, help="Device for SVD computation (e.g., cuda, cpu). Defaults to cuda if available, else cpu.")
|
||||
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
|
||||
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract LoRA for a module")
|
||||
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata from SAI and Kohya_ss")
|
||||
parser.add_argument("--load_original_model_to", type=str, default=None, help="Device for original model (e.g. 'cpu', 'cuda:0'). Defaults to CPU for SD1/2, honored for SDXL.")
|
||||
parser.add_argument("--load_tuned_model_to", type=str, default=None, help="Device for tuned model (e.g. 'cpu', 'cuda:0'). Defaults to CPU for SD1/2, honored for SDXL.")
|
||||
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
|
||||
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info for each module")
|
||||
parser.add_argument(
|
||||
"--dynamic_method", type=str,
|
||||
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"],
|
||||
default=None, help="Dynamic rank reduction method"
|
||||
)
|
||||
return parser
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.conv_dim is None:
|
||||
args.conv_dim = args.dim # Default conv_dim to dim if not provided
|
||||
logger.info(f"--conv_dim not set, using value of --dim: {args.conv_dim}")
|
||||
|
||||
methods_requiring_param = ["sv_ratio", "sv_fro", "sv_cumulative", "sv_rel_decrease"]
|
||||
if args.dynamic_method in methods_requiring_param and args.dynamic_param is None:
|
||||
parser.error(f"Dynamic method '{args.dynamic_method}' requires --dynamic_param to be set.")
|
||||
|
||||
if not args.dynamic_method: # Ranks must be positive if not using dynamic method
|
||||
if args.dim <= 0: parser.error(f"--dim (rank) must be > 0. Got {args.dim}")
|
||||
if args.conv_dim <=0: parser.error(f"--conv_dim (rank) must be > 0. Got {args.conv_dim}") # Check after defaulting
|
||||
|
||||
if MIN_SV <= 0: logger.warning(f"Global MIN_SV ({MIN_SV}) should be positive.")
|
||||
|
||||
svd_args = vars(args).copy()
|
||||
svd(**svd_args)
|
||||
|
|
@ -1,360 +0,0 @@
|
|||
# extract approximating LoRA by svd from two SD models
|
||||
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||
# Thanks to cloneofsimo!
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, model_util, sdxl_model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# CLAMP_QUANTILE = 0.99
|
||||
# MIN_DIFF = 1e-1
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def svd(
|
||||
model_org=None,
|
||||
model_tuned=None,
|
||||
save_to=None,
|
||||
dim=4,
|
||||
v2=None,
|
||||
sdxl=None,
|
||||
conv_dim=None,
|
||||
v_parameterization=None,
|
||||
device=None,
|
||||
save_precision=None,
|
||||
clamp_quantile=0.99,
|
||||
min_diff=0.01,
|
||||
no_metadata=False,
|
||||
load_precision=None,
|
||||
load_original_model_to=None,
|
||||
load_tuned_model_to=None,
|
||||
):
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
||||
if v_parameterization is None:
|
||||
v_parameterization = v2
|
||||
|
||||
load_dtype = str_to_dtype(load_precision) if load_precision else None
|
||||
save_dtype = str_to_dtype(save_precision)
|
||||
work_device = "cpu"
|
||||
|
||||
# load models
|
||||
if not sdxl:
|
||||
logger.info(f"loading original SD model : {model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||
text_encoders_o = [text_encoder_o]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o = text_encoder_o.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
logger.info(f"loading tuned SD model : {model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||
text_encoders_t = [text_encoder_t]
|
||||
if load_dtype is not None:
|
||||
text_encoder_t = text_encoder_t.to(load_dtype)
|
||||
unet_t = unet_t.to(load_dtype)
|
||||
|
||||
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||
else:
|
||||
device_org = load_original_model_to if load_original_model_to else "cpu"
|
||||
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
|
||||
|
||||
logger.info(f"loading original SDXL model : {model_org}")
|
||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
||||
)
|
||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o1 = text_encoder_o1.to(load_dtype)
|
||||
text_encoder_o2 = text_encoder_o2.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
logger.info(f"loading original SDXL model : {model_tuned}")
|
||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
||||
)
|
||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||
if load_dtype is not None:
|
||||
text_encoder_t1 = text_encoder_t1.to(load_dtype)
|
||||
text_encoder_t2 = text_encoder_t2.to(load_dtype)
|
||||
unet_t = unet_t.to(load_dtype)
|
||||
|
||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||
if conv_dim is None:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
|
||||
|
||||
lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
|
||||
lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
|
||||
assert len(lora_network_o.text_encoder_loras) == len(
|
||||
lora_network_t.text_encoder_loras
|
||||
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||
|
||||
# get diffs
|
||||
diffs = {}
|
||||
text_encoder_different = False
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||
|
||||
# clear weight to save memory
|
||||
module_o.weight = None
|
||||
module_t.weight = None
|
||||
|
||||
# Text Encoder might be same
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||
text_encoder_different = True
|
||||
logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
|
||||
diffs[lora_name] = diff
|
||||
|
||||
# clear target Text Encoder to save memory
|
||||
for text_encoder in text_encoders_t:
|
||||
del text_encoder
|
||||
|
||||
if not text_encoder_different:
|
||||
logger.warning("Text encoder is same. Extract U-Net only.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
diffs = {} # clear diffs
|
||||
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||
|
||||
# clear weight to save memory
|
||||
module_o.weight = None
|
||||
module_t.weight = None
|
||||
|
||||
diffs[lora_name] = diff
|
||||
|
||||
# clear LoRA network, target U-Net to save memory
|
||||
del lora_network_o
|
||||
del lora_network_t
|
||||
del unet_t
|
||||
|
||||
# make LoRA with svd
|
||||
logger.info("calculating by svd")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name, mat in tqdm(list(diffs.items())):
|
||||
if args.device:
|
||||
mat = mat.to(args.device)
|
||||
mat = mat.to(torch.float) # calc by float
|
||||
|
||||
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
|
||||
rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
|
||||
# logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
|
||||
if conv2d:
|
||||
if conv2d_3x3:
|
||||
mat = mat.flatten(start_dim=1)
|
||||
else:
|
||||
mat = mat.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, clamp_quantile)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
|
||||
U = U.to(work_device, dtype=save_dtype).contiguous()
|
||||
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
|
||||
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
|
||||
# make state dict for LoRA
|
||||
lora_sd = {}
|
||||
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
||||
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
||||
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
|
||||
|
||||
# load state dict to LoRA and save it
|
||||
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
|
||||
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
|
||||
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
logger.info(f"Loading extracted LoRA weights: {info}")
|
||||
|
||||
dir_name = os.path.dirname(save_to)
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
|
||||
# minimum metadata
|
||||
net_kwargs = {}
|
||||
if conv_dim is not None:
|
||||
net_kwargs["conv_dim"] = str(conv_dim)
|
||||
net_kwargs["conv_alpha"] = str(float(conv_dim))
|
||||
|
||||
metadata = {
|
||||
"ss_v2": str(v2),
|
||||
"ss_base_model_version": model_version,
|
||||
"ss_network_module": "networks.lora",
|
||||
"ss_network_dim": str(dim),
|
||||
"ss_network_alpha": str(float(dim)),
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
}
|
||||
|
||||
if not no_metadata:
|
||||
title = os.path.splitext(os.path.basename(save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
||||
logger.info(f"LoRA weights are saved to: {save_to}")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
||||
parser.add_argument(
|
||||
"--v_parameterization",
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_org",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_tuned",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||
parser.add_argument(
|
||||
"--conv_dim",
|
||||
type=int,
|
||||
default=None,
|
||||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
||||
)
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument(
|
||||
"--clamp_quantile",
|
||||
type=float,
|
||||
default=0.99,
|
||||
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_diff",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
|
||||
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_original_model_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_tuned_model_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
svd(**vars(args))
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import argparse # Import argparse
|
||||
|
||||
def extract_model_differences(base_model_path, finetuned_model_path, output_delta_path=None, save_dtype_str="float32"):
|
||||
"""
|
||||
Calculates the difference between the state dictionaries of a fine-tuned model
|
||||
and a base model.
|
||||
|
||||
Args:
|
||||
base_model_path (str): Path to the base model .safetensors file.
|
||||
finetuned_model_path (str): Path to the fine-tuned model .safetensors file.
|
||||
output_delta_path (str, optional): Path to save the resulting delta weights
|
||||
.safetensors file. If None, not saved.
|
||||
save_dtype_str (str, optional): Data type to save the delta weights ('float32', 'float16', 'bfloat16').
|
||||
Defaults to 'float32'.
|
||||
Returns:
|
||||
OrderedDict: A state dictionary containing the delta weights.
|
||||
Returns None if loading fails or other critical errors.
|
||||
"""
|
||||
print(f"Loading base model from: {base_model_path}")
|
||||
try:
|
||||
# Ensure model is loaded to CPU to avoid CUDA issues if not needed for diffing
|
||||
base_state_dict = load_file(base_model_path, device="cpu")
|
||||
print(f"Base model loaded. Found {len(base_state_dict)} tensors.")
|
||||
except Exception as e:
|
||||
print(f"Error loading base model: {e}")
|
||||
return None
|
||||
|
||||
print(f"\nLoading fine-tuned model from: {finetuned_model_path}")
|
||||
try:
|
||||
finetuned_state_dict = load_file(finetuned_model_path, device="cpu")
|
||||
print(f"Fine-tuned model loaded. Found {len(finetuned_state_dict)} tensors.")
|
||||
except Exception as e:
|
||||
print(f"Error loading fine-tuned model: {e}")
|
||||
return None
|
||||
|
||||
delta_state_dict = OrderedDict()
|
||||
diff_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
unique_to_finetuned_count = 0
|
||||
unique_to_base_count = 0
|
||||
|
||||
print("\nCalculating differences...")
|
||||
|
||||
# Keys in finetuned model
|
||||
finetuned_keys = set(finetuned_state_dict.keys())
|
||||
base_keys = set(base_state_dict.keys())
|
||||
|
||||
common_keys = finetuned_keys.intersection(base_keys)
|
||||
keys_only_in_finetuned = finetuned_keys - base_keys
|
||||
keys_only_in_base = base_keys - finetuned_keys
|
||||
|
||||
for key in common_keys:
|
||||
ft_tensor = finetuned_state_dict[key]
|
||||
base_tensor = base_state_dict[key]
|
||||
|
||||
if not (ft_tensor.is_floating_point() and base_tensor.is_floating_point()):
|
||||
# print(f"Skipping key '{key}': Non-floating point tensors (FT: {ft_tensor.dtype}, Base: {base_tensor.dtype}).")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
if ft_tensor.shape != base_tensor.shape:
|
||||
print(f"Skipping key '{key}': Shape mismatch (FT: {ft_tensor.shape}, Base: {base_tensor.shape}).")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
# Calculate difference in float32 for precision, then cast to save_dtype
|
||||
delta_tensor = ft_tensor.to(dtype=torch.float32) - base_tensor.to(dtype=torch.float32)
|
||||
delta_state_dict[key] = delta_tensor
|
||||
diff_count += 1
|
||||
except Exception as e:
|
||||
print(f"Error calculating difference for key '{key}': {e}")
|
||||
error_count += 1
|
||||
|
||||
for key in keys_only_in_finetuned:
|
||||
print(f"Warning: Key '{key}' (Shape: {finetuned_state_dict[key].shape}, Dtype: {finetuned_state_dict[key].dtype}) is present in fine-tuned model but not in base model. Storing as is.")
|
||||
delta_state_dict[key] = finetuned_state_dict[key] # Store the tensor from the finetuned model
|
||||
unique_to_finetuned_count += 1
|
||||
|
||||
if keys_only_in_base:
|
||||
print(f"\nWarning: {len(keys_only_in_base)} key(s) are present only in the base model and will not be in the delta file.")
|
||||
for key in list(keys_only_in_base)[:5]: # Print first 5 as examples
|
||||
print(f" - Example key only in base: {key}")
|
||||
if len(keys_only_in_base) > 5:
|
||||
print(f" ... and {len(keys_only_in_base) - 5} more.")
|
||||
|
||||
|
||||
print(f"\nDifference calculation complete.")
|
||||
print(f" {diff_count} layers successfully diffed.")
|
||||
print(f" {unique_to_finetuned_count} layers unique to fine-tuned model (added as is).")
|
||||
print(f" {skipped_count} common layers skipped (shape/type mismatch).")
|
||||
print(f" {error_count} common layers had errors during diffing.")
|
||||
|
||||
if output_delta_path and delta_state_dict:
|
||||
save_dtype = torch.float32 # Default
|
||||
if save_dtype_str == "float16":
|
||||
save_dtype = torch.float16
|
||||
elif save_dtype_str == "bfloat16":
|
||||
save_dtype = torch.bfloat16
|
||||
elif save_dtype_str != "float32":
|
||||
print(f"Warning: Invalid save_dtype '{save_dtype_str}'. Defaulting to float32.")
|
||||
save_dtype_str = "float32" # for print message
|
||||
|
||||
print(f"\nPreparing to save delta weights with dtype: {save_dtype_str}")
|
||||
|
||||
final_save_dict = OrderedDict()
|
||||
for k, v_tensor in delta_state_dict.items():
|
||||
if v_tensor.is_floating_point():
|
||||
final_save_dict[k] = v_tensor.to(dtype=save_dtype)
|
||||
else:
|
||||
final_save_dict[k] = v_tensor # Keep non-float as is (e.g. int tensors if any)
|
||||
|
||||
try:
|
||||
save_file(final_save_dict, output_delta_path)
|
||||
print(f"Delta weights saved to: {output_delta_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving delta weights: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
return delta_state_dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Extract weight differences between a fine-tuned and a base SDXL model.")
|
||||
parser.add_argument("base_model_path", type=str, help="File path for the BASE SDXL model (.safetensors).")
|
||||
parser.add_argument("finetuned_model_path", type=str, help="File path for the FINE-TUNED SDXL model (.safetensors).")
|
||||
parser.add_argument("--output_path", type=str, default=None,
|
||||
help="Optional: File path to save the delta weights (.safetensors). "
|
||||
"If not provided, defaults to 'model_deltas/delta_[finetuned_model_name].safetensors'.")
|
||||
parser.add_argument("--save_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"],
|
||||
help="Data type for saving the delta weights. Choose from 'float32', 'float16', 'bfloat16'. "
|
||||
"Defaults to 'float32'.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("--- Model Difference Extraction Script ---")
|
||||
|
||||
if not os.path.exists(args.base_model_path):
|
||||
print(f"Error: Base model file not found at {args.base_model_path}")
|
||||
exit(1)
|
||||
if not os.path.exists(args.finetuned_model_path):
|
||||
print(f"Error: Fine-tuned model file not found at {args.finetuned_model_path}")
|
||||
exit(1)
|
||||
|
||||
output_delta_file = args.output_path
|
||||
if output_delta_file is None:
|
||||
output_dir = "model_deltas"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
finetuned_basename = os.path.splitext(os.path.basename(args.finetuned_model_path))[0]
|
||||
output_delta_file = os.path.join(output_dir, f"delta_{finetuned_basename}.safetensors")
|
||||
|
||||
# Ensure the output directory exists if a full path is given
|
||||
if output_delta_file:
|
||||
output_dir_for_file = os.path.dirname(output_delta_file)
|
||||
if output_dir_for_file and not os.path.exists(output_dir_for_file):
|
||||
os.makedirs(output_dir_for_file, exist_ok=True)
|
||||
|
||||
|
||||
differences = extract_model_differences(
|
||||
args.base_model_path,
|
||||
args.finetuned_model_path,
|
||||
output_delta_path=output_delta_file,
|
||||
save_dtype_str=args.save_dtype
|
||||
)
|
||||
|
||||
if differences:
|
||||
print(f"\nExtraction process finished. {len(differences)} total keys in the delta state_dict.")
|
||||
else:
|
||||
print("\nCould not extract differences due to errors during model loading.")
|
||||
Loading…
Reference in New Issue