Cleanup files

pull/3264/head
bmaltais 2025-05-27 08:47:38 -04:00
parent 829d5a6af3
commit c0be9f70da
10 changed files with 66 additions and 2692 deletions

@ -1 +1 @@
Subproject commit e2ed26510450cf147da1b66aea5154d04d0942ec Subproject commit 5753b8ff6bc045c27c1c61535e35195da860269c

View File

@ -1,287 +0,0 @@
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision bf16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_fro_0.65.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 16 --device cuda --sdxl --target_fro_retained 0.5 --group_size 6 --svd_mode per_layer --dynamic_param 0.65
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision bf16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_cumulative_0.9.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 16 --device cuda --sdxl --target_fro_retained 0.5 --group_size 6 --svd_mode per_layer --dynamic_param 0.9 --dynamic_method sv_cumulative
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_fro_0.5v2.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 256 --device cuda --sdxl --dynamic_param 0.5 --dynamic_method sv_fro --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_cumulative_0.5v2.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 768 --device cuda --sdxl --dynamic_param 0.5 --dynamic_method sv_cumulative --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_ratio_0.5.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 768 --device cuda --sdxl --dynamic_param 0.5 --dynamic_method sv_ratio --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_knee.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 512 --device cuda --sdxl --dynamic_method sv_knee --verbose --dynamic_param 0.5
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py `
--save_precision fp16 `
--save_to E:/lora/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES_sv_cumulative_knee.safetensors `
--model_tuned E:/models/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES.safetensors `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--dim 512 `
--device cuda `
--sdxl `
--dynamic_method sv_cumulative_knee `
--verbose `
--dynamic_param 0.25
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_v2.safetensors `
--rank 4 `
--iterations 200 `
--lr 0.005 `
--device cuda `
--precision fp32 `
--verbose `
--verbose_layer_debug `
--save_weights_dtype fp16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_64_4000steps.safetensors `
--rank 64 `
--initial_alpha 32 `
--max_rank_doublings 2 `
--max_iterations 16000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.05 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_16_16000steps.safetensors `
--rank 16 `
--initial_alpha 8 `
--max_rank_retries 3 `
--rank_increase_factor 1.5 `
--max_iterations 8000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.05 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_16_8000steps.safetensors `
--rank 16 `
--initial_alpha 16 `
--max_rank_retries 6 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.05 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES.safetensors `
E:/lora/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES_loha_16_8000steps.safetensors `
--rank 16 `
--initial_alpha 16 `
--max_rank_retries 6 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/aetherverseXL_v10.safetensors `
E:/lora/sdxl/aetherverseXL_v10_loha_9e-8.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 27 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 9e-8 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--advanced_projection_decay_cap_min 0.5 `
--advanced_projection_decay_cap_max 1.05 `
--min_progress_loss_ratio 0.000001 `
--projection_sample_interval 1 `
--projection_min_ema_history 100
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/aetherverseXL_v10.safetensors `
E:/lora/sdxl/aetherverseXL_v10_loha_9e-8.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 27 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 9e-8 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--save_every_n_layers 10 `
--keep_n_resume_files 10
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/proteus_v06.safetensors `
E:/lora/sdxl/proteus_v06_1e-7.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 27 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 1e-7 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--save_every_n_layers 10 `
--keep_n_resume_files 10
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_1e-8v3.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 29 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 1e-8 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--advanced_projection_decay_cap_min 0.5 `
--advanced_projection_decay_cap_max 1.05 `
--min_progress_loss_ratio 0.000001 `
--projection_sample_interval 1 `
--projection_min_ema_history 100 `
--continue_training_from_loha E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_1e-8v2_resume_L422.safetensors
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES.safetensors `
E:/lora/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES_loha_9e-8.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 27 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 1e-7 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--advanced_projection_decay_cap_min 0.5 `
--advanced_projection_decay_cap_max 1.05 `
--min_progress_loss_ratio 0.000001 `
--projection_sample_interval 1 `
--projection_min_ema_history 100
C:\Users\berna\Downloads\Dune_Movie_Loha2.safetensors
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py `
--save_to E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha.safetensors `
--model_org_path E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned_path E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--algo loha `
--network_alpha 64 `
--network_dim 4 `
--conv_alpha 64 `
--conv_dim 4 `
--device cuda `
--sdxl `
--save_precision fp16 `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_model_difference.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--save_dtype float16
--model_org_path E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned_path E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--algo loha `
--network_alpha 64 `
--network_dim 4 `
--conv_alpha 64 `
--conv_dim 4 `
--device cuda `
--sdxl `
--save_precision fp16 `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_to E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors --algo loha --sdxl --dim 32 --conv_dim 32 --dynamic_method sv_cumulative --dynamic_param 0.99 --save_precision fp16 --device cuda --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py ^
--model_org_path "D:\StableDiffusion\models\sdxl_base_1.0.safetensors" ^
--model_tuned_path "D:\StableDiffusion\models\my_sdxl_finetune.safetensors" ^
--save_to "C:\LoRA_Extractor\output\my_loha_sdxl.safetensors" ^
--sdxl ^
--algo loha ^
--network_alpha 64 ^
--network_dim 4 ^
--conv_alpha 64 ^
--conv_dim 4 ^
--save_precision bf16 ^
--device cuda ^
--verbose
sv_cumulative_knee
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_two_pass_energy_512.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --total_rank_budget 2048 --device cuda --sdxl --svd_mode per_layer --dynamic_param 1.0 --dynamic_method two_pass_energy --verbose --min_rank 4 --max_rank 32
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py ^
--save_precision bf16 ^
--save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_two_pass_energy_512.safetensors ^
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors ^
--model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors ^
--dim 512 ^
--device cuda ^
--sdxl ^
--target_fro_retained 0.5 ^
--group_size 6 ^
--svd_mode per_layer ^
--dynamic_method two_pass_energy ^
--dynamic_param 1.0 ^
--min_rank 4 ^
--verbose

View File

@ -1,397 +0,0 @@
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision bf16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_fro_0.65.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 16 --device cuda --sdxl --target_fro_retained 0.5 --group_size 6 --svd_mode per_layer --dynamic_param 0.65
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision bf16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_cumulative_0.9.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 16 --device cuda --sdxl --target_fro_retained 0.5 --group_size 6 --svd_mode per_layer --dynamic_param 0.9 --dynamic_method sv_cumulative
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_fro_0.5v2.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 256 --device cuda --sdxl --dynamic_param 0.5 --dynamic_method sv_fro --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_cumulative_0.5v2.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 768 --device cuda --sdxl --dynamic_param 0.5 --dynamic_method sv_cumulative --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_ratio_0.5.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 768 --device cuda --sdxl --dynamic_param 0.5 --dynamic_method sv_ratio --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_sv_knee.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --dim 512 --device cuda --sdxl --dynamic_method sv_knee --verbose --dynamic_param 0.5
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py `
--save_precision fp16 `
--save_to E:/lora/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES_sv_cumulative_knee.safetensors `
--model_tuned E:/models/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES.safetensors `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--dim 512 `
--device cuda `
--sdxl `
--dynamic_method sv_fro `
--verbose `
--dynamic_param 0.25
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
--save_precision fp16 `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned E:/models/sdxl/xxxRay_v11.safetensors `
--save_to E:/lora/sdxl/xxxRay_v11_sv_cumulative_knee.safetensors `
--dim 384 `
--device cuda `
--sdxl `
--dynamic_method sv_cumulative_knee `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
--save_precision fp16 `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned E:/models/sdxl/xxxRay_v11.safetensors `
--save_to E:/lora/sdxl/xxxRay_v11_sv_fro_0.9_1024.safetensors `
--dim 1024 `
--device cuda `
--sdxl `
--dynamic_method sv_fro `
--dynamic_param 0.9 `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
--save_precision fp16 `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned E:/models/sdxl/proteus_v06.safetensors `
--save_to E:/lora/sdxl/proteus_v06_sv_cumulative_knee_1024.safetensors `
--dim 1024 `
--device cuda `
--sdxl `
--dynamic_method sv_cumulative_knee `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_v2.safetensors `
--rank 4 `
--iterations 200 `
--lr 0.005 `
--device cuda `
--precision fp32 `
--verbose `
--verbose_layer_debug `
--save_weights_dtype fp16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_64_4000steps.safetensors `
--rank 64 `
--initial_alpha 32 `
--max_rank_doublings 2 `
--max_iterations 16000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.05 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_16_16000steps.safetensors `
--rank 16 `
--initial_alpha 8 `
--max_rank_retries 3 `
--rank_increase_factor 1.5 `
--max_iterations 8000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.05 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_16_8000steps.safetensors `
--rank 16 `
--initial_alpha 16 `
--max_rank_retries 6 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.05 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES.safetensors `
E:/lora/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES_loha_16_8000steps.safetensors `
--rank 16 `
--initial_alpha 16 `
--max_rank_retries 6 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 200 `
--target_loss 9.9999999e-8 `
--lr 0.1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/aetherverseXL_v10.safetensors `
E:/lora/sdxl/aetherverseXL_v10_loha_9e-8.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 27 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 9e-8 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--advanced_projection_decay_cap_min 0.5 `
--advanced_projection_decay_cap_max 1.05 `
--min_progress_loss_ratio 0.000001 `
--projection_sample_interval 1 `
--projection_min_ema_history 100
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/aetherverseXL_v10.safetensors `
E:/lora/sdxl/aetherverseXL_v10_loha_9e-8.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 30 `
--rank_increase_factor 1.2 `
--max_iterations 8000 `
--min_iterations 200 `
--target_loss 9e-8 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 100 `
--save_every_n_layers 10 `
--keep_n_resume_files 10 `
--rank_search_strategy binary_search_min_rank
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/aetherverseXL_v10.safetensors `
E:/lora/sdxl/aetherverseXL_v10_loha_9e-8.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 8 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 400 `
--target_loss 9e-8 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 100 `
--save_every_n_layers 10 `
--keep_n_resume_files 10 `
--skip_delta_threshold 3e-7 `
--rank_search_strategy binary_search_min_rank
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\lr_finder.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/xxxRay_v11.safetensors `
--lr_finder_num_layers 16 `
--lr_finder_min_lr 1e-8 `
--lr_finder_max_lr 0.2 `
--lr_finder_num_steps 120 `
--lr_finder_iters_per_step 40 `
--rank 8 `
--initial_alpha 8.0 `
--precision bf16 `
--device cuda `
--lr_finder_plot `
--lr_finder_show_plot
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/xxxRay_v11.safetensors `
E:/lora/sdxl/xxxRay_v11_loha_1e-7.safetensors `
--rank 2 `
--initial_alpha 2 `
--max_rank_retries 7 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 400 `
--target_loss 1e-7 `
--lr 1e-01 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 100 `
--save_every_n_layers 10 `
--keep_n_resume_files 10 `
--skip_delta_threshold 1e-7 `
--rank_search_strategy binary_search_min_rank `
--probe_aggressive_early_stop
D:\kohya_ss\venv\Scripts\python.exe D:\kohya_ss\tools\model_diff_report.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--top_n_diff 15 --plot_histograms --plot_histograms_top_n 3 --output_dir ./analysis_results
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES.safetensors `
E:/lora/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES_loha_3e-7.safetensors `
--rank 1 `
--initial_alpha 1 `
--max_rank_retries 10 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 400 `
--target_loss 3e-7 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 100 `
--save_every_n_layers 10 `
--keep_n_resume_files 10 `
--skip_delta_threshold 6e-7 `
--rank_search_strategy binary_search_min_rank `
--probe_aggressive_early_stop
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/proteus_v06.safetensors `
E:/lora/sdxl/proteus_v06_1e-7.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 27 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 1e-7 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--save_every_n_layers 10 `
--keep_n_resume_files 10 `
--rank_search_strategy binary_search_min_rank
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_1e-8v3.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 29 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 1e-8 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--advanced_projection_decay_cap_min 0.5 `
--advanced_projection_decay_cap_max 1.05 `
--min_progress_loss_ratio 0.000001 `
--projection_sample_interval 1 `
--projection_min_ema_history 100 `
--continue_training_from_loha E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_1e-8v2_resume_L422.safetensors
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES.safetensors `
E:/lora/sdxl/lustifySDXLNSFW_oltFIXEDTEXTURES_loha_9e-8.safetensors `
--rank 4 `
--initial_alpha 4 `
--max_rank_retries 27 `
--rank_increase_factor 1.2 `
--max_iterations 16000 `
--min_iterations 400 `
--target_loss 1e-7 `
--lr 1e-1 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 200 `
--advanced_projection_decay_cap_min 0.5 `
--advanced_projection_decay_cap_max 1.05 `
--min_progress_loss_ratio 0.000001 `
--projection_sample_interval 1 `
--projection_min_ema_history 100
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_model_difference.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--save_dtype float16
--model_org_path E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned_path E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--algo loha `
--network_alpha 64 `
--network_dim 4 `
--conv_alpha 64 `
--conv_dim 4 `
--device cuda `
--sdxl `
--save_precision fp16 `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_to E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors --algo loha --sdxl --dim 32 --conv_dim 32 --dynamic_method sv_cumulative --dynamic_param 0.99 --save_precision fp16 --device cuda --verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py ^
--model_org_path "D:\StableDiffusion\models\sdxl_base_1.0.safetensors" ^
--model_tuned_path "D:\StableDiffusion\models\my_sdxl_finetune.safetensors" ^
--save_to "C:\LoRA_Extractor\output\my_loha_sdxl.safetensors" ^
--sdxl ^
--algo loha ^
--network_alpha 64 ^
--network_dim 4 ^
--conv_alpha 64 ^
--conv_dim 4 ^
--save_precision bf16 ^
--device cuda ^
--verbose
sv_cumulative_knee
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py --save_precision fp16 --save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_two_pass_energy_512.safetensors --model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors --model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors --total_rank_budget 2048 --device cuda --sdxl --svd_mode per_layer --dynamic_param 1.0 --dynamic_method two_pass_energy --verbose --min_rank 4 --max_rank 32
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\sd-scripts\networks\extract_lora_from_models-nw.py ^
--save_precision bf16 ^
--save_to E:/lora/sdxl/cinemaDiffusoXL_beta03_two_pass_energy_512.safetensors ^
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors ^
--model_tuned E:/models/sdxl/cinemaDiffusoXL_beta03.safetensors ^
--dim 512 ^
--device cuda ^
--sdxl ^
--target_fro_retained 0.5 ^
--group_size 6 ^
--svd_mode per_layer ^
--dynamic_method two_pass_energy ^
--dynamic_param 1.0 ^
--min_rank 4 ^
--verbose

View File

@ -0,0 +1,65 @@
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
--save_precision fp16 `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--save_to E:/lora/sdxl/dreamshaperXL_alpha2Xl10_sv_fro_0.9_1024.safetensors `
--dim 1024 `
--device cuda `
--sdxl `
--dynamic_method sv_fro `
--dynamic_param 0.9 `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
--save_precision fp16 `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned E:/models/sdxl/proteus_v06.safetensors `
--save_to E:/lora/sdxl/proteus_v06_sv_cumulative_knee_1024.safetensors `
--dim 1024 `
--device cuda `
--sdxl `
--dynamic_method sv_cumulative_knee `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\lr_finder.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--lr_finder_num_layers 16 `
--lr_finder_min_lr 1e-8 `
--lr_finder_max_lr 0.2 `
--lr_finder_num_steps 120 `
--lr_finder_iters_per_step 40 `
--rank 8 `
--initial_alpha 8.0 `
--precision bf16 `
--device cuda `
--lr_finder_plot `
--lr_finder_show_plot
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_1e-7.safetensors `
--rank 2 `
--initial_alpha 2 `
--max_rank_retries 7 `
--rank_increase_factor 2 `
--max_iterations 8000 `
--min_iterations 400 `
--target_loss 1e-7 `
--lr 1e-01 `
--device cuda `
--precision fp32 `
--verbose `
--save_weights_dtype bf16 `
--progress_check_interval 100 `
--save_every_n_layers 10 `
--keep_n_resume_files 10 `
--skip_delta_threshold 1e-7 `
--rank_search_strategy binary_search_min_rank `
--probe_aggressive_early_stop
D:\kohya_ss\venv\Scripts\python.exe D:\kohya_ss\tools\model_diff_report.py `
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
--top_n_diff 15 --plot_histograms --plot_histograms_top_n 3 --output_dir ./analysis_results

View File

@ -1,662 +0,0 @@
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file, load_file
import safetensors # Import the main library to use safetensors.safe_open
from tqdm import tqdm
import math
import json
from collections import OrderedDict
import signal
import sys
import glob
# --- Global variables ---
extracted_loha_state_dict_global = OrderedDict()
layer_optimization_stats_global = []
args_global = None
processed_layers_this_session_count_global = 0
previously_completed_module_prefixes_global = set()
all_completed_module_prefixes_ever_global = set() # Tracks all module prefixes ever completed (resumed + current)
skipped_identical_count_global = 0
skipped_other_reason_count_global = 0
keys_scanned_this_run_global = 0
save_attempted_on_interrupt = False
outer_pbar_global = None
main_loop_completed_scan_flag_global = False # True if the main key loop finished a full scan
# --- optimize_loha_for_layer and get_module_shape_info_from_weight (UNCHANGED) ---
def optimize_loha_for_layer(
layer_name: str, delta_W_target: torch.Tensor, out_dim: int, in_dim_effective: int,
k_h: int, k_w: int, rank: int, initial_alpha_val: float, lr: float = 1e-3,
max_iterations: int = 1000, min_iterations: int = 100, target_loss: float = None,
weight_decay: float = 1e-4, device: str = 'cuda', dtype: torch.dtype = torch.float32,
is_conv: bool = True, verbose_layer_debug: bool = False
):
delta_W_target = delta_W_target.to(device, dtype=dtype)
if is_conv:
k_ops = k_h * k_w
hada_w1_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype)); nn.init.kaiming_uniform_(hada_w1_a, a=math.sqrt(5))
hada_w1_b = nn.Parameter(torch.empty(rank, in_dim_effective * k_ops, device=device, dtype=dtype)); nn.init.normal_(hada_w1_b, std=0.02)
hada_w2_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype)); nn.init.kaiming_uniform_(hada_w2_a, a=math.sqrt(5))
hada_w2_b = nn.Parameter(torch.empty(rank, in_dim_effective * k_ops, device=device, dtype=dtype)); nn.init.normal_(hada_w2_b, std=0.02)
else: # Linear
hada_w1_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype)); nn.init.kaiming_uniform_(hada_w1_a, a=math.sqrt(5))
hada_w1_b = nn.Parameter(torch.empty(rank, in_dim_effective, device=device, dtype=dtype)); nn.init.normal_(hada_w1_b, std=0.02)
hada_w2_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype)); nn.init.kaiming_uniform_(hada_w2_a, a=math.sqrt(5))
hada_w2_b = nn.Parameter(torch.empty(rank, in_dim_effective, device=device, dtype=dtype)); nn.init.normal_(hada_w2_b, std=0.02)
alpha_param = nn.Parameter(torch.tensor(initial_alpha_val, device=device, dtype=dtype))
optimizer = torch.optim.AdamW([hada_w1_a, hada_w1_b, hada_w2_a, hada_w2_b, alpha_param], lr=lr, weight_decay=weight_decay)
patience_epochs = max(10, int(max_iterations * 0.05))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience_epochs, factor=0.5, min_lr=1e-7, verbose=False)
iter_pbar = tqdm(range(max_iterations), desc=f"Opt: {layer_name}", leave=False, dynamic_ncols=True, position=1, mininterval=0.5)
final_loss = float('inf'); stopped_early_by_loss = False; iterations_actually_done = 0
for i in iter_pbar:
iterations_actually_done = i + 1
if save_attempted_on_interrupt: print(f"\n Interrupt during opt of {layer_name}. Stopping layer after {i} iters."); break
optimizer.zero_grad(); eff_alpha_scale = alpha_param / rank
if is_conv:
term1_flat = hada_w1_a @ hada_w1_b; term1_reshaped = term1_flat.view(out_dim, in_dim_effective, k_h, k_w)
term2_flat = hada_w2_a @ hada_w2_b; term2_reshaped = term2_flat.view(out_dim, in_dim_effective, k_h, k_w)
delta_W_loha = eff_alpha_scale * term1_reshaped * term2_reshaped
else:
term1 = hada_w1_a @ hada_w1_b; term2 = hada_w2_a @ hada_w2_b
delta_W_loha = eff_alpha_scale * term1 * term2
loss = F.mse_loss(delta_W_loha, delta_W_target); final_loss = loss.item()
loss.backward(); optimizer.step(); scheduler.step(loss)
current_lr = optimizer.param_groups[0]['lr']
iter_pbar.set_postfix_str(f"Loss={final_loss:.3e}, AlphaP={alpha_param.item():.2f}, LR={current_lr:.1e}", refresh=True)
if verbose_layer_debug and (i == 0 or (i + 1) % (max_iterations // 10 if max_iterations >= 10 else 1) == 0 or i == max_iterations - 1):
iter_pbar.write(f" Debug {layer_name} - Iter {i+1}/{max_iterations}: Loss: {final_loss:.6e}, LR: {current_lr:.2e}, AlphaP: {alpha_param.item():.4f}")
if target_loss is not None and i >= min_iterations -1 and final_loss <= target_loss:
if verbose_layer_debug or (args_global and args_global.verbose): iter_pbar.write(f" Target loss {target_loss:.2e} reached for {layer_name} at iter {i+1}.")
stopped_early_by_loss = True; break
if not save_attempted_on_interrupt: iter_pbar.set_description_str(f"Opt: {layer_name} (Done)"); iter_pbar.set_postfix_str(f"FinalLoss={final_loss:.2e}, It={iterations_actually_done}{', EarlyStop' if stopped_early_by_loss else ''}")
iter_pbar.close()
if save_attempted_on_interrupt and not stopped_early_by_loss and iterations_actually_done < max_iterations:
return {'final_loss': final_loss, 'stopped_early': False, 'iterations_done': iterations_actually_done, 'interrupted_mid_layer': True}
return {'hada_w1_a': hada_w1_a.data.cpu().contiguous(), 'hada_w1_b': hada_w1_b.data.cpu().contiguous(),
'hada_w2_a': hada_w2_a.data.cpu().contiguous(), 'hada_w2_b': hada_w2_b.data.cpu().contiguous(),
'alpha': alpha_param.data.cpu().contiguous(), 'final_loss': final_loss,
'stopped_early': stopped_early_by_loss, 'iterations_done': iterations_actually_done,
'interrupted_mid_layer': False}
def get_module_shape_info_from_weight(weight_tensor: torch.Tensor):
if len(weight_tensor.shape) == 4: is_conv = True; out_dim, in_dim_effective, k_h, k_w = weight_tensor.shape; groups = 1; return out_dim, in_dim_effective, k_h, k_w, groups, is_conv
elif len(weight_tensor.shape) == 2: is_conv = False; out_dim, in_dim = weight_tensor.shape; return out_dim, in_dim, None, None, 1, is_conv
return None
# --- NEW: Helper function to generate intermediate filenames ---
def generate_intermediate_filename(base_save_path: str, num_total_completed_layers: int) -> str:
base, ext = os.path.splitext(base_save_path)
return f"{base}_resume_L{num_total_completed_layers}{ext}"
# --- NEW: Helper function to find the best file to resume from ---
def find_best_resume_file(intended_final_path: str) -> tuple[str | None, int]:
output_dir = os.path.dirname(intended_final_path)
if not output_dir: output_dir = "." # Current directory if no path part
base_save_name, save_ext = os.path.splitext(os.path.basename(intended_final_path))
potential_files = []
# Check the main intended file first
if os.path.exists(intended_final_path):
potential_files.append(intended_final_path)
# Check for intermediate files
intermediate_pattern = os.path.join(output_dir, f"{base_save_name}_resume_L*{save_ext}")
potential_files.extend(glob.glob(intermediate_pattern))
best_file_path = None
max_completed_modules = -1
if not potential_files:
print(" No existing main LoHA file or intermediate resume files found.")
return None, -1
print(f" Found potential resume files: {potential_files}")
for file_path in potential_files:
try:
if not os.path.exists(file_path): continue # Should not happen with glob but good check
with safetensors.safe_open(file_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata and "ss_completed_loha_modules" in metadata:
num_completed = len(json.loads(metadata["ss_completed_loha_modules"]))
if num_completed > max_completed_modules:
max_completed_modules = num_completed
best_file_path = file_path
elif num_completed == max_completed_modules and best_file_path != intended_final_path and file_path == intended_final_path:
# Prefer the main file if module count is the same as an intermediate
best_file_path = file_path
elif max_completed_modules == -1: # If no file has metadata, consider the first one (or main one)
# This case handles files without the specific metadata, preferring the main file if it exists.
# It's a basic fallback; files with proper metadata will usually win.
if best_file_path is None or (file_path == intended_final_path and best_file_path != intended_final_path):
best_file_path = file_path # Fallback to considering the file itself if no metadata found yet
max_completed_modules = 0 # Treat as 0 if no metadata, to be potentially overridden
print(f" File {file_path} has no 'ss_completed_loha_modules' metadata. Treating as 0 completed for now.")
except Exception as e:
print(f" Warning: Could not read or parse metadata from {file_path}: {e}")
if best_file_path is None and file_path == intended_final_path and max_completed_modules == -1:
best_file_path = file_path # If primary file is broken, still note it as a candidate if nothing better
max_completed_modules = 0
if best_file_path:
print(f" Selected '{os.path.basename(best_file_path)}' for resume (contains {max_completed_modules} completed modules in metadata).")
elif not potential_files: # Already handled above, but as a safeguard
print(f" No existing LoHA file or intermediate files found matching pattern for: {intended_final_path}")
else:
print(f" Could not determine a best file to resume from among candidates, or no valid metadata found.")
return best_file_path, max_completed_modules
# --- NEW: Helper function to clean up intermediate files ---
def cleanup_intermediate_files(final_intended_path: str):
output_dir = os.path.dirname(final_intended_path)
if not output_dir: output_dir = "."
base_save_name, save_ext = os.path.splitext(os.path.basename(final_intended_path))
intermediate_pattern = os.path.join(output_dir, f"{base_save_name}_resume_L*{save_ext}")
cleaned_count = 0
for file_path in glob.glob(intermediate_pattern):
try:
os.remove(file_path)
if args_global and args_global.verbose: print(f" Cleaned up intermediate file: {file_path}")
cleaned_count +=1
except OSError as e:
print(f" Warning: Could not clean up intermediate file {file_path}: {e}")
if cleaned_count > 0:
print(f" Cleaned up {cleaned_count} intermediate file(s).")
# --- perform_graceful_save (MODIFIED to only require output_path_override) ---
def perform_graceful_save(output_path_to_save: str):
global extracted_loha_state_dict_global, layer_optimization_stats_global, args_global
global processed_layers_this_session_count_global, save_attempted_on_interrupt
global skipped_identical_count_global, skipped_other_reason_count_global, keys_scanned_this_run_global
global all_completed_module_prefixes_ever_global # Use this for metadata
# Ensure all_completed_module_prefixes_ever is up-to-date before saving
# This should already be handled by adding to it when layers are processed or resumed.
current_session_processed_prefixes = {stat['name'] for stat in layer_optimization_stats_global}
# `all_completed_module_prefixes_ever_global` should already include `previously_completed_module_prefixes_global`
# and any newly processed ones.
total_processed_ever = len(all_completed_module_prefixes_ever_global)
if not extracted_loha_state_dict_global and not previously_completed_module_prefixes_global : # Check against all_completed for empty save
# If all_completed is also empty, it means nothing was resumed and nothing new processed
if not all_completed_module_prefixes_ever_global:
print(f"No layers were processed or loaded to save to {output_path_to_save}. Save aborted.")
return
args_to_use = args_global
if not args_to_use: print("Error: Global args not available for saving metadata."); return
final_save_path = output_path_to_save # Use the direct path given
if args_to_use.save_weights_dtype == "fp16": final_save_dtype_torch = torch.float16
elif args_to_use.save_weights_dtype == "bf16": final_save_dtype_torch = torch.bfloat16
else: final_save_dtype_torch = torch.float32
final_state_dict_to_save = OrderedDict()
for k, v_tensor in extracted_loha_state_dict_global.items():
if v_tensor.is_floating_point(): final_state_dict_to_save[k] = v_tensor.to(final_save_dtype_torch)
else: final_state_dict_to_save[k] = v_tensor
# Metadata uses all_completed_module_prefixes_ever_global
print(f"\nAttempting to save LoHA for {total_processed_ever} unique modules in total "
f"({processed_layers_this_session_count_global} new this session) to {final_save_path}")
eff_global_network_alpha_val = args_to_use.initial_alpha; eff_global_network_alpha_str = f"{eff_global_network_alpha_val:.8f}"
global_rank_str = str(args_to_use.rank)
conv_rank_str = str(args_to_use.conv_rank if args_to_use.conv_rank is not None else args_to_use.rank)
eff_conv_alpha_val = args_to_use.initial_conv_alpha; conv_alpha_str = f"{eff_conv_alpha_val:.8f}"
network_args_dict = {
"algo": "loha", "dim": global_rank_str, "alpha": eff_global_network_alpha_str,
"conv_dim": conv_rank_str, "conv_alpha": conv_alpha_str,
"dropout": str(args_to_use.dropout), "rank_dropout": str(args_to_use.rank_dropout), "module_dropout": str(args_to_use.module_dropout),
"use_tucker": "false", "use_scalar": "false", "block_size": "1",}
sf_metadata = {
"ss_network_module": "lycoris.kohya", "ss_network_rank": global_rank_str,
"ss_network_alpha": eff_global_network_alpha_str, "ss_network_algo": "loha",
"ss_network_args": json.dumps(network_args_dict),
"ss_comment": f"Extracted LoHA (Interrupt: {save_attempted_on_interrupt}). OptPrec: {args_to_use.precision}. SaveDtype: {args_to_use.save_weights_dtype}. ATOL: {args_to_use.atol_fp32_check}. Layers: {total_processed_ever}. MaxIter: {args_to_use.max_iterations}. TargetLoss: {args_to_use.target_loss}",
"ss_base_model_name": os.path.splitext(os.path.basename(args_to_use.base_model_path))[0],
"ss_ft_model_name": os.path.splitext(os.path.basename(args_to_use.ft_model_path))[0],
"ss_save_weights_dtype": args_to_use.save_weights_dtype, "ss_optimization_precision": args_to_use.precision,
"ss_completed_loha_modules": json.dumps(list(all_completed_module_prefixes_ever_global)) # Use the global cumulative set
}
json_metadata_for_file = {
"comfyui_lora_type": "LyCORIS_LoHa", "model_name": os.path.splitext(os.path.basename(final_save_path))[0],
"base_model_path": args_to_use.base_model_path, "ft_model_path": args_to_use.ft_model_path,
"loha_extraction_settings": {k: str(v) if isinstance(v, type(os.pathsep)) else v for k,v in vars(args_to_use).items()},
"extraction_summary":{
"processed_layers_in_total_cumulative": total_processed_ever, # Cumulative
"processed_this_session": processed_layers_this_session_count_global,
"skipped_identical_count_this_session": skipped_identical_count_global,
"skipped_other_reason_count_this_session": skipped_other_reason_count_global,
"total_candidate_keys_scanned_in_loop_this_session": keys_scanned_this_run_global,
},
"layer_optimization_details_this_session": layer_optimization_stats_global,
"embedded_safetensors_metadata": sf_metadata,
"interrupted_save": save_attempted_on_interrupt
}
if final_save_path.endswith(".safetensors"):
try:
save_file(final_state_dict_to_save, final_save_path, metadata=sf_metadata)
print(f"LoHA state_dict saved to: {final_save_path}")
except Exception as e:
print(f"Error saving .safetensors file: {e}"); return
metadata_json_file_path = os.path.splitext(final_save_path)[0] + "_extraction_metadata.json"
try:
with open(metadata_json_file_path, 'w') as f: json.dump(json_metadata_for_file, f, indent=4)
print(f"Extended metadata saved to: {metadata_json_file_path}")
except Exception as e: print(f"Could not save extended metadata JSON: {e}")
else:
# Saving to .pt might not be fully robust with this new scheme if JSON metadata is critical
print(f"Saving to .pt not fully supported with extended metadata JSON. Saving basic .pt file.")
torch.save({'state_dict': final_state_dict_to_save, 'metadata': sf_metadata}, final_save_path)
print(f"LoHA state_dict saved to: {final_save_path} (basic .pt save)")
# --- handle_interrupt (MODIFIED to use intermediate filenames) ---
def handle_interrupt(signum, frame):
global save_attempted_on_interrupt, outer_pbar_global, args_global, all_completed_module_prefixes_ever_global
print("\n" + "="*30 + "\nCtrl+C (SIGINT) detected!\n" + "="*30)
if save_attempted_on_interrupt: print("Save already attempted. Force exiting."); os._exit(1); return
save_attempted_on_interrupt = True
if outer_pbar_global: outer_pbar_global.close() # Close main progress bar
# Close any active layer progress bar (it's trickier, this might not catch it if deep in opt)
# For simplicity, we rely on the check within optimize_loha_for_layer
print("Attempting to save progress for processed layers...")
if args_global and args_global.save_to:
num_layers_for_filename = len(all_completed_module_prefixes_ever_global)
interrupt_save_path = generate_intermediate_filename(args_global.save_to, num_layers_for_filename)
print(f"Interrupt save will be to: {interrupt_save_path}")
perform_graceful_save(output_path_to_save=interrupt_save_path)
else:
print("Cannot perform interrupt save: args_global or save_to path not defined.")
print("Graceful save attempt finished. Exiting.")
sys.exit(0)
def main(cli_args):
global args_global, extracted_loha_state_dict_global, layer_optimization_stats_global
global processed_layers_this_session_count_global, save_attempted_on_interrupt, outer_pbar_global
global skipped_identical_count_global, skipped_other_reason_count_global, keys_scanned_this_run_global
global previously_completed_module_prefixes_global, all_completed_module_prefixes_ever_global
global main_loop_completed_scan_flag_global
args_global = cli_args
signal.signal(signal.SIGINT, handle_interrupt)
if args_global.precision == "fp16": target_opt_dtype = torch.float16
elif args_global.precision == "bf16": target_opt_dtype = torch.bfloat16
else: target_opt_dtype = torch.float32
if args_global.save_weights_dtype == "fp16": final_save_dtype = torch.float16
elif args_global.save_weights_dtype == "bf16": final_save_dtype = torch.bfloat16
else: final_save_dtype = torch.float32
print(f"Using device: {args_global.device}, Opt Dtype: {target_opt_dtype}, Save Dtype: {final_save_dtype}")
if args_global.target_loss: print(f"Target Loss: {args_global.target_loss:.2e} (after {args_global.min_iterations} min iters)")
print(f"Max Iters/Layer: {args_global.max_iterations}")
# --- MODIFIED: Loading Existing LoHA for resuming (using find_best_resume_file) ---
chosen_resume_file = None
if not args_global.overwrite:
print(f"\nChecking for existing LoHA file or resume states for: {args_global.save_to}")
chosen_resume_file, num_modules_in_chosen_file = find_best_resume_file(args_global.save_to)
if chosen_resume_file:
print(f" Attempting to resume from: {chosen_resume_file} ({num_modules_in_chosen_file} modules reported in metadata).")
try:
file_metadata = None
with safetensors.safe_open(chosen_resume_file, framework="pt", device="cpu") as f:
file_metadata = f.metadata()
completed_modules_in_file = set()
if file_metadata and "ss_completed_loha_modules" in file_metadata:
try:
completed_modules_in_file = set(json.loads(file_metadata.get("ss_completed_loha_modules")))
# Verify count if possible, though num_modules_in_chosen_file is already from this.
if len(completed_modules_in_file) != num_modules_in_chosen_file and num_modules_in_chosen_file !=0 : # 0 can be if file had no metadata but was chosen
print(f" Warning: Metadata module count ({len(completed_modules_in_file)}) differs from initial scan count ({num_modules_in_chosen_file}). Using parsed set.")
except json.JSONDecodeError:
print(" Warning: Could not parse 'ss_completed_loha_modules' metadata from chosen file. Will not load specific tensors by prefix matching.")
else:
print(" 'ss_completed_loha_modules' not found in chosen file's metadata. Will not load specific tensors by prefix matching (might load all if no prefixes known).")
if completed_modules_in_file: # Only load if we have a list of modules to check against
print(" Loading tensors from chosen resume file...")
loaded_sd_for_resume = load_file(chosen_resume_file, device='cpu')
resumed_tensor_count = 0
for key, tensor_val in loaded_sd_for_resume.items():
module_prefix_for_check = ".".join(key.split('.')[:-1]) # e.g. lora_unet_..._block_0_fc1
is_bias_for_completed_module = key.endswith(".bias") and module_prefix_for_check in completed_modules_in_file
# Check if the tensor belongs to a module marked as completed
# This covers hada_w1_a, hada_w1_b etc. for LoHA layers, and biases.
if module_prefix_for_check in completed_modules_in_file or is_bias_for_completed_module :
extracted_loha_state_dict_global[key] = tensor_val
resumed_tensor_count +=1
previously_completed_module_prefixes_global = completed_modules_in_file
all_completed_module_prefixes_ever_global.update(previously_completed_module_prefixes_global) # Initialize with loaded
print(f" Successfully loaded {len(previously_completed_module_prefixes_global)} module prefixes "
f"with {resumed_tensor_count} tensors for resume from {os.path.basename(chosen_resume_file)}.")
del loaded_sd_for_resume
elif not completed_modules_in_file and num_modules_in_chosen_file == 0 and os.path.exists(chosen_resume_file):
# This case could mean an empty LoRA was found (e.g. from a previous failed start)
# or a file without the specific metadata was chosen by find_best_resume_file.
# We don't load anything specific but acknowledge the file existed.
print(f" Chosen resume file {os.path.basename(chosen_resume_file)} seems empty or has no LoHA module metadata. Starting new layer processing.")
# Optional: Load accompanying JSON metadata if it exists for the chosen_resume_file
resume_metadata_json_path = os.path.splitext(chosen_resume_file)[0] + "_extraction_metadata.json"
if os.path.exists(resume_metadata_json_path):
try:
with open(resume_metadata_json_path, 'r') as f_meta:
loaded_json_meta = json.load(f_meta)
# You could potentially load old layer_optimization_stats_global if needed for some cumulative report
# For now, we just acknowledge it.
print(f" Loaded accompanying metadata from: {os.path.basename(resume_metadata_json_path)}")
except Exception as e_json:
print(f" Could not load or parse JSON metadata from {resume_metadata_json_path}: {e_json}")
except Exception as e:
print(f" Error loading or parsing chosen LoHA file '{chosen_resume_file}': {e}. Starting fresh for new layers.")
extracted_loha_state_dict_global.clear()
previously_completed_module_prefixes_global.clear()
all_completed_module_prefixes_ever_global.clear()
else:
print(" No suitable existing LoHA file found to resume from. Starting fresh.")
# Globals are already empty, so no action needed.
elif args_global.overwrite and os.path.exists(args_global.save_to):
print(f"Overwriting specified output file as per --overwrite: {args_global.save_to}")
print(" Any existing intermediate resume files for this target will NOT be automatically cleaned with --overwrite until a new final save.")
extracted_loha_state_dict_global.clear()
previously_completed_module_prefixes_global.clear()
all_completed_module_prefixes_ever_global.clear()
# Note: We don't clean intermediates here because the user might want to revert.
# Cleanup happens on successful *final* save.
print(f"\nLoading base model: {args_global.base_model_path}")
if args_global.base_model_path.endswith(".safetensors"): base_model_sd = load_file(args_global.base_model_path, device='cpu')
else: base_model_sd = torch.load(args_global.base_model_path, map_location='cpu'); base_model_sd = base_model_sd.get('state_dict', base_model_sd)
print(f"Loading fine-tuned model: {args_global.ft_model_path}")
if args_global.ft_model_path.endswith(".safetensors"): ft_model_sd = load_file(args_global.ft_model_path, device='cpu')
else: ft_model_sd = torch.load(args_global.ft_model_path, map_location='cpu'); ft_model_sd = ft_model_sd.get('state_dict', ft_model_sd)
# Reset session-specific counters
processed_layers_this_session_count_global = 0
skipped_identical_count_global = 0 # For this session's scan
skipped_other_reason_count_global = 0 # For this session's scan
keys_scanned_this_run_global = 0
layer_optimization_stats_global.clear() # For this session's stats
main_loop_completed_scan_flag_global = False
all_candidate_keys = []
for k in base_model_sd.keys():
if k.endswith('.weight') and k in ft_model_sd and (len(base_model_sd[k].shape) == 2 or len(base_model_sd[k].shape) == 4):
all_candidate_keys.append(k)
all_candidate_keys.sort()
total_candidates_to_scan = len(all_candidate_keys)
print(f"Found {total_candidates_to_scan} candidate '.weight' keys common to both models and of suitable shape.")
outer_pbar_global = tqdm(total=total_candidates_to_scan, desc="Scanning Layers", dynamic_ncols=True, position=0)
try:
for key_name in all_candidate_keys:
if save_attempted_on_interrupt: break
keys_scanned_this_run_global += 1
outer_pbar_global.update(1)
original_module_path = key_name[:-len(".weight")]
loha_key_prefix = ""
if original_module_path.startswith("model.diffusion_model."): loha_key_prefix = "lora_unet_" + original_module_path[len("model.diffusion_model."):].replace(".", "_")
elif original_module_path.startswith("conditioner.embedders.0.transformer."): loha_key_prefix = "lora_te1_" + original_module_path[len("conditioner.embedders.0.transformer."):].replace(".", "_")
elif original_module_path.startswith("conditioner.embedders.1.model.transformer."): loha_key_prefix = "lora_te2_" + original_module_path[len("conditioner.embedders.1.model.transformer."):].replace(".", "_")
else: loha_key_prefix = "lora_" + original_module_path.replace(".", "_")
# Check if already processed (either resumed or done in this session earlier if logic allowed re-scanning)
if loha_key_prefix in all_completed_module_prefixes_ever_global:
if args_global.verbose:
if loha_key_prefix in previously_completed_module_prefixes_global:
tqdm.write(f"Skipping {loha_key_prefix} (scan): already processed (loaded from resumed LoHA).")
# else: # This case should not happen if all_completed_module_prefixes_ever_global is managed correctly
# tqdm.write(f"Skipping {loha_key_prefix} (scan): already processed in this session (should be rare).")
outer_pbar_global.set_description_str(f"Scan {keys_scanned_this_run_global}/{total_candidates_to_scan} (Resumed: {len(previously_completed_module_prefixes_global)}, New Opt: {processed_layers_this_session_count_global})")
continue
if args_global.max_layers is not None and args_global.max_layers > 0 and processed_layers_this_session_count_global >= args_global.max_layers:
# Still need to scan all keys to correctly determine if the job is "fully complete" later
# So, we just skip optimization but continue scanning.
if args_global.verbose and processed_layers_this_session_count_global == args_global.max_layers and (keys_scanned_this_run_global - (len(all_completed_module_prefixes_ever_global) - processed_layers_this_session_count_global) - skipped_identical_count_global - skipped_other_reason_count_global) == (args_global.max_layers +1) : # First time hitting this after max_layers
tqdm.write(f"\nReached max_layers limit ({args_global.max_layers}) for new layers this session. Continuing scan only to assess remaining layers.")
outer_pbar_global.set_description_str(f"Scan {keys_scanned_this_run_global}/{total_candidates_to_scan} (Max New Layers Reached, Opt Ths Sess: {processed_layers_this_session_count_global})")
# This key is not skipped due to being identical or other error, but due to max_layers.
# We don't increment skipped_other_reason_count_global here, as it's still a valid candidate for a future run.
continue # Continue scanning
base_W = base_model_sd[key_name].to(dtype=torch.float32)
ft_W = ft_model_sd[key_name].to(dtype=torch.float32)
if base_W.shape != ft_W.shape:
skipped_other_reason_count_global +=1
if args_global.verbose: tqdm.write(f"Skipping {key_name} (shape mismatch).")
continue
shape_info = get_module_shape_info_from_weight(base_W)
if shape_info is None:
skipped_other_reason_count_global +=1
if args_global.verbose: tqdm.write(f"Skipping {key_name} (unsupported shape).")
continue
delta_W_fp32 = (ft_W - base_W)
if torch.allclose(delta_W_fp32, torch.zeros_like(delta_W_fp32), atol=args_global.atol_fp32_check):
if args_global.verbose: tqdm.write(f"Skipping {key_name} (identical weights).")
skipped_identical_count_global += 1
continue
# If we reach here, this layer is a candidate for optimization in this session
max_layers_target_str = f"/{args_global.max_layers}" if args_global.max_layers is not None and args_global.max_layers > 0 else ""
outer_pbar_global.set_description_str(f"Optimizing L{processed_layers_this_session_count_global + 1}{max_layers_target_str} (Scan {keys_scanned_this_run_global}/{total_candidates_to_scan})")
if args_global.verbose: tqdm.write(f"\n Orig: {key_name} -> LoHA: {loha_key_prefix}")
out_dim, in_dim_effective, k_h, k_w, _, is_conv = shape_info
delta_W_target_for_opt = delta_W_fp32.to(dtype=target_opt_dtype)
current_rank = args_global.conv_rank if is_conv and args_global.conv_rank is not None else args_global.rank
current_initial_alpha = args_global.initial_conv_alpha if is_conv else args_global.initial_alpha
tqdm.write(f"Optimizing Layer {processed_layers_this_session_count_global + 1}{max_layers_target_str}: {loha_key_prefix} (Orig: {original_module_path}, Shp: {list(base_W.shape)}, R: {current_rank}, Alpha_init: {current_initial_alpha:.1f})")
try:
opt_results = optimize_loha_for_layer(
layer_name=loha_key_prefix, delta_W_target=delta_W_target_for_opt,
out_dim=out_dim, in_dim_effective=in_dim_effective, k_h=k_h, k_w=k_w, rank=current_rank,
initial_alpha_val=current_initial_alpha, lr=args_global.lr,
max_iterations=args_global.max_iterations, min_iterations=args_global.min_iterations,
target_loss=args_global.target_loss, weight_decay=args_global.weight_decay,
device=args_global.device, dtype=target_opt_dtype, is_conv=is_conv,
verbose_layer_debug=args_global.verbose_layer_debug
)
if not opt_results.get('interrupted_mid_layer'):
for p_name, p_val in opt_results.items():
if p_name not in ['final_loss', 'stopped_early', 'iterations_done', 'interrupted_mid_layer']:
extracted_loha_state_dict_global[f'{loha_key_prefix}.{p_name}'] = p_val.to(final_save_dtype)
layer_optimization_stats_global.append({
"name": loha_key_prefix, "original_name": original_module_path,
"final_loss": opt_results['final_loss'], "iterations_done": opt_results['iterations_done'],
"stopped_early_by_loss_target": opt_results['stopped_early']})
all_completed_module_prefixes_ever_global.add(loha_key_prefix) # Add to cumulative set
tqdm.write(f" Layer {loha_key_prefix} Done. Loss: {opt_results['final_loss']:.4e}, Iters: {opt_results['iterations_done']}{', Stopped by Loss' if opt_results['stopped_early'] else ''}")
if args_global.use_bias:
original_bias_key = f"{original_module_path}.bias"
# Check if bias exists in ft_model and differs from base (or base doesn't have it)
bias_differs = False
if original_bias_key in ft_model_sd:
ft_B = ft_model_sd[original_bias_key].to(dtype=torch.float32)
if original_bias_key in base_model_sd:
base_B = base_model_sd[original_bias_key].to(dtype=torch.float32)
if not torch.allclose(base_B, ft_B, atol=args_global.atol_fp32_check):
bias_differs = True
else: # Bias in FT but not in base
bias_differs = True
if bias_differs:
extracted_loha_state_dict_global[original_bias_key] = ft_B.cpu().to(final_save_dtype)
if args_global.verbose: tqdm.write(f" Saved differing/new bias for {original_bias_key}")
# Note: Bias keys are not added to "loha_key_prefix" sets as they don't have LoHA params.
# They are just carried over if different.
processed_layers_this_session_count_global += 1
else: # Interrupted mid-layer
if args_global.verbose: tqdm.write(f" Opt for {loha_key_prefix} interrupted; not saving params for this layer.")
# Do not add to all_completed_module_prefixes_ever_global or increment processed_layers_this_session_count_global
except Exception as e:
print(f"\nError during optimization for {original_module_path} ({loha_key_prefix}): {e}")
import traceback; traceback.print_exc()
skipped_other_reason_count_global +=1 # Count as skipped due to error during opt
# After the loop finishes (or breaks due to interrupt)
if not save_attempted_on_interrupt and keys_scanned_this_run_global == total_candidates_to_scan:
main_loop_completed_scan_flag_global = True
finally: # This will run whether the try block completes normally or an exception (like interrupt) occurs
if outer_pbar_global:
if not outer_pbar_global.disable and outer_pbar_global.n < outer_pbar_global.total:
outer_pbar_global.update(outer_pbar_global.total - outer_pbar_global.n) # Fill up the bar
outer_pbar_global.close()
# --- Save decision logic ---
if not save_attempted_on_interrupt: # If interrupted, handler already saved
print("\n--- Final Optimization Summary (This Session) ---")
for stat in layer_optimization_stats_global: print(f"Layer: {stat['name']}, Final Loss: {stat['final_loss']:.4e}, Iters: {stat['iterations_done']}{', Stopped by Loss' if stat['stopped_early_by_loss_target'] else ''}")
print(f"\n--- Overall Summary ---")
print(f"Total unique LoHA modules accumulated (resumed + new): {len(all_completed_module_prefixes_ever_global)}")
print(f" Processed new this session: {processed_layers_this_session_count_global}")
print(f" Skipped as identical (this session's scan): {skipped_identical_count_global}")
print(f" Skipped for other reasons (this session's scan, e.g., shape error, opt error): {skipped_other_reason_count_global}")
print(f" Total candidate keys scanned in loop (this session): {keys_scanned_this_run_global}/{total_candidates_to_scan}")
actual_save_path: str
save_to_final_name = False
if main_loop_completed_scan_flag_global:
# Number of layers that were found to be different and optimizable during the full scan of *this session*
num_optimizable_layers_identified_in_scan = total_candidates_to_scan - skipped_identical_count_global - skipped_other_reason_count_global
# Check if all *such* layers are now accounted for in our cumulative set
if len(all_completed_module_prefixes_ever_global) >= num_optimizable_layers_identified_in_scan:
# This implies that all layers that showed a difference in the current model comparison
# are now present in the LoHA state dict (either from resume or processed now).
# We also need to ensure max_layers didn't prematurely stop us if it was less than this count.
if args_global.max_layers is None or processed_layers_this_session_count_global >= args_global.max_layers or len(all_completed_module_prefixes_ever_global) < (len(previously_completed_module_prefixes_global) + args_global.max_layers):
# If max_layers is not set, or if we processed up to max_layers (or didn't need to because all were done),
# and the total count matches the optimizable count from scan, then it's final.
save_to_final_name = True
else: # max_layers was hit, and it's less than total optimizable, so not final.
print(f" Scan completed, but max_layers ({args_global.max_layers}) may have limited processing before all {num_optimizable_layers_identified_in_scan} differing layers were handled.")
else:
print(f" Scan completed, but not all {num_optimizable_layers_identified_in_scan} differing layers are processed yet "
f"(current total: {len(all_completed_module_prefixes_ever_global)}).")
else:
print(" Scan did not complete fully. Saving intermediate state.")
if save_to_final_name:
actual_save_path = args_global.save_to
print(f"\nAll optimizable layers appear to be processed. Saving to final path: {actual_save_path}")
else:
num_layers_for_filename = len(all_completed_module_prefixes_ever_global)
actual_save_path = generate_intermediate_filename(args_global.save_to, num_layers_for_filename)
print(f"\nRun incomplete or not all differing layers processed. Saving intermediate state to: {actual_save_path}")
perform_graceful_save(output_path_to_save=actual_save_path)
if save_to_final_name and actual_save_path == args_global.save_to : # Ensure it's the final path
print("\nCleaning up intermediate resume files...")
cleanup_intermediate_files(args_global.save_to)
else: # Save was attempted by interrupt handler
print("\nProcess was interrupted. Graceful save to an intermediate file was attempted by signal handler.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract LoHA parameters by optimizing against weight differences. Saves intermediate files like 'name_resume_L{count}.safetensors'.")
parser.add_argument("base_model_path", type=str, help="Path to the base model state_dict file (.pt, .pth, .safetensors)")
parser.add_argument("ft_model_path", type=str, help="Path to the fine-tuned model state_dict file (.pt, .pth, .safetensors)")
parser.add_argument("save_to", type=str, help="Path to save the FINAL extracted LoHA file (recommended .safetensors). Intermediate files will be based on this name.")
parser.add_argument("--overwrite", action="store_true", help="Ignore and overwrite any existing FINAL LoHA output file and its intermediate files if found at the start. Does not prevent resuming from other intermediate files if the final target does not exist.")
parser.add_argument("--rank", type=int, default=4, help="Default rank for LoHA decomposition (used for linear layers and as fallback for conv).")
parser.add_argument("--conv_rank", type=int, default=None, help="Specific rank for convolutional LoHA layers. Defaults to --rank if not set.")
parser.add_argument("--initial_alpha", type=float, default=None, help="Global initial alpha for optimization. Defaults to 'rank'.")
parser.add_argument("--initial_conv_alpha", type=float, default=None, help="Specific initial alpha for Conv LoHA. Defaults to '--initial_alpha' or conv_rank.")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate for LoHA optimization per layer.")
parser.add_argument("--max_iterations", type=int, default=1000, help="Maximum number of optimization iterations per layer.")
parser.add_argument("--min_iterations", type=int, default=100, help="Minimum iterations before checking target loss.")
parser.add_argument("--target_loss", type=float, default=None, help="Target MSE loss for early stopping per layer.")
parser.add_argument("--weight_decay", type=float, default=1e-5, help="Weight decay for LoHA optimization.")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device ('cuda' or 'cpu').")
parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], help="Optimization precision. Default: fp32.")
parser.add_argument("--save_weights_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"], help="Dtype for saved LoHA weights. Default: bf16.")
parser.add_argument("--atol_fp32_check", type=float, default=1e-6, help="Tolerance for identical weight check.")
parser.add_argument("--use_bias", action="store_true", help="Save differing bias terms.")
parser.add_argument("--dropout", type=float, default=0.0, help="General dropout (metadata only).")
parser.add_argument("--rank_dropout", type=float, default=0.0, help="Rank dropout (metadata only).")
parser.add_argument("--module_dropout", type=float, default=0.0, help="Module dropout (metadata only).")
parser.add_argument("--max_layers", type=int, default=None, help="Max NEW differing layers to process this session. Scan will continue to assess all layers.")
parser.add_argument("--verbose", action="store_true", help="General verbose output.")
parser.add_argument("--verbose_layer_debug", action="store_true", help="Detailed per-iteration optimization debug output.")
parsed_args = parser.parse_args()
if not os.path.exists(parsed_args.base_model_path): print(f"Error: Base model path not found: {parsed_args.base_model_path}"); exit(1)
if not os.path.exists(parsed_args.ft_model_path): print(f"Error: Fine-tuned model path not found: {parsed_args.ft_model_path}"); exit(1)
save_dir = os.path.dirname(parsed_args.save_to)
if save_dir and not os.path.exists(save_dir):
try:
os.makedirs(save_dir, exist_ok=True)
print(f"Created directory: {save_dir}")
except OSError as e:
print(f"Error: Could not create directory {save_dir}: {e}"); exit(1)
if parsed_args.initial_alpha is None: parsed_args.initial_alpha = float(parsed_args.rank)
# Ensure conv_alpha defaults correctly after initial_alpha might have defaulted to rank
if parsed_args.initial_conv_alpha is None:
# If conv_rank is set, use that for default alpha, else use the global initial_alpha (which might itself be rank)
conv_rank_for_alpha_default = parsed_args.conv_rank if parsed_args.conv_rank is not None else parsed_args.rank
parsed_args.initial_conv_alpha = float(conv_rank_for_alpha_default) if parsed_args.conv_rank is not None else parsed_args.initial_alpha
main(parsed_args)

View File

@ -1,33 +1,5 @@
import sys import sys
import os import os
# 1. Add sd-scripts directory to sys.path
# This block can now be potentially removed if no other sd-scripts imports are needed
# OR kept if there's a chance of re-introducing some utilities for other purposes.
# For full removal of the sd-scripts dependency for *this script's execution*,
# ensure no other `from library...` or `from networks...` exist.
# script_dir = os.path.dirname(os.path.abspath(__file__))
# project_root = os.path.dirname(script_dir)
# sd_scripts_dir_path = os.path.join(project_root, "sd-scripts")
# if sd_scripts_dir_path not in sys.path:
# sys.path.insert(0, sd_scripts_dir_path)
# Now you can import from the library package and the networks package
# try:
# # model_util and sdxl_model_util REMOVED from here
# # from library.utils import setup_logging # REMOVED
# # from networks import lora # REMOVED
# except ImportError as e:
# print(f"Error importing from sd-scripts. Please check your sd-scripts folder structure.")
# # print(f"Attempted to load from: {sd_scripts_dir_path}") # If path addition is removed
# print(f"Original error: {e}")
# print("Current sys.path relevant entries:")
# for p in sys.path:
# if "sd-scripts" in p or "kohya_ss" in p: # Adjust if sd_scripts_dir_path is removed
# print(p)
# raise
import argparse import argparse
import json import json
import time import time

View File

@ -1,340 +0,0 @@
import argparse
import json
import os
import time
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, model_util, sdxl_model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
def index_sv_knee(S):
"""Determine rank using the knee point detection method."""
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
# Line coefficients from (1, S[0]) to (n, S[-1])
a = S[0] - S[-1]
b = n - 1
c = 1 * S[-1] - n * S[0]
# Compute distances for each k
distances = []
for k in range(1, n + 1):
dist = abs(a * k + b * S[k - 1] + c) / (a**2 + b**2)**0.5
distances.append(dist)
# Find index of maximum distance (add 1 because k starts at 1)
index = torch.argmax(torch.tensor(distances)).item() + 1
index = max(1, min(index, n - 1))
return index
def index_sv_rel_decrease(S, tau=0.1):
"""Determine rank based on relative decrease threshold."""
if len(S) < 2:
return 1
# Compute ratios of consecutive singular values
ratios = S[1:] / S[:-1]
# Find the smallest k where ratio < tau
for k in range(len(ratios)):
if ratios[k] < tau:
return max(1, k + 1) # k + 1 because we want rank after the drop
# If no drop below tau, return max rank
return min(len(S), len(S) - 1)
def save_to_file(file_name, model, state_dict, dtype, metadata=None):
if dtype is not None:
for key in list(state_dict.keys()):
if isinstance(state_dict[key], torch.Tensor):
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
dynamic_method=None,
dynamic_param=None,
verbose=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time"
v_parameterization = v2 if v_parameterization is None else v_parameterization
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision) if save_precision else torch.float
work_device = "cpu"
# Load models
if not sdxl:
logger.info(f"Loading original SD model: {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
if load_dtype:
text_encoder_o.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SD model: {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
if load_dtype:
text_encoder_t.to(load_dtype)
unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
device_org = load_original_model_to or "cpu"
device_tuned = load_tuned_model_to or "cpu"
logger.info(f"Loading original SDXL model: {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype:
text_encoder_o1.to(load_dtype)
text_encoder_o2.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SDXL model: {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype:
text_encoder_t1.to(load_dtype)
text_encoder_t2.to(load_dtype)
unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# Create LoRA network
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {}
# Define a small initial dimension for memory efficiency
init_dim = 4 # Small value to minimize memory usage
# Create LoRA networks with minimal dimension
lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)"
# Compute differences
diffs = {}
text_encoder_different = False
for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}")
diffs[lora_name] = diff
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
logger.warning("Text encoders are identical. Extracting U-Net only.")
lora_network_o.text_encoder_loras = []
diffs.clear()
for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
diffs[lora_name] = diff
del lora_network_t, unet_t
# Filter relevant modules
lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
# Extract and resize LoRA using SVD
logger.info("Extracting and resizing LoRA via SVD")
lora_weights = {}
with torch.no_grad():
for lora_name in tqdm(lora_names):
mat = diffs[lora_name]
if device:
mat = mat.to(device)
mat = mat.to(torch.float)
conv2d = len(mat.size()) == 4
kernel_size = mat.size()[2:4] if conv2d else None
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = mat.size()[0:2]
if conv2d:
mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
# Determine rank
max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
if dynamic_method:
if S[0] <= MIN_SV:
rank = 1
elif dynamic_method == "sv_ratio":
rank = index_sv_ratio(S, dynamic_param)
elif dynamic_method == "sv_cumulative":
rank = index_sv_cumulative(S, dynamic_param)
elif dynamic_method == "sv_fro":
rank = index_sv_fro(S, dynamic_param)
elif dynamic_method == "sv_knee":
rank = index_sv_knee(S)
elif dynamic_method == "sv_rel_decrease":
rank = index_sv_rel_decrease(S, dynamic_param)
rank = min(rank, max_rank, in_dim, out_dim)
else:
rank = min(max_rank, in_dim, out_dim)
# Truncate SVD components
U = U[:, :rank] @ torch.diag(S[:rank])
Vh = Vh[:rank, :]
# Clamp values
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
U = U.clamp(-hi_val, hi_val)
Vh = Vh.clamp(-hi_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, *kernel_size)
U = U.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U, Vh)
# Verbose output
if verbose:
s_sum = float(torch.sum(S))
s_rank = float(torch.sum(S[:rank]))
fro = float(torch.sqrt(torch.sum(S.pow(2))))
fro_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2))))
ratio = S[0] / S[rank - 1] if rank > 1 else float('inf')
logger.info(f"{lora_name:75} | sum(S) retained: {s_rank/s_sum:.1%}, fro retained: {fro_rank/fro:.1%}, max ratio: {ratio:.1f}, rank: {rank}")
# Create state dict
lora_sd = {}
for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype)
# Load and save LoRA
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoders_o, unet_o)
info = lora_network_save.load_state_dict(lora_sd)
logger.info(f"Loaded extracted and resized LoRA weights: {info}")
os.makedirs(os.path.dirname(save_to), exist_ok=True)
# Metadata
net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {}
metadata = {
"ss_v2": str(v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": str(dim) if not dynamic_method else "Dynamic",
"ss_network_alpha": str(float(dim)) if not dynamic_method else "Dynamic",
"ss_network_args": json.dumps(net_kwargs),
}
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
metadata.update(sai_metadata)
save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata)
logger.info(f"LoRA saved to: {save_to}")
def setup_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)")
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models")
parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA")
parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)")
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3")
parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)")
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract")
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata")
parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)")
parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)")
parser.add_argument("--dynamic_method", choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease"], help="Dynamic rank reduction method")
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
if args.dynamic_method and not args.dynamic_param:
raise ValueError("Dynamic method requires a dynamic parameter")
svd(**vars(args))

View File

@ -1,432 +0,0 @@
import argparse
import json
import os
import time
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, model_util, sdxl_model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
def index_sv_knee_improved(S, MIN_SV_KNEE=1e-8): # MIN_SV_KNEE can be same as global MIN_SV or specific
"""
Determine rank using the knee point detection method with normalization.
Normalizes singular values and their indices to the [0,1] range
to make the knee detection scale-invariant.
"""
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
# S is expected to be sorted in descending order.
s_max, s_min = S[0], S[-1]
# Handle flat or nearly flat singular value spectrum
if s_max - s_min < MIN_SV_KNEE:
# If all singular values are almost the same, a knee is not well-defined.
# Returning 1 is a conservative choice for low rank.
# Alternatively, n // 2 or n - 1 could be chosen depending on desired behavior.
return 1
# Normalize singular values to [0, 1]
# s_normalized[0] will be 1, s_normalized[n-1] will be 0
s_normalized = (S - s_min) / (s_max - s_min)
# Normalize indices to [0, 1]
# x_normalized[0] will be 0, x_normalized[n-1] will be 1
x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
# The line in normalized space connects (x_norm[0], s_norm[0]) to (x_norm[n-1], s_norm[n-1])
# This is (0, 1) to (1, 0).
# The equation of the line passing through (0,1) and (1,0) is x + y - 1 = 0.
# So, A=1, B=1, C=-1 for the line equation Ax + By + C = 0.
# Calculate the perpendicular distance from each point (x_normalized[i], s_normalized[i]) to this line.
# Distance = |A*x_i + B*y_i + C| / sqrt(A^2 + B^2)
# Distance = |1*x_normalized + 1*s_normalized - 1| / sqrt(1^2 + 1^2)
# Distance = |x_normalized + s_normalized - 1| / sqrt(2)
# The sqrt(2) denominator is constant and doesn't affect argmax, so it can be omitted for finding the index.
distances = (x_normalized + s_normalized - 1).abs()
# Find the 0-based index of the point with the maximum distance.
knee_index_0based = torch.argmax(distances).item()
# Convert 0-based index to 1-based rank.
rank = knee_index_0based + 1
# Clamp rank similar to original: must be > 0 and <= n-1 (typical for rank reduction)
# If knee_index_0based is n-1 (last point), rank becomes n. min(n, n-1) results in n-1.
rank = max(1, min(rank, n - 1))
return rank
def index_sv_cumulative_knee(S, min_sv_threshold=1e-8):
"""
Determine rank using the knee point detection method on the normalized cumulative sum of singular values.
This method identifies a point where adding more singular values contributes diminishingly to the total sum.
"""
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
s_sum = torch.sum(S)
# If all singular values are zero or very small, return rank 1.
if s_sum < min_sv_threshold:
return 1
# Calculate cumulative sum of singular values, normalized by the total sum.
# y_values[0] = S[0]/s_sum, ..., y_values[n-1] = 1.0
y_values = torch.cumsum(S, dim=0) / s_sum
# Normalize these y_values (cumulative sums) to the range [0,1] for knee detection.
y_min, y_max = y_values[0], y_values[n-1] # y_max is typically 1.0
# If the normalized cumulative sum curve is very flat (e.g., S[0] captures almost all energy),
# it implies the first few components are dominant.
if y_max - y_min < min_sv_threshold: # Using min_sv_threshold here as a sensitivity for flatness
# This condition means (S[0] + ... + S[n-1]) - S[0] is small relative to sum(S) if n>1
# Effectively, S[1:] components are negligible.
return 1
# y_norm[0] = 0, y_norm[n-1] = 1 (represents the normalized cumulative sum from start to end)
y_norm = (y_values - y_min) / (y_max - y_min)
# x_values are indices, normalized to [0, 1]
# x_norm[0] = 0, ..., x_norm[n-1] = 1
x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
# The "knee" is the point on the curve (x_norm[i], y_norm[i]) that is farthest
# from the line connecting the start and end of this normalized curve.
# In this normalized space, the line connects (0,0) to (1,1).
# The equation of this line is Y = X, or X - Y = 0.
# The distance from a point (x_i, y_i) to the line X - Y = 0 is |x_i - y_i| / sqrt(1^2 + (-1)^2).
# We can maximize |x_i - y_i| (or |y_i - x_i|) as sqrt(2) is a constant factor.
distances = (y_norm - x_norm).abs() # y_norm is expected to be >= x_norm for a concave cumulative curve.
# Find the 0-based index of the point with the maximum distance.
knee_index_0based = torch.argmax(distances).item()
# Convert 0-based index to 1-based rank.
rank = knee_index_0based + 1
# Clamp rank to be between 1 and n-1 (as n elements give n-1 possible ranks for truncation).
# A rank of n means no truncation. n-1 is the highest sensible rank for reduction.
rank = max(1, min(rank, n - 1))
return rank
def index_sv_rel_decrease(S, tau=0.1):
"""Determine rank based on relative decrease threshold."""
if len(S) < 2:
return 1
# Compute ratios of consecutive singular values
ratios = S[1:] / S[:-1]
# Find the smallest k where ratio < tau
for k in range(len(ratios)):
if ratios[k] < tau:
return max(1, k + 1) # k + 1 because we want rank after the drop
# If no drop below tau, return max rank
return min(len(S), len(S) - 1)
def save_to_file(file_name, model, state_dict, dtype, metadata=None):
if dtype is not None:
for key in list(state_dict.keys()):
if isinstance(state_dict[key], torch.Tensor):
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
dynamic_method=None,
dynamic_param=None,
verbose=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time"
v_parameterization = v2 if v_parameterization is None else v_parameterization
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision) if save_precision else torch.float
work_device = "cpu"
# Load models
if not sdxl:
logger.info(f"Loading original SD model: {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
if load_dtype:
text_encoder_o.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SD model: {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
if load_dtype:
text_encoder_t.to(load_dtype)
unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
device_org = load_original_model_to or "cpu"
device_tuned = load_tuned_model_to or "cpu"
logger.info(f"Loading original SDXL model: {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype:
text_encoder_o1.to(load_dtype)
text_encoder_o2.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SDXL model: {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype:
text_encoder_t1.to(load_dtype)
text_encoder_t2.to(load_dtype)
unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# Create LoRA network
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {}
# Define a small initial dimension for memory efficiency
init_dim = 4 # Small value to minimize memory usage
# Create LoRA networks with minimal dimension
lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)"
# Compute differences
diffs = {}
text_encoder_different = False
for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}")
diffs[lora_name] = diff
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
logger.warning("Text encoders are identical. Extracting U-Net only.")
lora_network_o.text_encoder_loras = []
diffs.clear()
for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
diffs[lora_name] = diff
del lora_network_t, unet_t
# Filter relevant modules
lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
# Extract and resize LoRA using SVD
logger.info("Extracting and resizing LoRA via SVD")
lora_weights = {}
with torch.no_grad():
for lora_name in tqdm(lora_names):
mat = diffs[lora_name]
if device:
mat = mat.to(device)
mat = mat.to(torch.float)
conv2d = len(mat.size()) == 4
kernel_size = mat.size()[2:4] if conv2d else None
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = mat.size()[0:2]
if conv2d:
mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
# Determine rank
max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
if dynamic_method:
if S[0] <= MIN_SV:
rank = 1
elif dynamic_method == "sv_ratio":
rank = index_sv_ratio(S, dynamic_param)
elif dynamic_method == "sv_cumulative":
rank = index_sv_cumulative(S, dynamic_param)
elif dynamic_method == "sv_fro":
rank = index_sv_fro(S, dynamic_param)
elif dynamic_method == "sv_knee":
rank = index_sv_knee_improved(S, MIN_SV) # Pass MIN_SV or a specific threshold
elif dynamic_method == "sv_cumulative_knee": # New method
rank = index_sv_cumulative_knee(S, MIN_SV) # Pass MIN_SV or a specific threshold
elif dynamic_method == "sv_rel_decrease":
rank = index_sv_rel_decrease(S, dynamic_param)
rank = min(rank, max_rank, in_dim, out_dim)
else:
rank = min(max_rank, in_dim, out_dim)
# Truncate SVD components
U = U[:, :rank] @ torch.diag(S[:rank])
Vh = Vh[:rank, :]
# Clamp values
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
U = U.clamp(-hi_val, hi_val)
Vh = Vh.clamp(-hi_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, *kernel_size)
U = U.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U, Vh)
# Verbose output
if verbose:
s_sum = float(torch.sum(S))
s_rank = float(torch.sum(S[:rank]))
fro = float(torch.sqrt(torch.sum(S.pow(2))))
fro_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2))))
ratio = S[0] / S[rank - 1] if rank > 1 else float('inf')
logger.info(f"{lora_name:75} | sum(S) retained: {s_rank/s_sum:.1%}, fro retained: {fro_rank/fro:.1%}, max ratio: {ratio:.1f}, rank: {rank}")
# Create state dict
lora_sd = {}
for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype)
# Load and save LoRA
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoders_o, unet_o)
info = lora_network_save.load_state_dict(lora_sd)
logger.info(f"Loaded extracted and resized LoRA weights: {info}")
os.makedirs(os.path.dirname(save_to), exist_ok=True)
# Metadata
net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {}
metadata = {
"ss_v2": str(v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": str(dim) if not dynamic_method else "Dynamic",
"ss_network_alpha": str(float(dim)) if not dynamic_method else "Dynamic",
"ss_network_args": json.dumps(net_kwargs),
}
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
metadata.update(sai_metadata)
save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata)
logger.info(f"LoRA saved to: {save_to}")
def setup_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)")
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models")
parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA")
parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)")
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3")
parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)")
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract")
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata")
parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)")
parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)")
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info")
parser.add_argument(
"--dynamic_method",
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], # Added "sv_cumulative_knee"
help="Dynamic rank reduction method"
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
if args.dynamic_method and not args.dynamic_param:
raise ValueError("Dynamic method requires a dynamic parameter")
svd(**vars(args))

View File

@ -1,545 +0,0 @@
import sys
import os
# 1. Add sd-scripts directory to sys.path
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
sd_scripts_dir_path = os.path.join(project_root, "sd-scripts")
if sd_scripts_dir_path not in sys.path:
sys.path.insert(0, sd_scripts_dir_path)
# Now you can import from the library package and the networks package
try:
from library import sai_model_spec, model_util, sdxl_model_util
from library.utils import setup_logging
from networks import lora # <--- CORRECTED LORA IMPORT
except ImportError as e:
print(f"Error importing from sd-scripts. Please check your sd-scripts folder structure.")
print(f"Attempted to load from: {sd_scripts_dir_path}")
print(f"Original error: {e}")
print("Current sys.path relevant entries:")
for p in sys.path:
if "sd-scripts" in p or "kohya_ss" in p: # Print relevant paths for debugging
print(p)
# Ensure 'networks' directory exists in 'sd-scripts' and contains 'lora.py'
# Also ensure 'sd-scripts/networks/__init__.py' exists.
raise
# --- The rest of your script ---
import argparse
import json
# import os # Already imported
import time
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
def index_sv_knee_improved(S, MIN_SV_KNEE=1e-8): # MIN_SV_KNEE can be same as global MIN_SV or specific
"""
Determine rank using the knee point detection method with normalization.
Normalizes singular values and their indices to the [0,1] range
to make the knee detection scale-invariant.
"""
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
# S is expected to be sorted in descending order.
s_max, s_min = S[0], S[-1]
# Handle flat or nearly flat singular value spectrum
if s_max - s_min < MIN_SV_KNEE:
# If all singular values are almost the same, a knee is not well-defined.
# Returning 1 is a conservative choice for low rank.
# Alternatively, n // 2 or n - 1 could be chosen depending on desired behavior.
return 1
# Normalize singular values to [0, 1]
# s_normalized[0] will be 1, s_normalized[n-1] will be 0
s_normalized = (S - s_min) / (s_max - s_min)
# Normalize indices to [0, 1]
# x_normalized[0] will be 0, x_normalized[n-1] will be 1
x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
# The line in normalized space connects (x_norm[0], s_norm[0]) to (x_norm[n-1], s_norm[n-1])
# This is (0, 1) to (1, 0).
# The equation of the line passing through (0,1) and (1,0) is x + y - 1 = 0.
# So, A=1, B=1, C=-1 for the line equation Ax + By + C = 0.
# Calculate the perpendicular distance from each point (x_normalized[i], s_normalized[i]) to this line.
# Distance = |A*x_i + B*y_i + C| / sqrt(A^2 + B^2)
# Distance = |1*x_normalized + 1*s_normalized - 1| / sqrt(1^2 + 1^2)
# Distance = |x_normalized + s_normalized - 1| / sqrt(2)
# The sqrt(2) denominator is constant and doesn't affect argmax, so it can be omitted for finding the index.
distances = (x_normalized + s_normalized - 1).abs()
# Find the 0-based index of the point with the maximum distance.
knee_index_0based = torch.argmax(distances).item()
# Convert 0-based index to 1-based rank.
rank = knee_index_0based + 1
# Clamp rank similar to original: must be > 0 and <= n-1 (typical for rank reduction)
# If knee_index_0based is n-1 (last point), rank becomes n. min(n, n-1) results in n-1.
rank = max(1, min(rank, n - 1))
return rank
def index_sv_cumulative_knee(S, min_sv_threshold=1e-8):
"""
Determine rank using the knee point detection method on the normalized cumulative sum of singular values.
This method identifies a point where adding more singular values contributes diminishingly to the total sum.
"""
n = len(S)
if n < 3: # Need at least 3 points to detect a knee
return 1
s_sum = torch.sum(S)
# If all singular values are zero or very small, return rank 1.
if s_sum < min_sv_threshold:
return 1
# Calculate cumulative sum of singular values, normalized by the total sum.
# y_values[0] = S[0]/s_sum, ..., y_values[n-1] = 1.0
y_values = torch.cumsum(S, dim=0) / s_sum
# Normalize these y_values (cumulative sums) to the range [0,1] for knee detection.
y_min, y_max = y_values[0], y_values[n-1] # y_max is typically 1.0
# If the normalized cumulative sum curve is very flat (e.g., S[0] captures almost all energy),
# it implies the first few components are dominant.
if y_max - y_min < min_sv_threshold: # Using min_sv_threshold here as a sensitivity for flatness
# This condition means (S[0] + ... + S[n-1]) - S[0] is small relative to sum(S) if n>1
# Effectively, S[1:] components are negligible.
return 1
# y_norm[0] = 0, y_norm[n-1] = 1 (represents the normalized cumulative sum from start to end)
y_norm = (y_values - y_min) / (y_max - y_min)
# x_values are indices, normalized to [0, 1]
# x_norm[0] = 0, ..., x_norm[n-1] = 1
x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
# The "knee" is the point on the curve (x_norm[i], y_norm[i]) that is farthest
# from the line connecting the start and end of this normalized curve.
# In this normalized space, the line connects (0,0) to (1,1).
# The equation of this line is Y = X, or X - Y = 0.
# The distance from a point (x_i, y_i) to the line X - Y = 0 is |x_i - y_i| / sqrt(1^2 + (-1)^2).
# We can maximize |x_i - y_i| (or |y_i - x_i|) as sqrt(2) is a constant factor.
distances = (y_norm - x_norm).abs() # y_norm is expected to be >= x_norm for a concave cumulative curve.
# Find the 0-based index of the point with the maximum distance.
knee_index_0based = torch.argmax(distances).item()
# Convert 0-based index to 1-based rank.
rank = knee_index_0based + 1
# Clamp rank to be between 1 and n-1 (as n elements give n-1 possible ranks for truncation).
# A rank of n means no truncation. n-1 is the highest sensible rank for reduction.
rank = max(1, min(rank, n - 1))
return rank
def index_sv_rel_decrease(S, tau=0.1):
"""Determine rank based on relative decrease threshold."""
if len(S) < 2:
# For matrices with fewer than 2 singular values, a relative decrease
# isn't meaningful. Returning 1 is a sensible default.
return 1
# Compute ratios of consecutive singular values
# S is sorted descending, so S[:-1] >= S[1:]
# ratios will be <= 1.0
ratios = S[1:] / S[:-1] # Example: S=[10,1,0.5], ratios=[0.1, 0.5]
# Find the smallest k such that S[k+1]/S[k] < tau.
# The rank would then be k+1, as we include S[k].
for k in range(len(ratios)): # k ranges from 0 to len(S)-2
if ratios[k] < tau:
# We found a significant drop after the k-th singular value.
# So, we keep k+1 singular values (indices 0 to k).
# The rank is k+1. Since k >= 0, k+1 >= 1.
return k + 1
# If no drop below tau was found, it means all relative decreases were >= tau.
# In this case, this method suggests using all available singular values.
# The actual rank will be capped later by args.dim/conv_dim and matrix dimensions.
return len(S)
def save_to_file(file_name, model_to_save, state_dict_content, dtype, metadata=None): # Renamed params for clarity
if dtype is not None:
for key in list(state_dict_content.keys()):
if isinstance(state_dict_content[key], torch.Tensor):
state_dict_content[key] = state_dict_content[key].to(dtype)
# save_file from safetensors expects a state_dict as the first argument if metadata is also passed.
# torch.save would also expect the state_dict.
# The 'model' variable being passed to save_file should be the state_dict itself.
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model_to_save, file_name, metadata=metadata) # Pass metadata correctly
else:
torch.save(model_to_save, file_name)
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
dynamic_method=None,
dynamic_param=None,
verbose=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time"
v_parameterization = v2 if v_parameterization is None else v_parameterization
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision) if save_precision else torch.float
work_device = "cpu"
# Load models
if not sdxl:
logger.info(f"Loading original SD model: {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
if load_dtype:
text_encoder_o.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SD model: {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
if load_dtype:
text_encoder_t.to(load_dtype)
unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
device_org = load_original_model_to or "cpu"
device_tuned = load_tuned_model_to or "cpu"
logger.info(f"Loading original SDXL model: {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype:
text_encoder_o1.to(load_dtype)
text_encoder_o2.to(load_dtype)
unet_o.to(load_dtype)
logger.info(f"Loading tuned SDXL model: {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype:
text_encoder_t1.to(load_dtype)
text_encoder_t2.to(load_dtype)
unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# Create LoRA network
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {}
# Define a small initial dimension for memory efficiency
init_dim = 4 # Small value to minimize memory usage
# Create LoRA networks with minimal dimension
lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)"
# Compute differences
diffs = {}
text_encoder_different = False
for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}")
diffs[lora_name] = diff
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
logger.warning("Text encoders are identical. Extracting U-Net only.")
lora_network_o.text_encoder_loras = []
diffs.clear()
for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras):
lora_name = lora_o.lora_name
diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device)
lora_o.org_module.weight = None
lora_t.org_module.weight = None
diffs[lora_name] = diff
del lora_network_t, unet_t
# Filter relevant modules
lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
# Extract and resize LoRA using SVD
logger.info("Extracting and resizing LoRA via SVD")
lora_weights = {}
with torch.no_grad():
for lora_name in tqdm(lora_names):
mat = diffs[lora_name]
if device:
mat = mat.to(device)
mat = mat.to(torch.float)
conv2d = len(mat.size()) == 4
kernel_size = mat.size()[2:4] if conv2d else None
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = mat.size()[0:2]
if conv2d:
mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
# Determine rank
max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
if dynamic_method:
if S[0] <= MIN_SV:
rank = 1
elif dynamic_method == "sv_ratio":
rank = index_sv_ratio(S, dynamic_param)
elif dynamic_method == "sv_cumulative":
rank = index_sv_cumulative(S, dynamic_param)
elif dynamic_method == "sv_fro":
rank = index_sv_fro(S, dynamic_param)
elif dynamic_method == "sv_knee":
rank = index_sv_knee_improved(S, MIN_SV) # Pass MIN_SV or a specific threshold
elif dynamic_method == "sv_cumulative_knee": # New method
rank = index_sv_cumulative_knee(S, MIN_SV) # Pass MIN_SV or a specific threshold
elif dynamic_method == "sv_rel_decrease":
rank = index_sv_rel_decrease(S, dynamic_param)
rank = min(rank, max_rank, in_dim, out_dim)
else:
rank = min(max_rank, in_dim, out_dim)
rank = max(1, rank) # Ensure rank is at least 1
# Truncate SVD components and distribute sqrt(S)
S_k = S[:rank]
U_k = U[:, :rank]
Vh_k = Vh[:rank, :]
# Ensure S_k values are non-negative before sqrt to avoid NaN from tiny negative SVD artifacts
S_k_non_negative = torch.clamp(S_k, min=0.0) # Use 0.0 for float tensor
s_sqrt = torch.sqrt(S_k_non_negative)
# Distribute s_sqrt: U_final = U_k * diag(s_sqrt), Vh_final = diag(s_sqrt) * Vh_k
# Using efficient broadcasting for multiplication:
U_final = U_k * s_sqrt.unsqueeze(0) # (out_dim, rank) * (1, rank)
Vh_final = Vh_k * s_sqrt.unsqueeze(1) # (rank, in_dim_effective) * (rank, 1)
# Clamp values (applied to U_final, Vh_final)
# The distribution of values in U_final and Vh_final might be different
# than the original U and Vh, so the effect of clamping might change.
dist = torch.cat([U_final.flatten(), Vh_final.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
U_clamped = U_final.clamp(-hi_val, hi_val)
Vh_clamped = Vh_final.clamp(-hi_val, hi_val)
if conv2d:
# U_clamped is (out_dim, rank)
U_clamped = U_clamped.reshape(out_dim, rank, 1, 1)
# Vh_clamped is (rank, in_dim * possibly_kernel_dims)
# It needs to be reshaped back to (rank, in_dim, kernel_h, kernel_w)
if conv2d_3x3 : # Original mat was (out_dim, in_dim * k_h * k_w)
Vh_clamped = Vh_clamped.reshape(rank, in_dim, *kernel_size)
else: # Original mat was (out_dim, in_dim) for 1x1 conv, kernel_size is (1,1)
Vh_clamped = Vh_clamped.reshape(rank, in_dim, *kernel_size) # kernel_size is (1,1) here
U_clamped = U_clamped.to(work_device, dtype=save_dtype).contiguous()
Vh_clamped = Vh_clamped.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U_clamped, Vh_clamped)
# Verbose output (S values are pre-modification for accurate reporting of original SVD properties)
if verbose:
s_sum_total = float(torch.sum(S))
s_sum_rank = float(torch.sum(S[:rank])) # Sum of the singular values actually used for reconstruction
fro_orig_total = float(torch.sqrt(torch.sum(S.pow(2))))
fro_reconstructed_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2)))) # Frobenius norm of the matrix part represented by chosen rank
# Ratio of the largest retained singular value to the smallest retained singular value
# S is sorted, S[0] is max. S[rank-1] is the smallest singular value included if rank > 0.
ratio_sv = S[0] / S[rank - 1] if rank > 0 and S[rank - 1].abs() > MIN_SV else float('inf') # Avoid division by zero or tiny number
# Ensure denominators are not zero for percentages
sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > MIN_SV else 1.0
fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > MIN_SV else 1.0
logger.info(
f"{lora_name:75} | rank: {rank}, "
f"sum(S) retained: {sum_s_retained_percentage:.2%}, "
f"Frobenius norm retained: {fro_retained_percentage:.2%}, "
f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}"
)
# Create state dict
lora_sd = {}
for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype) # alpha is rank
# Load and save LoRA
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoders_o, unet_o) # This applies weights, not strictly necessary if just saving sd
info = lora_network_save.load_state_dict(lora_sd) # This populates the network object with the weights from lora_sd
logger.info(f"Loaded extracted and resized LoRA weights into network object: {info}")
os.makedirs(os.path.dirname(save_to), exist_ok=True)
# Metadata
net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {}
# Determine network_dim and network_alpha for metadata based on dynamic method
if dynamic_method:
network_dim_meta = "Dynamic"
network_alpha_meta = "Dynamic" # Alpha is rank, which is dynamic
else:
network_dim_meta = str(dim)
network_alpha_meta = str(float(dim)) # Alpha is rank, which is dim
metadata = {
"ss_v2": str(v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": network_dim_meta,
"ss_network_alpha": network_alpha_meta, # Alpha is typically the rank
"ss_network_args": json.dumps(net_kwargs),
"ss_lowram": "False", # Assuming not specifically lowram mode
"ss_num_train_images": "N/A", # Not applicable for extraction
# Add other relevant metadata as per sai_model_spec or conventions
}
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
# Build sai_metadata, ensuring it includes necessary fields like 'ss_sd_model_hash' if possible
# For extraction, some training-specific metadata might not be relevant or available.
sai_metadata = sai_model_spec.build_metadata(
None, # training_info (usually from train_util or fine_tune) - can be None for extraction
v2,
v_parameterization,
sdxl,
True, # is_sd2
False, # is_v_pred_like
time.time(),
title=title,
# model_hash=None, # Original model hash if available
# tuned_model_hash=None # Tuned model hash if available
)
# Filter out None values from sai_metadata if any, or handle them in build_metadata
sai_metadata_cleaned = {k: v for k, v in sai_metadata.items() if v is not None}
metadata.update(sai_metadata_cleaned)
# Use the state_dict 'lora_sd' for saving, not the network object 'lora_network_save'
save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata) # Pass lora_sd as the model/state_dict to save
logger.info(f"LoRA saved to: {save_to}")
def setup_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)")
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models")
parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA")
parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)")
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3")
parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)")
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract")
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata")
parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)")
parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)")
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info")
parser.add_argument(
"--dynamic_method",
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], # Added "sv_cumulative_knee"
help="Dynamic rank reduction method"
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
methods_requiring_param = ["sv_ratio", "sv_fro", "sv_cumulative", "sv_rel_decrease"]
if args.dynamic_method in methods_requiring_param and args.dynamic_param is None:
raise ValueError(f"Dynamic method '{args.dynamic_method}' requires --dynamic_param to be set.")
# Add a check for rank > 0 if not dynamic, or ensure dynamic methods return rank >= 1
if not args.dynamic_method and args.dim <= 0:
raise ValueError(f"--dim (rank) must be > 0. Got {args.dim}")
if args.conv_dim is not None and args.conv_dim <=0:
raise ValueError(f"--conv_dim (rank) must be > 0 if specified. Got {args.conv_dim}")
svd(**vars(args))