Code cleanup

pull/3264/head
bmaltais 2025-05-18 17:36:24 -04:00
parent 1e0aba6b55
commit 7767a5a3ec
3 changed files with 1299 additions and 401 deletions

397
tools/Untitled-2.txt Normal file
View File

@ -0,0 +1,397 @@
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision bf16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_fro_0.65.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 16 --device cuda --sdxl --target_fro_retained 0.5 --group_size 6 --svd_mode per_layer --dynamic_param 0.65
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision bf16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_cumulative_0.9.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 16 --device cuda --sdxl --target_fro_retained 0.5 --group_size 6 --svd_mode per_layer --dynamic_param 0.9 --dynamic_method sv_cumulative
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_fro_0.5v2.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 256 --device cuda --sdxl --dynamic_param 0.5 --dynamic_method sv_fro --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_cumulative_0.5v2.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 768 --device cuda --sdxl --dynamic_param 0.5 --dynamic_method sv_cumulative --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_ratio_0.5.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 768 --device cuda --sdxl --dynamic_param 0.5 --dynamic_method sv_ratio --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_knee.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 512 --device cuda --sdxl --dynamic_method sv_knee --verbose --dynamic_param 0.5
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py `
--save_precision fp16 `
--save_to E:/lora/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES_sv_cumulative_knee.safetensors `
--model_tuned E:/models/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES.safetensors `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--dim 512 `
--device cuda `
--sdxl `
--dynamic_method sv_fro `
--verbose `
--dynamic_param 0.25
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
--save_precision fp16 `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned E:/models/sdxl/xxxRay_v11.safetensors `
--save_to E:/lora/sdxl/xxxRay_v11_sv_cumulative_knee.safetensors `
--dim 384 `
--device cuda `
--sdxl `
--dynamic_method sv_cumulative_knee `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
--save_precision fp16 `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned E:/models/sdxl/xxxRay_v11.safetensors `
--save_to E:/lora/sdxl/xxxRay_v11_sv_fro_0.85_1024.safetensors `
--dim 1024 `
--device cuda `
--sdxl `
--dynamic_method sv_fro `
--dynamic_param 0.85 `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
--save_precision fp16 `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned E:/models/sdxl/proteus_v06.safetensors `
--save_to E:/lora/sdxl/proteus_v06_sv_cumulative_knee_1024.safetensors `
--dim 1024 `
--device cuda `
--sdxl `
--dynamic_method sv_cumulative_knee `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_v2.safetensors `
--rank 4 `
--iterations 200 `
--lr 0.005 `
--device cuda `
--precision fp32 `
--verbose `
--verbose_layer_debug `
--save_weights_dtype fp16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_64_4000steps.safetensors `
--rank 64 `
--initial_alpha 32 `
--max_rank_doublings 2 `
--max_iterations 16000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.05 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_16_16000steps.safetensors `
--rank 16 `
--initial_alpha 8 `
--max_rank_retries 3 `
--rank_increase_factor 1.5 `
--max_iterations 8000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.05 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_16_8000steps.safetensors `
--rank 16 `
--initial_alpha 16 `
--max_rank_retries 6 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.05 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES.safetensors `
E:/lora/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES_loha_16_8000steps.safetensors `
--rank 16 `
--initial_alpha 16 `
--max_rank_retries 6 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/aetherverseXL_v10.safetensors `
E:/lora/sdxl/aetherverseXL_v10_loha_9e-8.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 27 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 9e-8 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--advanced_projection_decay_cap_min 0.5 `
--advanced_projection_decay_cap_max 1.05 `
--min_progress_loss_ratio 0.000001 `
--projection_sample_interval 1 `
--projection_min_ema_history 100
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/aetherverseXL_v10.safetensors `
E:/lora/sdxl/aetherverseXL_v10_loha_9e-8.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 30 `
--rank_increase_factor 1.2 `
--max_iterations 8000 `
--min_iterations 200 `
--target_loss 9e-8 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 100 `
--save_every_n_layers 10 `
--keep_n_resume_files 10 `
--rank_search_strategy binary_search_min_rank
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/aetherverseXL_v10.safetensors `
E:/lora/sdxl/aetherverseXL_v10_loha_9e-8.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 8 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 400 `
--target_loss 9e-8 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 100 `
--save_every_n_layers 10 `
--keep_n_resume_files 10 `
--skip_delta_threshold 3e-7 `
--rank_search_strategy binary_search_min_rank
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\lr_finder.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/xxxRay_v11.safetensors `
--lr_finder_num_layers 16 `
--lr_finder_min_lr 1e-8 `
--lr_finder_max_lr 0.2 `
--lr_finder_num_steps 120 `
--lr_finder_iters_per_step 40 `
--rank 8 `
--initial_alpha 8.0 `
--precision bf16 `
--device cuda `
--lr_finder_plot `
--lr_finder_show_plot
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/xxxRay_v11.safetensors `
E:/lora/sdxl/xxxRay_v11_loha_1e-7.safetensors `
--rank 2 `
--initial_alpha 2 `
--max_rank_retries 7 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 400 `
--target_loss 1e-7 `
--lr 1e-01 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 100 `
--save_every_n_layers 10 `
--keep_n_resume_files 10 `
--skip_delta_threshold 1e-7 `
--rank_search_strategy binary_search_min_rank `
--probe_aggressive_early_stop
D:\kohya_ss\venv\Scripts\python.exe D:\kohya_ss\tools\model_diff_report.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--top_n_diff 15 --plot_histograms --plot_histograms_top_n 3 --output_dir ./analysis_results
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES.safetensors `
E:/lora/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES_loha_3e-7.safetensors `
--rank 1 `
--initial_alpha 1 `
--max_rank_retries 10 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 400 `
--target_loss 3e-7 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 100 `
--save_every_n_layers 10 `
--keep_n_resume_files 10 `
--skip_delta_threshold 6e-7 `
--rank_search_strategy binary_search_min_rank `
--probe_aggressive_early_stop
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/proteus_v06.safetensors `
E:/lora/sdxl/proteus_v06_1e-7.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 27 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 1e-7 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--save_every_n_layers 10 `
--keep_n_resume_files 10 `
--rank_search_strategy binary_search_min_rank
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_1e-8v3.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 29 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 1e-8 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--advanced_projection_decay_cap_min 0.5 `
--advanced_projection_decay_cap_max 1.05 `
--min_progress_loss_ratio 0.000001 `
--projection_sample_interval 1 `
--projection_min_ema_history 100 `
--continue_training_from_loha E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_1e-8v2_resume_L422.safetensors
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES.safetensors `
E:/lora/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES_loha_9e-8.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 27 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 1e-7 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--advanced_projection_decay_cap_min 0.5 `
--advanced_projection_decay_cap_max 1.05 `
--min_progress_loss_ratio 0.000001 `
--projection_sample_interval 1 `
--projection_min_ema_history 100
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_model_difference.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--save_dtype float16
--model_org_path E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned_path E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--algo loha `
--network_alpha 64 `
--network_dim 4 `
--conv_alpha 64 `
--conv_dim 4 `
--device cuda `
--sdxl `
--save_precision fp16 `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_to E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors --algo loha --sdxl --dim 32 --conv_dim 32 --dynamic_method sv_cumulative --dynamic_param 0.99 --save_precision fp16 --device cuda --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py ^
--model_org_path "D:\StableDiffusion\models\sdxl_base_1.0.safetensors" ^
--model_tuned_path "D:\StableDiffusion\models\my_sdxl_finetune.safetensors" ^
--save_to "C:\LoRA_Extractor\output\my_loha_sdxl.safetensors" ^
--sdxl ^
--algo loha ^
--network_alpha 64 ^
--network_dim 4 ^
--conv_alpha 64 ^
--conv_dim 4 ^
--save_precision bf16 ^
--device cuda ^
--verbose
sv_cumulative_knee
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_two_pass_energy_512.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --total_rank_budget 2048 --device cuda --sdxl --svd_mode per_layer --dynamic_param 1.0 --dynamic_method two_pass_energy --verbose --min_rank 4 --max_rank 32
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py ^
--save_precision bf16 ^
--save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_two_pass_energy_512.safetensors ^
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors ^
--model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors ^
--dim 512 ^
--device cuda ^
--sdxl ^
--target_fro_retained 0.5 ^
--group_size 6 ^
--svd_mode per_layer ^
--dynamic_method two_pass_energy ^
--dynamic_param 1.0 ^
--min_rank 4 ^
--verbose

View File

@ -13,23 +13,19 @@ if sd_scripts_dir_path not in sys.path:
try:
from library import sai_model_spec, model_util, sdxl_model_util
from library.utils import setup_logging
from networks import lora # <--- CORRECTED LORA IMPORT
from networks import lora
except ImportError as e:
print(f"Error importing from sd-scripts. Please check your sd-scripts folder structure.")
print(f"Attempted to load from: {sd_scripts_dir_path}")
print(f"Original error: {e}")
print("Current sys.path relevant entries:")
for p in sys.path:
if "sd-scripts" in p or "kohya_ss" in p: # Print relevant paths for debugging
if "sd-scripts" in p or "kohya_ss" in p:
print(p)
# Ensure 'networks' directory exists in 'sd-scripts' and contains 'lora.py'
# Also ensure 'sd-scripts/networks/__init__.py' exists.
raise
# --- The rest of your script ---
import argparse
import json
# import os # Already imported
import time
import torch
from safetensors.torch import load_file, save_file
@ -41,6 +37,7 @@ logger = logging.getLogger(__name__)
MIN_SV = 1e-6
# --- Singular Value Indexing Functions (Unchanged) ---
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
@ -63,483 +60,442 @@ def index_sv_ratio(S, target):
index = max(1, min(index, len(S) - 1))
return index
def index_sv_knee_improved(S, MIN_SV_KNEE=1e-8): # MIN_SV_KNEE can be same as global MIN_SV or specific
"""
Determine rank using the knee point detection method with normalization.
Normalizes singular values and their indices to the [0,1] range
to make the knee detection scale-invariant.
"""
def index_sv_knee(S, MIN_SV_KNEE=1e-8):
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
# S is expected to be sorted in descending order.
if n < 3: return 1
s_max, s_min = S[0], S[-1]
# Handle flat or nearly flat singular value spectrum
if s_max - s_min < MIN_SV_KNEE:
# If all singular values are almost the same, a knee is not well-defined.
# Returning 1 is a conservative choice for low rank.
# Alternatively, n // 2 or n - 1 could be chosen depending on desired behavior.
return 1
# Normalize singular values to [0, 1]
# s_normalized[0] will be 1, s_normalized[n-1] will be 0
if s_max - s_min < MIN_SV_KNEE: return 1
s_normalized = (S - s_min) / (s_max - s_min)
# Normalize indices to [0, 1]
# x_normalized[0] will be 0, x_normalized[n-1] will be 1
x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
# The line in normalized space connects (x_norm[0], s_norm[0]) to (x_norm[n-1], s_norm[n-1])
# This is (0, 1) to (1, 0).
# The equation of the line passing through (0,1) and (1,0) is x + y - 1 = 0.
# So, A=1, B=1, C=-1 for the line equation Ax + By + C = 0.
# Calculate the perpendicular distance from each point (x_normalized[i], s_normalized[i]) to this line.
# Distance = |A*x_i + B*y_i + C| / sqrt(A^2 + B^2)
# Distance = |1*x_normalized + 1*s_normalized - 1| / sqrt(1^2 + 1^2)
# Distance = |x_normalized + s_normalized - 1| / sqrt(2)
# The sqrt(2) denominator is constant and doesn't affect argmax, so it can be omitted for finding the index.
distances = (x_normalized + s_normalized - 1).abs()
# Find the 0-based index of the point with the maximum distance.
knee_index_0based = torch.argmax(distances).item()
# Convert 0-based index to 1-based rank.
rank = knee_index_0based + 1
# Clamp rank similar to original: must be > 0 and <= n-1 (typical for rank reduction)
# If knee_index_0based is n-1 (last point), rank becomes n. min(n, n-1) results in n-1.
rank = max(1, min(rank, n - 1))
return rank
def index_sv_cumulative_knee(S, min_sv_threshold=1e-8):
"""
Determine rank using the knee point detection method on the normalized cumulative sum of singular values.
This method identifies a point where adding more singular values contributes diminishingly to the total sum.
"""
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
if n < 3: return 1
s_sum = torch.sum(S)
# If all singular values are zero or very small, return rank 1.
if s_sum < min_sv_threshold:
return 1
# Calculate cumulative sum of singular values, normalized by the total sum.
# y_values[0] = S[0]/s_sum, ..., y_values[n-1] = 1.0
if s_sum < min_sv_threshold: return 1
y_values = torch.cumsum(S, dim=0) / s_sum
# Normalize these y_values (cumulative sums) to the range [0,1] for knee detection.
y_min, y_max = y_values[0], y_values[n-1] # y_max is typically 1.0
# If the normalized cumulative sum curve is very flat (e.g., S[0] captures almost all energy),
# it implies the first few components are dominant.
if y_max - y_min < min_sv_threshold: # Using min_sv_threshold here as a sensitivity for flatness
# This condition means (S[0] + ... + S[n-1]) - S[0] is small relative to sum(S) if n>1
# Effectively, S[1:] components are negligible.
return 1
# y_norm[0] = 0, y_norm[n-1] = 1 (represents the normalized cumulative sum from start to end)
y_min, y_max = y_values[0], y_values[n-1]
if y_max - y_min < min_sv_threshold: return 1
y_norm = (y_values - y_min) / (y_max - y_min)
# x_values are indices, normalized to [0, 1]
# x_norm[0] = 0, ..., x_norm[n-1] = 1
x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
# The "knee" is the point on the curve (x_norm[i], y_norm[i]) that is farthest
# from the line connecting the start and end of this normalized curve.
# In this normalized space, the line connects (0,0) to (1,1).
# The equation of this line is Y = X, or X - Y = 0.
# The distance from a point (x_i, y_i) to the line X - Y = 0 is |x_i - y_i| / sqrt(1^2 + (-1)^2).
# We can maximize |x_i - y_i| (or |y_i - x_i|) as sqrt(2) is a constant factor.
distances = (y_norm - x_norm).abs() # y_norm is expected to be >= x_norm for a concave cumulative curve.
# Find the 0-based index of the point with the maximum distance.
distances = (y_norm - x_norm).abs()
knee_index_0based = torch.argmax(distances).item()
# Convert 0-based index to 1-based rank.
rank = knee_index_0based + 1
# Clamp rank to be between 1 and n-1 (as n elements give n-1 possible ranks for truncation).
# A rank of n means no truncation. n-1 is the highest sensible rank for reduction.
rank = max(1, min(rank, n - 1))
return rank
def index_sv_rel_decrease(S, tau=0.1):
"""Determine rank based on relative decrease threshold."""
if len(S) < 2:
# For matrices with fewer than 2 singular values, a relative decrease
# isn't meaningful. Returning 1 is a sensible default.
return 1
# Compute ratios of consecutive singular values
# S is sorted descending, so S[:-1] >= S[1:]
# ratios will be <= 1.0
ratios = S[1:] / S[:-1] # Example: S=[10,1,0.5], ratios=[0.1, 0.5]
# Find the smallest k such that S[k+1]/S[k] < tau.
# The rank would then be k+1, as we include S[k].
for k in range(len(ratios)): # k ranges from 0 to len(S)-2
if len(S) < 2: return 1
ratios = S[1:] / S[:-1]
for k in range(len(ratios)):
if ratios[k] < tau:
# We found a significant drop after the k-th singular value.
# So, we keep k+1 singular values (indices 0 to k).
# The rank is k+1. Since k >= 0, k+1 >= 1.
return k + 1
# If no drop below tau was found, it means all relative decreases were >= tau.
# In this case, this method suggests using all available singular values.
# The actual rank will be capped later by args.dim/conv_dim and matrix dimensions.
return k + 1
return len(S)
def save_to_file(file_name, model_to_save, state_dict_content, dtype, metadata=None): # Renamed params for clarity
if dtype is not None:
for key in list(state_dict_content.keys()):
if isinstance(state_dict_content[key], torch.Tensor):
state_dict_content[key] = state_dict_content[key].to(dtype)
# --- Utility Functions ---
def _str_to_dtype(p):
if p == "float": return torch.float
if p == "fp16": return torch.float16
if p == "bf16": return torch.bfloat16
return None
def save_to_file(file_name, state_dict_to_save, dtype, metadata=None):
# Make a copy to modify for dtype conversion if necessary
state_dict_final = {}
for key, value in state_dict_to_save.items():
if isinstance(value, torch.Tensor) and dtype is not None:
state_dict_final[key] = value.to(dtype)
else:
state_dict_final[key] = value # Handles non-tensors or when dtype is None
# save_file from safetensors expects a state_dict as the first argument if metadata is also passed.
# torch.save would also expect the state_dict.
# The 'model' variable being passed to save_file should be the state_dict itself.
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model_to_save, file_name, metadata=metadata) # Pass metadata correctly
save_file(state_dict_final, file_name, metadata=metadata)
else:
torch.save(model_to_save, file_name)
torch.save(state_dict_final, file_name)
# --- Refactored Helper Functions ---
def _load_sd_model_components(model_path, is_v2, load_dtype_torch):
logger.info(f"Loading SD model from: {model_path} (to CPU initially)")
# model_util usually loads to CPU by default, then we cast dtype
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(is_v2, model_path)
del vae # Not used
text_encoders = [text_encoder]
if load_dtype_torch:
for te in text_encoders:
te.to(load_dtype_torch)
unet.to(load_dtype_torch)
return text_encoders, unet
def _load_sdxl_model_components(model_path, target_device_override, load_dtype_torch):
# Prioritize CPU loading unless target_device_override is explicitly GPU
# This 'target_device_override' comes from args.load_original_model_to / args.load_tuned_model_to
actual_load_device = target_device_override if target_device_override else "cpu"
logger.info(f"Loading SDXL model from: {model_path} to device: {actual_load_device}")
text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_path, actual_load_device
)
del vae # Not used
text_encoders = [text_encoder1, text_encoder2]
if load_dtype_torch: # Apply dtype cast after loading to the specified device
for te in text_encoders:
te.to(load_dtype_torch)
unet.to(load_dtype_torch)
return text_encoders, unet
def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_device, min_diff_thresh, module_type_str):
diffs_map = {}
is_different_flag = False
first_diff_logged = False
for lora_o, lora_t in zip(module_loras_o, module_loras_t):
lora_name = lora_o.lora_name
if lora_o.org_module is None or lora_t.org_module is None or \
not hasattr(lora_o.org_module, 'weight') or lora_o.org_module.weight is None or \
not hasattr(lora_t.org_module, 'weight') or lora_t.org_module.weight is None:
logger.warning(f"Skipping {lora_name} in {module_type_str} due to missing org_module or weight.")
continue
# Weights are expected to be on CPU after loading, or on specified load device.
# Move them to diff_calc_device ONLY if they are not already there.
# diff_calc_device will be CPU in the corrected flow.
weight_o = lora_o.org_module.weight
weight_t = lora_t.org_module.weight
if str(weight_o.device) != str(diff_calc_device):
weight_o = weight_o.to(diff_calc_device)
if str(weight_t.device) != str(diff_calc_device):
weight_t = weight_t.to(diff_calc_device)
diff = weight_t - weight_o # Diff happens on diff_calc_device (CPU)
# No need to set lora_o.org_module.weight to None here, original weights might be reused
# by other parts of sd-scripts if this script is integrated.
# We will del the entire model objects (unet_t, text_encoders_t) later.
diffs_map[lora_name] = diff # diff is on diff_calc_device (CPU)
current_max_diff = torch.max(torch.abs(diff))
if not is_different_flag and current_max_diff > min_diff_thresh:
is_different_flag = True
if not first_diff_logged:
logger.info(f"{module_type_str} '{lora_name}' differs: max diff {current_max_diff} > {min_diff_thresh}")
first_diff_logged = True
return diffs_map, is_different_flag
def _determine_rank(S_values, dynamic_method_name, dynamic_param_value, max_rank_limit,
module_eff_in_dim, module_eff_out_dim, min_sv_threshold=MIN_SV):
if not S_values.numel() or S_values[0] <= min_sv_threshold: return 1
rank = 0
if dynamic_method_name == "sv_ratio": rank = index_sv_ratio(S_values, dynamic_param_value)
elif dynamic_method_name == "sv_cumulative": rank = index_sv_cumulative(S_values, dynamic_param_value)
elif dynamic_method_name == "sv_fro": rank = index_sv_fro(S_values, dynamic_param_value)
elif dynamic_method_name == "sv_knee": rank = index_sv_knee(S_values, min_sv_threshold)
elif dynamic_method_name == "sv_cumulative_knee": rank = index_sv_cumulative_knee(S_values, min_sv_threshold)
elif dynamic_method_name == "sv_rel_decrease": rank = index_sv_rel_decrease(S_values, dynamic_param_value)
else: rank = max_rank_limit
rank = min(rank, max_rank_limit, module_eff_in_dim, module_eff_out_dim, len(S_values))
rank = max(1, rank)
return rank
def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, rank,
clamp_quantile_val, is_conv2d, is_conv2d_3x3,
conv_kernel_size,
module_out_channels, module_in_channels,
# svd_comp_device, # U,S,Vh are on this device
target_device_for_final_weights, target_dtype_for_final_weights):
# U_full, S_all_values, Vh_full are assumed to be on the SVD computation device.
S_k = S_all_values[:rank]
U_k = U_full[:, :rank]
Vh_k = Vh_full[:rank, :]
S_k_non_negative = torch.clamp(S_k, min=0.0)
s_sqrt = torch.sqrt(S_k_non_negative)
U_final = U_k * s_sqrt.unsqueeze(0) # on svd_comp_device
Vh_final = Vh_k * s_sqrt.unsqueeze(1) # on svd_comp_device
# Clamping happens on svd_comp_device
dist = torch.cat([U_final.flatten(), Vh_final.flatten()])
hi_val = torch.quantile(dist, clamp_quantile_val)
if hi_val == 0 and torch.max(torch.abs(dist)) > 1e-9:
logger.debug(f"Clamping hi_val is zero for non-zero distribution. Max abs val: {torch.max(torch.abs(dist))}. Quantile: {clamp_quantile_val}")
U_clamped = U_final.clamp(-hi_val, hi_val)
Vh_clamped = Vh_final.clamp(-hi_val, hi_val)
if is_conv2d: # Reshaping also on svd_comp_device
U_clamped = U_clamped.reshape(module_out_channels, rank, 1, 1)
if is_conv2d_3x3:
Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, *conv_kernel_size)
else:
Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, 1, 1)
# Move to final target device and dtype at the very end
U_clamped = U_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous()
Vh_clamped = Vh_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous()
return U_clamped, Vh_clamped
def _log_svd_stats(lora_module_name, S_all_values, rank_used, min_sv_for_calc=MIN_SV):
if not S_all_values.numel():
logger.info(f"{lora_module_name:75} | rank: {rank_used}, SVD not performed (empty singular values).")
return
# S_all_values might be on GPU, move to CPU for float conversion and sum if not already
S_cpu = S_all_values.to('cpu')
s_sum_total = float(torch.sum(S_cpu))
s_sum_rank = float(torch.sum(S_cpu[:rank_used]))
fro_orig_total = float(torch.sqrt(torch.sum(S_cpu.pow(2))))
fro_reconstructed_rank = float(torch.sqrt(torch.sum(S_cpu[:rank_used].pow(2))))
ratio_sv = float('inf')
if rank_used > 0 and S_cpu[rank_used - 1].abs() > min_sv_for_calc:
ratio_sv = S_cpu[0] / S_cpu[rank_used - 1] # Ensure S_cpu[0] is also float for division
sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > min_sv_for_calc else 1.0
fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > min_sv_for_calc else 1.0
logger.info(
f"{lora_module_name:75} | rank: {rank_used}, "
f"sum(S) retained: {sum_s_retained_percentage:.2%}, "
f"Frobenius norm retained: {fro_retained_percentage:.2%}, "
f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}"
)
def _prepare_lora_metadata(output_path, is_v2_flag, base_model_ver, network_conv_dim_val,
use_dynamic_method_flag, network_dim_config_val,
is_v_param_flag, is_sdxl_flag, skip_sai_meta):
net_kwargs = {"conv_dim": str(network_conv_dim_val), "conv_alpha": str(float(network_conv_dim_val))} if network_conv_dim_val is not None else {}
if use_dynamic_method_flag:
network_dim_meta = "Dynamic"
network_alpha_meta = "Dynamic"
else:
network_dim_meta = str(network_dim_config_val)
network_alpha_meta = str(float(network_dim_config_val))
final_metadata = {
"ss_v2": str(is_v2_flag),
"ss_base_model_version": base_model_ver,
"ss_network_module": "networks.lora",
"ss_network_dim": network_dim_meta,
"ss_network_alpha": network_alpha_meta,
"ss_network_args": json.dumps(net_kwargs),
"ss_lowram": "False",
"ss_num_train_images": "N/A",
}
if not skip_sai_meta:
title = os.path.splitext(os.path.basename(output_path))[0]
is_sd2_for_meta = True
sai_metadata_content = sai_model_spec.build_metadata(
training_info=None, v2=is_v2_flag, v_parameterization=is_v_param_flag,
sdxl=is_sdxl_flag, is_sd2=is_sd2_for_meta, is_v_pred_like=False,
unet_use_linear_projection_in_v2=False, creation_time=time.time(), title=title,
)
sai_metadata_cleaned = {k: v for k, v in sai_metadata_content.items() if v is not None}
final_metadata.update(sai_metadata_cleaned)
return final_metadata
# --- Main SVD Function ---
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
dynamic_method=None,
dynamic_param=None,
verbose=False,
model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None,
conv_dim=None, v_parameterization=None, device=None, save_precision=None,
clamp_quantile=0.99, min_diff=0.01, no_metadata=False, load_precision=None,
load_original_model_to=None, load_tuned_model_to=None,
dynamic_method=None, dynamic_param=None, verbose=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time"
v_parameterization = v2 if v_parameterization is None else v_parameterization
load_dtype_torch = _str_to_dtype(load_precision)
save_dtype_torch = _str_to_dtype(save_precision) if save_precision else torch.float
# Device for SVD computation itself. Defaults to CUDA if available, else CPU.
svd_computation_device = torch.device(device if device else "cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using SVD computation device: {svd_computation_device}")
# Device for calculating weight differences. This should ideally be CPU to avoid GPU->CPU transfers if models loaded to CPU.
diff_calculation_device = torch.device("cpu")
logger.info(f"Calculating weight differences on: {diff_calculation_device}")
# Device for final LoRA weights before saving (usually CPU).
final_weights_device = torch.device("cpu")
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision) if save_precision else torch.float
work_device = "cpu"
# Load models
if not sdxl:
logger.info(f"Loading original SD model: {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
if load_dtype:
text_encoder_o.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SD model: {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
if load_dtype:
text_encoder_t.to(load_dtype)
unet_t.to(load_dtype)
# _load_sd_model_components loads to CPU, then applies dtype
text_encoders_o, unet_o = _load_sd_model_components(model_org, v2, load_dtype_torch)
text_encoders_t, unet_t = _load_sd_model_components(model_tuned, v2, load_dtype_torch)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
device_org = load_original_model_to or "cpu"
device_tuned = load_tuned_model_to or "cpu"
logger.info(f"Loading original SDXL model: {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype:
text_encoder_o1.to(load_dtype)
text_encoder_o2.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SDXL model: {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype:
text_encoder_t1.to(load_dtype)
text_encoder_t2.to(load_dtype)
unet_t.to(load_dtype)
# _load_sdxl_model_components uses load_original_model_to/load_tuned_model_to if provided, otherwise defaults to CPU.
text_encoders_o, unet_o = _load_sdxl_model_components(model_org, load_original_model_to, load_dtype_torch)
text_encoders_t, unet_t = _load_sdxl_model_components(model_tuned, load_tuned_model_to, load_dtype_torch)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# Create LoRA network
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {}
# Create LoRA networks (initially with small dim for structure)
init_dim_val = 1
# Conv_dim for network creation should be based on user's conv_dim, or init_dim_val if not set
# This is for the structure of the LoRA network object.
lora_conv_dim_init = conv_dim if conv_dim is not None else init_dim_val
kwargs_lora = {"conv_dim": lora_conv_dim_init, "conv_alpha": lora_conv_dim_init} # alpha matches dim for init
lora_network_o = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_o, unet_o, **kwargs_lora)
lora_network_t = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_t, unet_t, **kwargs_lora)
# Define a small initial dimension for memory efficiency
init_dim = 4 # Small value to minimize memory usage
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), \
f"Model versions differ: {len(lora_network_o.text_encoder_loras)} vs {len(lora_network_t.text_encoder_loras)} TEs"
# Create LoRA networks with minimal dimension
lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs)
# Compute differences on diff_calculation_device (CPU)
all_diffs = {}
te_diffs, text_encoder_different = _calculate_module_diffs_and_check(
lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras,
diff_calculation_device, min_diff, "Text Encoder"
)
if text_encoder_different:
all_diffs.update(te_diffs)
else:
logger.warning("Text encoders are considered identical based on min_diff. Not extracting TE LoRA.")
lora_network_o.text_encoder_loras = []
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)"
del text_encoders_t # Free memory
# Compute differences
diffs = {}
text_encoder_different = False
for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
unet_diffs, _ = _calculate_module_diffs_and_check(
lora_network_o.unet_loras, lora_network_t.unet_loras,
diff_calculation_device, min_diff, "U-Net"
)
all_diffs.update(unet_diffs) # All diffs are now on diff_calculation_device (CPU)
del lora_network_t, unet_t # Free memory
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}")
diffs[lora_name] = diff
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
logger.warning("Text encoders are identical. Extracting U-Net only.")
lora_network_o.text_encoder_loras = []
diffs.clear()
for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
diffs[lora_name] = diff
del lora_network_t, unet_t
# Filter relevant modules
lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
# Extract and resize LoRA using SVD
lora_names_to_process = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
logger.info("Extracting and resizing LoRA via SVD")
lora_weights = {}
with torch.no_grad():
for lora_name in tqdm(lora_names):
mat = diffs[lora_name]
if device:
mat = mat.to(device)
mat = mat.to(torch.float)
for lora_name in tqdm(lora_names_to_process):
if lora_name not in all_diffs:
logger.warning(f"Skipping {lora_name} as no diff was calculated for it.")
continue
conv2d = len(mat.size()) == 4
kernel_size = mat.size()[2:4] if conv2d else None
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = mat.size()[0:2]
if conv2d:
mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
# Determine rank
max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
if dynamic_method:
if S[0] <= MIN_SV:
rank = 1
elif dynamic_method == "sv_ratio":
rank = index_sv_ratio(S, dynamic_param)
elif dynamic_method == "sv_cumulative":
rank = index_sv_cumulative(S, dynamic_param)
elif dynamic_method == "sv_fro":
rank = index_sv_fro(S, dynamic_param)
elif dynamic_method == "sv_knee":
rank = index_sv_knee_improved(S, MIN_SV) # Pass MIN_SV or a specific threshold
elif dynamic_method == "sv_cumulative_knee": # New method
rank = index_sv_cumulative_knee(S, MIN_SV) # Pass MIN_SV or a specific threshold
elif dynamic_method == "sv_rel_decrease":
rank = index_sv_rel_decrease(S, dynamic_param)
rank = min(rank, max_rank, in_dim, out_dim)
else:
rank = min(max_rank, in_dim, out_dim)
original_diff_tensor = all_diffs[lora_name] # This is on diff_calculation_device (CPU)
rank = max(1, rank) # Ensure rank is at least 1
# Truncate SVD components and distribute sqrt(S)
S_k = S[:rank]
U_k = U[:, :rank]
Vh_k = Vh[:rank, :]
# Ensure S_k values are non-negative before sqrt to avoid NaN from tiny negative SVD artifacts
S_k_non_negative = torch.clamp(S_k, min=0.0) # Use 0.0 for float tensor
s_sqrt = torch.sqrt(S_k_non_negative)
is_conv2d_layer = len(original_diff_tensor.size()) == 4
kernel_s = original_diff_tensor.size()[2:4] if is_conv2d_layer else None
is_conv2d_3x3_layer = is_conv2d_layer and kernel_s != (1, 1)
# Distribute s_sqrt: U_final = U_k * diag(s_sqrt), Vh_final = diag(s_sqrt) * Vh_k
# Using efficient broadcasting for multiplication:
U_final = U_k * s_sqrt.unsqueeze(0) # (out_dim, rank) * (1, rank)
Vh_final = Vh_k * s_sqrt.unsqueeze(1) # (rank, in_dim_effective) * (rank, 1)
module_true_out_channels, module_true_in_channels = original_diff_tensor.size()[0:2]
# Clamp values (applied to U_final, Vh_final)
# The distribution of values in U_final and Vh_final might be different
# than the original U and Vh, so the effect of clamping might change.
dist = torch.cat([U_final.flatten(), Vh_final.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
U_clamped = U_final.clamp(-hi_val, hi_val)
Vh_clamped = Vh_final.clamp(-hi_val, hi_val)
# Move diff tensor to SVD computation device, ensure it's float32 for SVD
mat_for_svd = original_diff_tensor.to(svd_computation_device, dtype=torch.float)
if conv2d:
# U_clamped is (out_dim, rank)
U_clamped = U_clamped.reshape(out_dim, rank, 1, 1)
# Vh_clamped is (rank, in_dim * possibly_kernel_dims)
# It needs to be reshaped back to (rank, in_dim, kernel_h, kernel_w)
if conv2d_3x3 : # Original mat was (out_dim, in_dim * k_h * k_w)
Vh_clamped = Vh_clamped.reshape(rank, in_dim, *kernel_size)
else: # Original mat was (out_dim, in_dim) for 1x1 conv, kernel_size is (1,1)
Vh_clamped = Vh_clamped.reshape(rank, in_dim, *kernel_size) # kernel_size is (1,1) here
if is_conv2d_layer:
if is_conv2d_3x3_layer:
mat_for_svd = mat_for_svd.flatten(start_dim=1)
else:
mat_for_svd = mat_for_svd.squeeze()
U_clamped = U_clamped.to(work_device, dtype=save_dtype).contiguous()
Vh_clamped = Vh_clamped.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U_clamped, Vh_clamped)
if mat_for_svd.numel() == 0 or mat_for_svd.shape[0] == 0 or mat_for_svd.shape[1] == 0 :
logger.warning(f"Skipping SVD for {lora_name} due to empty/invalid shape: {mat_for_svd.shape}")
continue
try:
U_full, S_full, Vh_full = torch.linalg.svd(mat_for_svd) # SVD on svd_computation_device
except Exception as e:
logger.error(f"SVD failed for {lora_name} with shape {mat_for_svd.shape}. Error: {e}")
continue
eff_out_dim, eff_in_dim = mat_for_svd.shape[0], mat_for_svd.shape[1]
current_max_rank = dim if not is_conv2d_3x3_layer or conv_dim is None else conv_dim
rank = _determine_rank(S_full, dynamic_method, dynamic_param,
current_max_rank, eff_in_dim, eff_out_dim, MIN_SV)
U_clamped, Vh_clamped = _construct_lora_weights_from_svd_components(
U_full, S_full, Vh_full, rank, clamp_quantile,
is_conv2d_layer, is_conv2d_3x3_layer, kernel_s,
module_true_out_channels, module_true_in_channels,
final_weights_device, save_dtype_torch # Final weights to CPU with target dtype
)
lora_weights[lora_name] = (U_clamped, Vh_clamped) # U_clamped, Vh_clamped are on final_weights_device (CPU)
# Verbose output (S values are pre-modification for accurate reporting of original SVD properties)
if verbose:
s_sum_total = float(torch.sum(S))
s_sum_rank = float(torch.sum(S[:rank])) # Sum of the singular values actually used for reconstruction
fro_orig_total = float(torch.sqrt(torch.sum(S.pow(2))))
fro_reconstructed_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2)))) # Frobenius norm of the matrix part represented by chosen rank
# Ratio of the largest retained singular value to the smallest retained singular value
# S is sorted, S[0] is max. S[rank-1] is the smallest singular value included if rank > 0.
ratio_sv = S[0] / S[rank - 1] if rank > 0 and S[rank - 1].abs() > MIN_SV else float('inf') # Avoid division by zero or tiny number
# Ensure denominators are not zero for percentages
sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > MIN_SV else 1.0
fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > MIN_SV else 1.0
_log_svd_stats(lora_name, S_full, rank, MIN_SV) # S_full is on svd_computation_device
logger.info(
f"{lora_name:75} | rank: {rank}, "
f"sum(S) retained: {sum_s_retained_percentage:.2%}, "
f"Frobenius norm retained: {fro_retained_percentage:.2%}, "
f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}"
)
# Create state dict
# Create state dict for LoRA (all components are on final_weights_device (CPU))
lora_sd = {}
for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype) # alpha is rank
# Load and save LoRA
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoders_o, unet_o) # This applies weights, not strictly necessary if just saving sd
info = lora_network_save.load_state_dict(lora_sd) # This populates the network object with the weights from lora_sd
logger.info(f"Loaded extracted and resized LoRA weights into network object: {info}")
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype_torch, device=final_weights_device)
# Clean up original models from memory if they are still around and large (especially if on GPU)
del text_encoders_o, unet_o, lora_network_o, all_diffs
if 'torch' in sys.modules and hasattr(torch, 'cuda') and torch.cuda.is_available():
torch.cuda.empty_cache()
os.makedirs(os.path.dirname(save_to), exist_ok=True)
# Metadata
net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {}
# Determine network_dim and network_alpha for metadata based on dynamic method
if dynamic_method:
network_dim_meta = "Dynamic"
network_alpha_meta = "Dynamic" # Alpha is rank, which is dynamic
else:
network_dim_meta = str(dim)
network_alpha_meta = str(float(dim)) # Alpha is rank, which is dim
metadata = {
"ss_v2": str(v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": network_dim_meta,
"ss_network_alpha": network_alpha_meta, # Alpha is typically the rank
"ss_network_args": json.dumps(net_kwargs),
"ss_lowram": "False", # Assuming not specifically lowram mode
"ss_num_train_images": "N/A", # Not applicable for extraction
# Add other relevant metadata as per sai_model_spec or conventions
}
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
# Build sai_metadata, ensuring it includes necessary fields like 'ss_sd_model_hash' if possible
# For extraction, some training-specific metadata might not be relevant or available.
sai_metadata = sai_model_spec.build_metadata(
None, # training_info (usually from train_util or fine_tune) - can be None for extraction
v2,
v_parameterization,
sdxl,
True, # is_sd2
False, # is_v_pred_like
time.time(),
title=title,
# model_hash=None, # Original model hash if available
# tuned_model_hash=None # Tuned model hash if available
)
# Filter out None values from sai_metadata if any, or handle them in build_metadata
sai_metadata_cleaned = {k: v for k, v in sai_metadata.items() if v is not None}
metadata.update(sai_metadata_cleaned)
# Use the state_dict 'lora_sd' for saving, not the network object 'lora_network_save'
save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata) # Pass lora_sd as the model/state_dict to save
metadata_to_save = _prepare_lora_metadata(
save_to, v2, model_version, conv_dim,
bool(dynamic_method), dim,
v_parameterization, sdxl, no_metadata
)
save_to_file(save_to, lora_sd, save_dtype_torch, metadata_to_save) # save_dtype_torch applied again if not None
logger.info(f"LoRA saved to: {save_to}")
def setup_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)")
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models")
parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA")
parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)")
parser.add_argument("--load_precision", type=str, choices=["float", "fp16", "bf16"], default=None, help="Precision for loading models (applied after initial load)")
parser.add_argument("--save_precision", type=str, choices=["float", "fp16", "bf16"], default="float", help="Precision for saving LoRA weights")
parser.add_argument("--model_org", type=str, required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--model_tuned", type=str, required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--save_to", type=str, required=True, help="Output file name (ckpt/safetensors)")
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3")
parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)")
parser.add_argument("--conv_dim", type=int, default=None, help="Max dimension (rank) of LoRA for Conv2d-3x3. Defaults to 'dim' if not set.")
parser.add_argument("--device", type=str, default=None, help="Device for SVD computation (e.g., cuda, cpu). Defaults to cuda if available, else cpu.")
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract")
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata")
parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)")
parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)")
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract LoRA for a module")
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata from SAI and Kohya_ss")
parser.add_argument("--load_original_model_to", type=str, default=None, help="Device for original model (e.g. 'cpu', 'cuda:0'). Defaults to CPU.")
parser.add_argument("--load_tuned_model_to", type=str, default=None, help="Device for tuned model (e.g. 'cpu', 'cuda:0'). Defaults to CPU.")
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info")
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info for each module")
parser.add_argument(
"--dynamic_method",
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], # Added "sv_cumulative_knee"
help="Dynamic rank reduction method"
"--dynamic_method", type=str,
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"],
default=None, help="Dynamic rank reduction method"
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
if args.conv_dim is None:
args.conv_dim = args.dim
logger.info(f"--conv_dim not set, using value of --dim: {args.conv_dim}")
methods_requiring_param = ["sv_ratio", "sv_fro", "sv_cumulative", "sv_rel_decrease"]
if args.dynamic_method in methods_requiring_param and args.dynamic_param is None:
raise ValueError(f"Dynamic method '{args.dynamic_method}' requires --dynamic_param to be set.")
parser.error(f"Dynamic method '{args.dynamic_method}' requires --dynamic_param to be set.")
# Add a check for rank > 0 if not dynamic, or ensure dynamic methods return rank >= 1
if not args.dynamic_method and args.dim <= 0:
raise ValueError(f"--dim (rank) must be > 0. Got {args.dim}")
if args.conv_dim is not None and args.conv_dim <=0:
raise ValueError(f"--conv_dim (rank) must be > 0 if specified. Got {args.conv_dim}")
if not args.dynamic_method:
if args.dim <= 0: parser.error(f"--dim (rank) must be > 0. Got {args.dim}")
if args.conv_dim <=0: parser.error(f"--conv_dim (rank) must be > 0. Got {args.conv_dim}")
if MIN_SV <= 0: logger.warning(f"Global MIN_SV ({MIN_SV}) should be positive.")
svd(**vars(args))
svd(**vars(args))

View File

@ -0,0 +1,545 @@
import sys
import os
# 1. Add sd-scripts directory to sys.path
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
sd_scripts_dir_path = os.path.join(project_root, "sd-scripts")
if sd_scripts_dir_path not in sys.path:
sys.path.insert(0, sd_scripts_dir_path)
# Now you can import from the library package and the networks package
try:
from library import sai_model_spec, model_util, sdxl_model_util
from library.utils import setup_logging
from networks import lora # <--- CORRECTED LORA IMPORT
except ImportError as e:
print(f"Error importing from sd-scripts. Please check your sd-scripts folder structure.")
print(f"Attempted to load from: {sd_scripts_dir_path}")
print(f"Original error: {e}")
print("Current sys.path relevant entries:")
for p in sys.path:
if "sd-scripts" in p or "kohya_ss" in p: # Print relevant paths for debugging
print(p)
# Ensure 'networks' directory exists in 'sd-scripts' and contains 'lora.py'
# Also ensure 'sd-scripts/networks/__init__.py' exists.
raise
# --- The rest of your script ---
import argparse
import json
# import os # Already imported
import time
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
def index_sv_knee_improved(S, MIN_SV_KNEE=1e-8): # MIN_SV_KNEE can be same as global MIN_SV or specific
"""
Determine rank using the knee point detection method with normalization.
Normalizes singular values and their indices to the [0,1] range
to make the knee detection scale-invariant.
"""
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
# S is expected to be sorted in descending order.
s_max, s_min = S[0], S[-1]
# Handle flat or nearly flat singular value spectrum
if s_max - s_min < MIN_SV_KNEE:
# If all singular values are almost the same, a knee is not well-defined.
# Returning 1 is a conservative choice for low rank.
# Alternatively, n // 2 or n - 1 could be chosen depending on desired behavior.
return 1
# Normalize singular values to [0, 1]
# s_normalized[0] will be 1, s_normalized[n-1] will be 0
s_normalized = (S - s_min) / (s_max - s_min)
# Normalize indices to [0, 1]
# x_normalized[0] will be 0, x_normalized[n-1] will be 1
x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
# The line in normalized space connects (x_norm[0], s_norm[0]) to (x_norm[n-1], s_norm[n-1])
# This is (0, 1) to (1, 0).
# The equation of the line passing through (0,1) and (1,0) is x + y - 1 = 0.
# So, A=1, B=1, C=-1 for the line equation Ax + By + C = 0.
# Calculate the perpendicular distance from each point (x_normalized[i], s_normalized[i]) to this line.
# Distance = |A*x_i + B*y_i + C| / sqrt(A^2 + B^2)
# Distance = |1*x_normalized + 1*s_normalized - 1| / sqrt(1^2 + 1^2)
# Distance = |x_normalized + s_normalized - 1| / sqrt(2)
# The sqrt(2) denominator is constant and doesn't affect argmax, so it can be omitted for finding the index.
distances = (x_normalized + s_normalized - 1).abs()
# Find the 0-based index of the point with the maximum distance.
knee_index_0based = torch.argmax(distances).item()
# Convert 0-based index to 1-based rank.
rank = knee_index_0based + 1
# Clamp rank similar to original: must be > 0 and <= n-1 (typical for rank reduction)
# If knee_index_0based is n-1 (last point), rank becomes n. min(n, n-1) results in n-1.
rank = max(1, min(rank, n - 1))
return rank
def index_sv_cumulative_knee(S, min_sv_threshold=1e-8):
"""
Determine rank using the knee point detection method on the normalized cumulative sum of singular values.
This method identifies a point where adding more singular values contributes diminishingly to the total sum.
"""
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
s_sum = torch.sum(S)
# If all singular values are zero or very small, return rank 1.
if s_sum < min_sv_threshold:
return 1
# Calculate cumulative sum of singular values, normalized by the total sum.
# y_values[0] = S[0]/s_sum, ..., y_values[n-1] = 1.0
y_values = torch.cumsum(S, dim=0) / s_sum
# Normalize these y_values (cumulative sums) to the range [0,1] for knee detection.
y_min, y_max = y_values[0], y_values[n-1] # y_max is typically 1.0
# If the normalized cumulative sum curve is very flat (e.g., S[0] captures almost all energy),
# it implies the first few components are dominant.
if y_max - y_min < min_sv_threshold: # Using min_sv_threshold here as a sensitivity for flatness
# This condition means (S[0] + ... + S[n-1]) - S[0] is small relative to sum(S) if n>1
# Effectively, S[1:] components are negligible.
return 1
# y_norm[0] = 0, y_norm[n-1] = 1 (represents the normalized cumulative sum from start to end)
y_norm = (y_values - y_min) / (y_max - y_min)
# x_values are indices, normalized to [0, 1]
# x_norm[0] = 0, ..., x_norm[n-1] = 1
x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
# The "knee" is the point on the curve (x_norm[i], y_norm[i]) that is farthest
# from the line connecting the start and end of this normalized curve.
# In this normalized space, the line connects (0,0) to (1,1).
# The equation of this line is Y = X, or X - Y = 0.
# The distance from a point (x_i, y_i) to the line X - Y = 0 is |x_i - y_i| / sqrt(1^2 + (-1)^2).
# We can maximize |x_i - y_i| (or |y_i - x_i|) as sqrt(2) is a constant factor.
distances = (y_norm - x_norm).abs() # y_norm is expected to be >= x_norm for a concave cumulative curve.
# Find the 0-based index of the point with the maximum distance.
knee_index_0based = torch.argmax(distances).item()
# Convert 0-based index to 1-based rank.
rank = knee_index_0based + 1
# Clamp rank to be between 1 and n-1 (as n elements give n-1 possible ranks for truncation).
# A rank of n means no truncation. n-1 is the highest sensible rank for reduction.
rank = max(1, min(rank, n - 1))
return rank
def index_sv_rel_decrease(S, tau=0.1):
"""Determine rank based on relative decrease threshold."""
if len(S) < 2:
# For matrices with fewer than 2 singular values, a relative decrease
# isn't meaningful. Returning 1 is a sensible default.
return 1
# Compute ratios of consecutive singular values
# S is sorted descending, so S[:-1] >= S[1:]
# ratios will be <= 1.0
ratios = S[1:] / S[:-1] # Example: S=[10,1,0.5], ratios=[0.1, 0.5]
# Find the smallest k such that S[k+1]/S[k] < tau.
# The rank would then be k+1, as we include S[k].
for k in range(len(ratios)): # k ranges from 0 to len(S)-2
if ratios[k] < tau:
# We found a significant drop after the k-th singular value.
# So, we keep k+1 singular values (indices 0 to k).
# The rank is k+1. Since k >= 0, k+1 >= 1.
return k + 1
# If no drop below tau was found, it means all relative decreases were >= tau.
# In this case, this method suggests using all available singular values.
# The actual rank will be capped later by args.dim/conv_dim and matrix dimensions.
return len(S)
def save_to_file(file_name, model_to_save, state_dict_content, dtype, metadata=None): # Renamed params for clarity
if dtype is not None:
for key in list(state_dict_content.keys()):
if isinstance(state_dict_content[key], torch.Tensor):
state_dict_content[key] = state_dict_content[key].to(dtype)
# save_file from safetensors expects a state_dict as the first argument if metadata is also passed.
# torch.save would also expect the state_dict.
# The 'model' variable being passed to save_file should be the state_dict itself.
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model_to_save, file_name, metadata=metadata) # Pass metadata correctly
else:
torch.save(model_to_save, file_name)
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
dynamic_method=None,
dynamic_param=None,
verbose=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time"
v_parameterization = v2 if v_parameterization is None else v_parameterization
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision) if save_precision else torch.float
work_device = "cpu"
# Load models
if not sdxl:
logger.info(f"Loading original SD model: {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
if load_dtype:
text_encoder_o.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SD model: {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
if load_dtype:
text_encoder_t.to(load_dtype)
unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
device_org = load_original_model_to or "cpu"
device_tuned = load_tuned_model_to or "cpu"
logger.info(f"Loading original SDXL model: {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype:
text_encoder_o1.to(load_dtype)
text_encoder_o2.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SDXL model: {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype:
text_encoder_t1.to(load_dtype)
text_encoder_t2.to(load_dtype)
unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# Create LoRA network
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {}
# Define a small initial dimension for memory efficiency
init_dim = 4 # Small value to minimize memory usage
# Create LoRA networks with minimal dimension
lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)"
# Compute differences
diffs = {}
text_encoder_different = False
for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}")
diffs[lora_name] = diff
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
logger.warning("Text encoders are identical. Extracting U-Net only.")
lora_network_o.text_encoder_loras = []
diffs.clear()
for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
diffs[lora_name] = diff
del lora_network_t, unet_t
# Filter relevant modules
lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
# Extract and resize LoRA using SVD
logger.info("Extracting and resizing LoRA via SVD")
lora_weights = {}
with torch.no_grad():
for lora_name in tqdm(lora_names):
mat = diffs[lora_name]
if device:
mat = mat.to(device)
mat = mat.to(torch.float)
conv2d = len(mat.size()) == 4
kernel_size = mat.size()[2:4] if conv2d else None
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = mat.size()[0:2]
if conv2d:
mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
# Determine rank
max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
if dynamic_method:
if S[0] <= MIN_SV:
rank = 1
elif dynamic_method == "sv_ratio":
rank = index_sv_ratio(S, dynamic_param)
elif dynamic_method == "sv_cumulative":
rank = index_sv_cumulative(S, dynamic_param)
elif dynamic_method == "sv_fro":
rank = index_sv_fro(S, dynamic_param)
elif dynamic_method == "sv_knee":
rank = index_sv_knee_improved(S, MIN_SV) # Pass MIN_SV or a specific threshold
elif dynamic_method == "sv_cumulative_knee": # New method
rank = index_sv_cumulative_knee(S, MIN_SV) # Pass MIN_SV or a specific threshold
elif dynamic_method == "sv_rel_decrease":
rank = index_sv_rel_decrease(S, dynamic_param)
rank = min(rank, max_rank, in_dim, out_dim)
else:
rank = min(max_rank, in_dim, out_dim)
rank = max(1, rank) # Ensure rank is at least 1
# Truncate SVD components and distribute sqrt(S)
S_k = S[:rank]
U_k = U[:, :rank]
Vh_k = Vh[:rank, :]
# Ensure S_k values are non-negative before sqrt to avoid NaN from tiny negative SVD artifacts
S_k_non_negative = torch.clamp(S_k, min=0.0) # Use 0.0 for float tensor
s_sqrt = torch.sqrt(S_k_non_negative)
# Distribute s_sqrt: U_final = U_k * diag(s_sqrt), Vh_final = diag(s_sqrt) * Vh_k
# Using efficient broadcasting for multiplication:
U_final = U_k * s_sqrt.unsqueeze(0) # (out_dim, rank) * (1, rank)
Vh_final = Vh_k * s_sqrt.unsqueeze(1) # (rank, in_dim_effective) * (rank, 1)
# Clamp values (applied to U_final, Vh_final)
# The distribution of values in U_final and Vh_final might be different
# than the original U and Vh, so the effect of clamping might change.
dist = torch.cat([U_final.flatten(), Vh_final.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
U_clamped = U_final.clamp(-hi_val, hi_val)
Vh_clamped = Vh_final.clamp(-hi_val, hi_val)
if conv2d:
# U_clamped is (out_dim, rank)
U_clamped = U_clamped.reshape(out_dim, rank, 1, 1)
# Vh_clamped is (rank, in_dim * possibly_kernel_dims)
# It needs to be reshaped back to (rank, in_dim, kernel_h, kernel_w)
if conv2d_3x3 : # Original mat was (out_dim, in_dim * k_h * k_w)
Vh_clamped = Vh_clamped.reshape(rank, in_dim, *kernel_size)
else: # Original mat was (out_dim, in_dim) for 1x1 conv, kernel_size is (1,1)
Vh_clamped = Vh_clamped.reshape(rank, in_dim, *kernel_size) # kernel_size is (1,1) here
U_clamped = U_clamped.to(work_device, dtype=save_dtype).contiguous()
Vh_clamped = Vh_clamped.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U_clamped, Vh_clamped)
# Verbose output (S values are pre-modification for accurate reporting of original SVD properties)
if verbose:
s_sum_total = float(torch.sum(S))
s_sum_rank = float(torch.sum(S[:rank])) # Sum of the singular values actually used for reconstruction
fro_orig_total = float(torch.sqrt(torch.sum(S.pow(2))))
fro_reconstructed_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2)))) # Frobenius norm of the matrix part represented by chosen rank
# Ratio of the largest retained singular value to the smallest retained singular value
# S is sorted, S[0] is max. S[rank-1] is the smallest singular value included if rank > 0.
ratio_sv = S[0] / S[rank - 1] if rank > 0 and S[rank - 1].abs() > MIN_SV else float('inf') # Avoid division by zero or tiny number
# Ensure denominators are not zero for percentages
sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > MIN_SV else 1.0
fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > MIN_SV else 1.0
logger.info(
f"{lora_name:75} | rank: {rank}, "
f"sum(S) retained: {sum_s_retained_percentage:.2%}, "
f"Frobenius norm retained: {fro_retained_percentage:.2%}, "
f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}"
)
# Create state dict
lora_sd = {}
for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype) # alpha is rank
# Load and save LoRA
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoders_o, unet_o) # This applies weights, not strictly necessary if just saving sd
info = lora_network_save.load_state_dict(lora_sd) # This populates the network object with the weights from lora_sd
logger.info(f"Loaded extracted and resized LoRA weights into network object: {info}")
os.makedirs(os.path.dirname(save_to), exist_ok=True)
# Metadata
net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {}
# Determine network_dim and network_alpha for metadata based on dynamic method
if dynamic_method:
network_dim_meta = "Dynamic"
network_alpha_meta = "Dynamic" # Alpha is rank, which is dynamic
else:
network_dim_meta = str(dim)
network_alpha_meta = str(float(dim)) # Alpha is rank, which is dim
metadata = {
"ss_v2": str(v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": network_dim_meta,
"ss_network_alpha": network_alpha_meta, # Alpha is typically the rank
"ss_network_args": json.dumps(net_kwargs),
"ss_lowram": "False", # Assuming not specifically lowram mode
"ss_num_train_images": "N/A", # Not applicable for extraction
# Add other relevant metadata as per sai_model_spec or conventions
}
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
# Build sai_metadata, ensuring it includes necessary fields like 'ss_sd_model_hash' if possible
# For extraction, some training-specific metadata might not be relevant or available.
sai_metadata = sai_model_spec.build_metadata(
None, # training_info (usually from train_util or fine_tune) - can be None for extraction
v2,
v_parameterization,
sdxl,
True, # is_sd2
False, # is_v_pred_like
time.time(),
title=title,
# model_hash=None, # Original model hash if available
# tuned_model_hash=None # Tuned model hash if available
)
# Filter out None values from sai_metadata if any, or handle them in build_metadata
sai_metadata_cleaned = {k: v for k, v in sai_metadata.items() if v is not None}
metadata.update(sai_metadata_cleaned)
# Use the state_dict 'lora_sd' for saving, not the network object 'lora_network_save'
save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata) # Pass lora_sd as the model/state_dict to save
logger.info(f"LoRA saved to: {save_to}")
def setup_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)")
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models")
parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA")
parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)")
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3")
parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)")
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract")
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata")
parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)")
parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)")
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info")
parser.add_argument(
"--dynamic_method",
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], # Added "sv_cumulative_knee"
help="Dynamic rank reduction method"
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
methods_requiring_param = ["sv_ratio", "sv_fro", "sv_cumulative", "sv_rel_decrease"]
if args.dynamic_method in methods_requiring_param and args.dynamic_param is None:
raise ValueError(f"Dynamic method '{args.dynamic_method}' requires --dynamic_param to be set.")
# Add a check for rank > 0 if not dynamic, or ensure dynamic methods return rank >= 1
if not args.dynamic_method and args.dim <= 0:
raise ValueError(f"--dim (rank) must be > 0. Got {args.dim}")
if args.conv_dim is not None and args.conv_dim <=0:
raise ValueError(f"--conv_dim (rank) must be > 0 if specified. Got {args.conv_dim}")
svd(**vars(args))