diff --git a/tools/analyse_loha.py b/tools/analyse_loha.py new file mode 100644 index 0000000..9449ad0 --- /dev/null +++ b/tools/analyse_loha.py @@ -0,0 +1,159 @@ +import safetensors.torch +import json +from collections import OrderedDict +import sys # To redirect stdout +import traceback + +class Logger(object): + def __init__(self, filename="loha_analysis_output.txt"): + self.terminal = sys.stdout + self.log = open(filename, "w", encoding='utf-8') + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + # This flush method is needed for python 3 compatibility. + # This handles the flush command, which shutil.copytree or os.system uses. + self.terminal.flush() + self.log.flush() + + def close(self): + self.log.close() + +def analyze_safetensors_file(filepath, output_filename="loha_analysis_output.txt"): + """ + Analyzes a .safetensors file to extract and print its metadata + and tensor information (keys, shapes, dtypes) to a file. + """ + original_stdout = sys.stdout + logger = Logger(filename=output_filename) + sys.stdout = logger + + try: + print(f"--- Analyzing: {filepath} ---\n") + print(f"--- Output will be saved to: {output_filename} ---\n") + + # Load the tensors to get their structure + state_dict = safetensors.torch.load_file(filepath, device="cpu") # Load to CPU to avoid potential CUDA issues + + print("--- Tensor Information ---") + if not state_dict: + print("No tensors found in the state dictionary.") + else: + # Sort keys for consistent output + sorted_keys = sorted(state_dict.keys()) + current_module_prefix = "" + + # First, identify all unique module prefixes for better grouping + module_prefixes = sorted(list(set([".".join(key.split(".")[:-1]) for key in sorted_keys if "." in key]))) + + for prefix in module_prefixes: + if not prefix: # Skip keys that don't seem to be part of a module (e.g. global metadata tensors if any) + continue + print(f"\nModule: {prefix}") + for key in sorted_keys: + if key.startswith(prefix + "."): + tensor = state_dict[key] + print(f" - Key: {key}") + print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}") # Output shape as list for clarity + if key.endswith((".alpha", ".dim")): + try: + value = tensor.item() + # Check if value is float and format if it is + if isinstance(value, float): + print(f" Value: {value:.8f}") # Format float to a certain precision + else: + print(f" Value: {value}") + except Exception as e: + print(f" Value: Could not extract scalar value ({tensor}, error: {e})") + elif tensor.numel() < 10: # Print small tensors' values + print(f" Values (first few): {tensor.flatten()[:10].tolist()}") + + + # Print keys that might not fit the module pattern (e.g., older formats or single tensors) + print("\n--- Other Tensor Keys (if any, not fitting typical module.parameter pattern) ---") + other_keys_found = False + for key in sorted_keys: + if not any(key.startswith(p + ".") for p in module_prefixes if p): + other_keys_found = True + tensor = state_dict[key] + print(f" - Key: {key}") + print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}") + if key.endswith((".alpha", ".dim")) or tensor.numel() == 1: + try: + value = tensor.item() + if isinstance(value, float): + print(f" Value: {value:.8f}") + else: + print(f" Value: {value}") + except Exception as e: + print(f" Value: Could not extract scalar value ({tensor}, error: {e})") + + if not other_keys_found: + print("No other keys found.") + + print(f"\nTotal tensor keys found: {len(state_dict)}") + + print("\n--- Metadata (from safetensors header) ---") + metadata_content = OrderedDict() + malformed_metadata_keys = [] + try: + # Use safe_open to access the metadata separately + with safetensors.safe_open(filepath, framework="pt", device="cpu") as f: + metadata_keys = f.metadata() + if metadata_keys is None: + print("No metadata dictionary found in the file header (f.metadata() returned None).") + else: + for k in metadata_keys.keys(): + try: + metadata_content[k] = metadata_keys.get(k) + except Exception as e: + malformed_metadata_keys.append((k, str(e))) + metadata_content[k] = f"[Error reading value: {e}]" + except Exception as e: + print(f"Could not open or read metadata using safe_open: {e}") + traceback.print_exc(file=sys.stdout) + + if not metadata_content and not malformed_metadata_keys: + print("No metadata content extracted.") + else: + for key, value in metadata_content.items(): + print(f"- {key}: {value}") + if key == "ss_network_args" and value and not value.startswith("[Error"): + try: + parsed_args = json.loads(value) + print(" Parsed ss_network_args:") + for arg_key, arg_value in parsed_args.items(): + print(f" - {arg_key}: {arg_value}") + except json.JSONDecodeError: + print(" (ss_network_args is not a valid JSON string)") + if malformed_metadata_keys: + print("\n--- Malformed Metadata Keys (could not be read) ---") + for key, error_msg in malformed_metadata_keys: + print(f"- {key}: Error: {error_msg}") + + print("\n--- End of Analysis ---") + + except Exception as e: + print(f"\n!!! An error occurred during analysis !!!") + print(str(e)) + traceback.print_exc(file=sys.stdout) # Print full traceback to the log file + finally: + sys.stdout = original_stdout # Restore standard output + logger.close() + print(f"\nAnalysis complete. Output saved to: {output_filename}") + + +if __name__ == "__main__": + input_file_path = input("Enter the path to your working LoHA .safetensors file: ") + output_file_name = "loha_analysis_results.txt" # You can change this default + + # Suggest a default output name based on input file if desired + # import os + # base_name = os.path.splitext(os.path.basename(input_file_path))[0] + # output_file_name = f"{base_name}_analysis.txt" + + print(f"The analysis will be saved to: {output_file_name}") + analyze_safetensors_file(input_file_path, output_filename=output_file_name) \ No newline at end of file diff --git a/tools/dummy_loha.py b/tools/dummy_loha.py new file mode 100644 index 0000000..db9f059 --- /dev/null +++ b/tools/dummy_loha.py @@ -0,0 +1,204 @@ +import torch +from safetensors.torch import save_file +from collections import OrderedDict +import json + +# --- Script Configuration --- +# This script generates a minimal, non-functional LoHA (LyCORIS Hadamard Product Adaptation) +# .safetensors file, designed to be structurally compatible with ComfyUI and +# based on the analysis of a working SDXL LoHA file. + +# --- Global LoHA Parameters (mimicking metadata from your working file) --- +# These can be overridden per layer if needed for more complex dummies. +# From your metadata: ss_network_dim: 32, ss_network_alpha: 32.0 +DEFAULT_RANK = 32 +DEFAULT_ALPHA = 32.0 +CONV_RANK = 8 # From your ss_network_args: "conv_dim": "8" +CONV_ALPHA = 4.0 # From your ss_network_args: "conv_alpha": "4" + +# Define example target layers. +# We'll use names and dimensions that are representative of SDXL and your analysis. +# Format: (layer_name, in_dim, out_dim, rank, alpha) +# Note: For Conv2d, in_dim = in_channels, out_dim = out_channels. +# The hada_wX_b for conv will have shape (rank, in_channels * kernel_h * kernel_w) +# For simplicity in this dummy, we'll primarily focus on linear/attention +# layers first, and then add one representative conv-like layer. + +# Layer that previously caused error: +# "ERROR loha diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight shape '[640, 640]' is invalid..." +# This corresponds to lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_v +# In your working LoHA, similar attention layers (e.g., *_attn1_to_k) have out_dim=640, in_dim=640, rank=32, alpha=32.0 + +EXAMPLE_LAYERS_CONFIG = [ + # UNet Attention Layers (mimicking typical SDXL structure) + { + "name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_q", # Query + "in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False + }, + { + "name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_k", # Key + "in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False + }, + { + "name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_v", # Value - this one errored previously + "in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False + }, + { + "name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_out_0", # Output Projection + "in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False + }, + # A deeper UNet attention block + { + "name": "lora_unet_middle_block_1_transformer_blocks_0_attn1_to_q", + "in_dim": 1280, "out_dim": 1280, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False + }, + { + "name": "lora_unet_middle_block_1_transformer_blocks_0_attn1_to_out_0", + "in_dim": 1280, "out_dim": 1280, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False + }, + # Example UNet "Convolutional" LoHA (e.g., for a ResBlock's conv layer) + # Based on your lora_unet_input_blocks_1_0_in_layers_2 which had rank 8, alpha 4 + # Assuming original conv was Conv2d(320, 320, kernel_size=3, padding=1) + { + "name": "lora_unet_input_blocks_1_0_in_layers_2", + "in_dim": 320, # in_channels + "out_dim": 320, # out_channels + "rank": CONV_RANK, + "alpha": CONV_ALPHA, + "is_conv": True, + "kernel_size": 3 # Assume 3x3 kernel for this example + }, + # Example Text Encoder Layer (CLIP-L, first one from your list) + # lora_te1_text_model_encoder_layers_0_mlp_fc1 (original Linear(768, 3072)) + { + "name": "lora_te1_text_model_encoder_layers_0_mlp_fc1", + "in_dim": 768, "out_dim": 3072, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False + }, +] + +# Use bfloat16 as seen in the analysis +DTYPE = torch.bfloat16 + +# --- Main Script --- +def create_dummy_loha_file(filepath="dummy_loha_corrected.safetensors"): + """ + Creates and saves a dummy LoHA .safetensors file with corrected structure + and metadata based on analysis of a working file. + """ + state_dict = OrderedDict() + metadata = OrderedDict() + + print(f"Generating dummy LoHA with default rank={DEFAULT_RANK}, default alpha={DEFAULT_ALPHA}") + print(f"Targeting DTYPE: {DTYPE}") + + for layer_config in EXAMPLE_LAYERS_CONFIG: + layer_name = layer_config["name"] + in_dim = layer_config["in_dim"] + out_dim = layer_config["out_dim"] + rank = layer_config["rank"] + alpha = layer_config["alpha"] + is_conv = layer_config["is_conv"] + + print(f"Processing layer: {layer_name} (in: {in_dim}, out: {out_dim}, rank: {rank}, alpha: {alpha}, conv: {is_conv})") + + # --- LoHA Tensor Shapes Correction based on analysis --- + # hada_wX_a (maps to original layer's out_features): (out_dim, rank) + # hada_wX_b (maps from original layer's in_features): (rank, in_dim) + # For Convolutions, in_dim refers to in_channels, out_dim to out_channels. + # For hada_wX_b in conv, the effective input dimension includes kernel size. + + if is_conv: + kernel_size = layer_config.get("kernel_size", 3) # Default to 3x3 if not specified + # This is for LoHA types that decompose the full kernel (e.g. LyCORIS full conv): + # (rank, in_channels * kernel_h * kernel_w) + # For simpler conv LoHA (like applying to 1x1 equivalent), it might just be (rank, in_channels) + # The analysis for `lora_unet_input_blocks_1_0_in_layers_2` showed hada_w1_b as [8, 2880] + # where in_dim=320, rank=8. 2880 = 320 * 9 (i.e., in_channels * kernel_h * kernel_w for 3x3) + # This indicates a full kernel decomposition. + eff_in_dim_conv_b = in_dim * kernel_size * kernel_size + + hada_w1_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01 + hada_w1_b = torch.randn(rank, eff_in_dim_conv_b, dtype=DTYPE) * 0.01 + hada_w2_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01 + hada_w2_b = torch.randn(rank, eff_in_dim_conv_b, dtype=DTYPE) * 0.01 + else: # Linear layers + hada_w1_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01 + hada_w1_b = torch.randn(rank, in_dim, dtype=DTYPE) * 0.01 + hada_w2_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01 + hada_w2_b = torch.randn(rank, in_dim, dtype=DTYPE) * 0.01 + + state_dict[f"{layer_name}.hada_w1_a"] = hada_w1_a + state_dict[f"{layer_name}.hada_w1_b"] = hada_w1_b + state_dict[f"{layer_name}.hada_w2_a"] = hada_w2_a + state_dict[f"{layer_name}.hada_w2_b"] = hada_w2_b + + # Alpha tensor (scalar) + state_dict[f"{layer_name}.alpha"] = torch.tensor(float(alpha), dtype=DTYPE) + + # IMPORTANT: No per-module ".dim" tensor, as per analysis of working file. + # Rank is implicit in weight shapes and global metadata. + + # --- Metadata (mimicking the working LoHA file) --- + metadata["ss_network_module"] = "lycoris.kohya" + metadata["ss_network_dim"] = str(DEFAULT_RANK) # Global/default rank + metadata["ss_network_alpha"] = str(DEFAULT_ALPHA) # Global/default alpha + metadata["ss_network_algo"] = "loha" # Also specified inside ss_network_args by convention + + # Mimic ss_network_args from your file + network_args = { + "conv_dim": str(CONV_RANK), + "conv_alpha": str(CONV_ALPHA), + "algo": "loha", + # Add other args from your file if they seem relevant for loading structure, + # but these are the most critical for type/rank. + "dropout": "0.0", # From your file, though value might not matter for dummy + "rank_dropout": "0", # from your file + "module_dropout": "0", # from your file + "use_tucker": "False", # from your file + "use_scalar": "False", # from your file + "rank_dropout_scale": "False", # from your file + "train_norm": "False" # from your file + } + metadata["ss_network_args"] = json.dumps(network_args) + + # Other potentially useful metadata from your working file (optional for basic loading) + metadata["ss_sd_model_name"] = "sd_xl_base_1.0.safetensors" # Example base model + metadata["ss_resolution"] = "(1024,1024)" # Example, format might vary + metadata["modelspec.sai_model_spec"] = "1.0.0" + metadata["modelspec.implementation"] = "https_//github.com/Stability-AI/generative-models" # fixed typo + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base/lora" # Even for LoHA, this is often used + metadata["ss_mixed_precision"] = "bf16" + metadata["ss_note"] = "Dummy LoHA (corrected) for ComfyUI validation. Not trained." + + + # --- Save the State Dictionary with Metadata --- + try: + save_file(state_dict, filepath, metadata=metadata) + print(f"\nSuccessfully saved dummy LoHA file to: {filepath}") + print("\nFile structure (tensor keys):") + for key in state_dict.keys(): + print(f"- {key}: shape {state_dict[key].shape}, dtype {state_dict[key].dtype}") + print("\nMetadata:") + for key, value in metadata.items(): + print(f"- {key}: {value}") + + except Exception as e: + print(f"\nError saving file: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + create_dummy_loha_file() + + # --- Verification Note for ComfyUI --- + # 1. Place `dummy_loha_corrected.safetensors` into `ComfyUI/models/loras/`. + # 2. Load an SDXL base model in ComfyUI. + # 3. Add a "Load LoRA" node and select `dummy_loha_corrected.safetensors`. + # 4. Connect the LoRA node between the checkpoint loader and the KSampler. + # + # Expected outcome: + # - ComfyUI should load the file without "key not loaded" or "dimension mismatch" errors + # for the layers defined in EXAMPLE_LAYERS_CONFIG. + # - The LoRA node should correctly identify it as a LoHA/LyCORIS model. + # - If you have layers in your SDXL model that match the names in EXAMPLE_LAYERS_CONFIG, + # ComfyUI will attempt to apply these (random) weights. \ No newline at end of file diff --git a/tools/extract loha and lora examples.txt b/tools/extract loha and lora examples.txt new file mode 100644 index 0000000..108b525 --- /dev/null +++ b/tools/extract loha and lora examples.txt @@ -0,0 +1,65 @@ +D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py ` +--save_precision fp16 ` +--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors ` +--model_tuned E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors ` +--save_to E:/lora/sdxl/dreamshaperXL_alpha2Xl10_sv_fro_0.9_1024.safetensors ` +--dim 1024 ` +--device cuda ` +--sdxl ` +--dynamic_method sv_fro ` +--dynamic_param 0.9 ` +--verbose + +D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py ` +--save_precision fp16 ` +--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors ` +--model_tuned E:/models/sdxl/proteus_v06.safetensors ` +--save_to E:/lora/sdxl/proteus_v06_sv_cumulative_knee_1024.safetensors ` +--dim 1024 ` +--device cuda ` +--sdxl ` +--dynamic_method sv_cumulative_knee ` +--verbose + +D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\lr_finder.py ` +E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors ` +E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors ` + --lr_finder_num_layers 16 ` + --lr_finder_min_lr 1e-8 ` + --lr_finder_max_lr 0.2 ` + --lr_finder_num_steps 120 ` + --lr_finder_iters_per_step 40 ` + --rank 8 ` + --initial_alpha 8.0 ` + --precision bf16 ` + --device cuda ` + --lr_finder_plot ` + --lr_finder_show_plot + +D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py ` +E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors ` +E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors ` +E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_1e-7.safetensors ` +--rank 2 ` +--initial_alpha 2 ` +--max_rank_retries 7 ` +--rank_increase_factor 2 ` +--max_iterations 8000 ` +--min_iterations 400 ` +--target_loss 1e-7 ` +--lr 1e-01 ` +--device cuda ` +--precision fp32 ` +--verbose ` +--save_weights_dtype bf16 ` +--progress_check_interval 100 ` +--save_every_n_layers 10 ` +--keep_n_resume_files 10 ` +--skip_delta_threshold 1e-7 ` +--rank_search_strategy binary_search_min_rank ` +--probe_aggressive_early_stop + +D:\kohya_ss\venv\Scripts\python.exe D:\kohya_ss\tools\model_diff_report.py ` +E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors ` +E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors ` +--top_n_diff 15 --plot_histograms --plot_histograms_top_n 3 --output_dir ./analysis_results \ No newline at end of file diff --git a/tools/extract_loha_from_model.py b/tools/extract_loha_from_model.py new file mode 100644 index 0000000..d1addf9 --- /dev/null +++ b/tools/extract_loha_from_model.py @@ -0,0 +1,1076 @@ +import argparse +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file, load_file +import safetensors +from tqdm import tqdm +import math +import json +from collections import OrderedDict +import signal +import sys +import glob +import traceback +import re +from enum import Enum, auto + +# --- Global variables --- +extracted_loha_state_dict_global = OrderedDict() +layer_optimization_stats_global = [] +args_global = None +processed_layers_this_session_count_global = 0 +previously_completed_module_prefixes_global = set() +all_completed_module_prefixes_ever_global = set() +skipped_identical_count_global = 0 +skipped_other_reason_count_global = 0 +skipped_good_initial_loss_count_global = 0 +keys_scanned_this_run_global = 0 +save_attempted_on_interrupt = False +outer_pbar_global = None +main_loop_completed_scan_flag_global = False +params_to_seed_optimizer_global = {} +skipped_vae_layers_count = 0 # Ensure this is a global if accessed in main and other places + +# --- Logging Helper --- +class LogType(Enum): + RANK_RETRY_STARTING = auto() + RANK_INCREASED_INFO = auto() + INITIAL_PARAMS_LOADED = auto() + INITIAL_PARAMS_KAIMING_NORMAL = auto() + INSUFFICIENT_PROGRESS_STOP = auto() + PROJECTION_STOP = auto() + INSUFFICIENT_PROGRESS_LOG_ONLY = auto() + PROJECTION_LOG_ONLY = auto() + EMA_PROJECTION_SKIPPED_HISTORY = auto() + EMA_PROJECTION_INCONCLUSIVE_FALLBACK_RAW = auto() + NEW_BEST_RESULT_FOR_LAYER = auto() + TARGET_LOSS_REACHED_IN_ATTEMPT = auto() + TARGET_LOSS_MET_STOP_ALL_RETRIES = auto() + ATTEMPT_EARLY_FINISH_NO_STOP_FLAG = auto() + LAST_RANK_ATTEMPT_SUMMARY = auto() + ATTEMPT_ENDED_WILL_RETRY = auto() + NO_VALID_OPTIMIZATION_RESULT = auto() + +def log_layer_optimization_event(log_type: LogType, layer_name: str, **kwargs): + if not (args_global and args_global.verbose): + return + if not args_global.verbose_layer_debug: + if log_type in [ + LogType.INITIAL_PARAMS_LOADED, LogType.INITIAL_PARAMS_KAIMING_NORMAL, + LogType.EMA_PROJECTION_SKIPPED_HISTORY, LogType.EMA_PROJECTION_INCONCLUSIVE_FALLBACK_RAW + ]: + return + prefix = f" {layer_name}: " + msg = "" + if log_type == LogType.RANK_RETRY_STARTING: + msg = f"Retrying (PrevR: {kwargs.get('prev_rank', 'N/A')}, BestLoss: {kwargs.get('prev_best_loss', float('inf')):.2e}) -> Increasing rank..." + elif log_type == LogType.RANK_INCREASED_INFO: + warm_start_msg = "" + status = kwargs.get('warm_start_status') + if status == 'applied': warm_start_msg = f" Warm-start from R: {kwargs['prev_rank_for_warm_start']}." + elif status == 'skipped_no_warm_start_arg': warm_start_msg = " Warm-start skipped (--no_warm_start)." + elif status == 'skipped_cannot_warm_start': warm_start_msg = f" Cannot warm-start (PrevR {kwargs['prev_rank_for_warm_start']} !< NewR {kwargs['new_rank']})." + elif status == 'no_prior_params_for_warm_start': warm_start_msg = " No prior best params to warm-start." + msg = f" Increased Rank: {kwargs['new_rank']}, Alpha: {kwargs['new_alpha']:.2f}.{warm_start_msg}" + elif log_type == LogType.INITIAL_PARAMS_LOADED: + msg = f" R:{kwargs['rank']} Initialized from existing LoHA." + elif log_type == LogType.INITIAL_PARAMS_KAIMING_NORMAL: + msg = f" R:{kwargs['rank']} Initialized Kaiming/Normal (Attempt {kwargs.get('attempt', 1)})." + elif log_type == LogType.INSUFFICIENT_PROGRESS_STOP: + msg = f"Att {kwargs['attempt']}(R:{kwargs['rank']}): Stop - RawProg Low (Imprv: {kwargs['rel_imprv']:.1e} < {kwargs['min_ratio']:.1e}; Loss: {kwargs['current_loss']:.2e})." + elif log_type == LogType.PROJECTION_STOP: + details = f"Est. {kwargs.get('iters_needed', 'inf'):.0f} iters" + if kwargs.get('proj_final_loss') is not None: details += f", ProjLoss: ~{kwargs['proj_final_loss']:.2e}" + else: details += ", Target Unreachable" + msg = f"Att {kwargs['attempt']}(R:{kwargs['rank']}): Stop - Proj ({kwargs['proj_type']}). {details} vs Target: {kwargs['target_loss']:.2e} (Avail: {kwargs['avail_iters']})." + elif log_type == LogType.INSUFFICIENT_PROGRESS_LOG_ONLY: + msg = f"Att {kwargs['attempt']}(R:{kwargs['rank']}) [LastRankLog]: RawProg Low (Imprv: {kwargs['rel_imprv']:.1e} < {kwargs['min_ratio']:.1e}; Loss: {kwargs['current_loss']:.2e})." + elif log_type == LogType.PROJECTION_LOG_ONLY: + details = f"Est. {kwargs.get('iters_needed', 'inf'):.0f} iters" + if kwargs.get('proj_final_loss') is not None: details += f", ProjLoss: ~{kwargs['proj_final_loss']:.2e}" + else: details += ", Target Unreachable" + msg = f"Att {kwargs['attempt']}(R:{kwargs['rank']}) [LastRankLog]: Proj ({kwargs['proj_type']}). {details} vs Target: {kwargs['target_loss']:.2e} (Avail: {kwargs['avail_iters']})." + elif log_type == LogType.EMA_PROJECTION_SKIPPED_HISTORY: + msg = f" Att {kwargs['attempt']}(R:{kwargs['rank']}): EMA proj. skipped (Hist {kwargs['hist_len']}/{kwargs['min_hist']})." + elif log_type == LogType.EMA_PROJECTION_INCONCLUSIVE_FALLBACK_RAW: + msg = f" Att {kwargs['attempt']}(R:{kwargs['rank']}): EMA proj. inconclusive, using raw." + elif log_type == LogType.NEW_BEST_RESULT_FOR_LAYER: + msg = f"Att {kwargs['attempt']}(R:{kwargs['rank']}): New Best -> Loss {kwargs['loss']:.2e}." + elif log_type == LogType.TARGET_LOSS_REACHED_IN_ATTEMPT: + msg = f"Att {kwargs['attempt']}(R:{kwargs['rank']}): Target loss {kwargs['target_loss']:.2e} met at iter {kwargs['iter']}." + elif log_type == LogType.TARGET_LOSS_MET_STOP_ALL_RETRIES: + msg = "Target loss met. Halting rank retries for this layer." + elif log_type == LogType.ATTEMPT_EARLY_FINISH_NO_STOP_FLAG: + msg = f"Att {kwargs['attempt']}(R:{kwargs['rank']}): Finished early ({kwargs['iters_done']}/{kwargs['max_iters']}), no stop/target. Using result." + elif log_type == LogType.LAST_RANK_ATTEMPT_SUMMARY: + reason = "All rank attempts completed." + if kwargs.get('target_loss') is not None and kwargs.get('final_loss_for_layer',0) > kwargs['target_loss']: + reason = f"Last rank (R:{kwargs['final_rank_for_layer']}) finished, target {kwargs['target_loss']:.2e} not met." + msg = f"{reason} Best: Loss {kwargs.get('final_loss_for_layer',0):.2e}, Rank {kwargs['final_rank_for_layer']}." + elif log_type == LogType.ATTEMPT_ENDED_WILL_RETRY: + reason_detail = "" + if kwargs.get('reason_type') == 'projection_unreachable': + proj_loss_info = f" (ProjL: ~{kwargs['proj_final_loss']:.2e})" if kwargs.get('proj_final_loss') is not None else "" + reason_detail = f"target {kwargs['target_loss']:.2e} proj. unreachable{proj_loss_info} ({kwargs['proj_type']})" + elif kwargs.get('reason_type') == 'insufficient_progress': reason_detail = "insufficient raw progress" + elif kwargs.get('reason_type') == 'max_iterations_no_target': reason_detail = f"max iters (Loss {kwargs['current_loss']:.2e}, Target not met)" + elif kwargs.get('reason_type') == 'max_iterations_no_target_set': reason_detail = f"max iters (Loss {kwargs['current_loss']:.2e})" + if reason_detail: + msg = f"Att {kwargs['attempt']}(R:{kwargs['rank']}) ended: {reason_detail}. Will try next rank..." + elif log_type == LogType.NO_VALID_OPTIMIZATION_RESULT: + msg = "No valid optimization result (likely interrupted)." + if msg: + tqdm.write(prefix + msg) + +def _get_closest_ema_value_before_iter(target_iter: int, ema_history: list[tuple[int, float]]) -> tuple[int | None, float | None]: + if not ema_history: return None, None + best_match_iter, best_match_loss = None, None + for hist_iter, hist_loss in reversed(ema_history): + if hist_iter <= target_iter: + if best_match_iter is None or hist_iter > best_match_iter : + best_match_iter, best_match_loss = hist_iter, hist_loss + if best_match_iter is not None and target_iter - hist_iter > getattr(args_global, 'projection_sample_interval', 20) * 2: break + return (best_match_iter, best_match_loss) if best_match_iter is not None else (ema_history[0] if ema_history else (None, None)) + +def initialize_loha_parameters( + out_dim: int, current_rank: int, in_dim_effective_k_ops: int, + device: str, dtype: torch.dtype, layer_name: str, attempt_idx: int, + is_continuation_attempt: bool, + existing_params_to_load: dict | None = None, + warm_start_status: str | None = None, + prev_rank_for_warm_start: int | None = None +): + hada_w1_a_p = nn.Parameter(torch.empty(out_dim, current_rank, device=device, dtype=dtype)) + hada_w1_b_p = nn.Parameter(torch.empty(current_rank, in_dim_effective_k_ops, device=device, dtype=dtype)) + hada_w2_a_p = nn.Parameter(torch.empty(out_dim, current_rank, device=device, dtype=dtype)) + hada_w2_b_p = nn.Parameter(torch.empty(current_rank, in_dim_effective_k_ops, device=device, dtype=dtype)) + initialized_from_external_or_warm_start = False + with torch.no_grad(): + if is_continuation_attempt and attempt_idx == 0 and existing_params_to_load: + try: + if existing_params_to_load['hada_w1_a'].shape[1] == current_rank: + for p_name_suffix in ['hada_w1_a', 'hada_w1_b', 'hada_w2_a', 'hada_w2_b']: + target_param = locals()[f"{p_name_suffix}_p"] + target_param.data.copy_(existing_params_to_load[p_name_suffix].to(device, dtype)) + log_layer_optimization_event(LogType.INITIAL_PARAMS_LOADED, layer_name, rank=current_rank) + initialized_from_external_or_warm_start = True + except Exception as e: + if args_global.verbose_layer_debug: + tqdm.write(f" {layer_name}: Error loading initial existing params: {e}. Falling back.") + pass + if not initialized_from_external_or_warm_start and warm_start_status == 'applied' and existing_params_to_load and prev_rank_for_warm_start is not None: + prev_params_cpu = {k: existing_params_to_load[k] for k in ['hada_w1_a', 'hada_w1_b', 'hada_w2_a', 'hada_w2_b']} + hada_w1_a_p.data[:, :prev_rank_for_warm_start] = prev_params_cpu['hada_w1_a'].to(device, dtype) + hada_w1_b_p.data[:prev_rank_for_warm_start, :] = prev_params_cpu['hada_w1_b'].to(device, dtype) + hada_w2_a_p.data[:, :prev_rank_for_warm_start] = prev_params_cpu['hada_w2_a'].to(device, dtype) + hada_w2_b_p.data[:prev_rank_for_warm_start, :] = prev_params_cpu['hada_w2_b'].to(device, dtype) + if current_rank > prev_rank_for_warm_start: + for p_slice in [hada_w1_a_p.data[:, prev_rank_for_warm_start:], hada_w2_a_p.data[:, prev_rank_for_warm_start:]]: + nn.init.kaiming_uniform_(p_slice, a=math.sqrt(5)) + for p_slice in [hada_w1_b_p.data[prev_rank_for_warm_start:, :], hada_w2_b_p.data[prev_rank_for_warm_start:, :]]: + nn.init.normal_(p_slice, std=0.02) + initialized_from_external_or_warm_start = True + if not initialized_from_external_or_warm_start: + log_layer_optimization_event(LogType.INITIAL_PARAMS_KAIMING_NORMAL, layer_name, rank=current_rank, attempt=attempt_idx + 1) + for p in [hada_w1_a_p, hada_w2_a_p]: nn.init.kaiming_uniform_(p.data, a=math.sqrt(5)) + for p in [hada_w1_b_p, hada_w2_b_p]: nn.init.normal_(p.data, std=0.02) + return hada_w1_a_p, hada_w1_b_p, hada_w2_a_p, hada_w2_b_p + +def check_insufficient_progress( + current_loss: float, + loss_at_window_start: float, + min_progress_ratio: float, + target_loss: float | None, + perform_stop_check: bool, + layer_name: str, + attempt_idx: int, + rank: int +) -> tuple[bool, dict | None]: + raw_rel_imprv = 0.0 + if loss_at_window_start > 1e-12 and loss_at_window_start > current_loss: + raw_rel_imprv = (loss_at_window_start - current_loss) / loss_at_window_start + log_details = { + 'attempt': attempt_idx + 1, 'rank': rank, + 'rel_imprv': raw_rel_imprv, 'min_ratio': min_progress_ratio, + 'current_loss': current_loss + } + if (target_loss is None or current_loss > target_loss * 1.01) and raw_rel_imprv < min_progress_ratio: + log_event_type = LogType.INSUFFICIENT_PROGRESS_STOP if perform_stop_check else LogType.INSUFFICIENT_PROGRESS_LOG_ONLY + log_layer_optimization_event(log_event_type, layer_name, **log_details) + if perform_stop_check: + return True, log_details + return False, None + +def check_loss_projection( + ema_loss_history: list[tuple[int, float]], + current_raw_loss: float, + raw_rel_imprv_for_fallback: float, + target_loss: float, + max_total_iters_for_attempt: int, + current_iter_num: int, + prog_check_interval: int, + proj_min_ema_hist: int, + adv_proj_decay_cap_min: float, + adv_proj_decay_cap_max: float, + perform_stop_check: bool, + layer_name: str, + attempt_idx: int, + rank: int, + relative_improvement_history_ref: list[float] +) -> tuple[bool, dict | None]: + use_ema = len(ema_loss_history) >= proj_min_ema_hist + if not use_ema: + log_layer_optimization_event(LogType.EMA_PROJECTION_SKIPPED_HISTORY, layer_name, + attempt=attempt_idx + 1, rank=rank, + hist_len=len(ema_loss_history), min_hist=proj_min_ema_hist) + req_iters_to_target = float('inf') + proj_type_at_check = "none" + projected_final_loss_val = None + iters_remaining_in_attempt = max_total_iters_for_attempt - current_iter_num + num_intervals_remaining = iters_remaining_in_attempt // prog_check_interval if prog_check_interval > 0 else 0 + + if use_ema: + ema_curr_iter, ema_curr_loss = ema_loss_history[-1] + ema_start_iter, ema_start_loss = _get_closest_ema_value_before_iter(current_iter_num - prog_check_interval, ema_loss_history) + smooth_ema_imprv = 0.0 + if ema_start_loss and ema_start_loss > ema_curr_loss and ema_start_iter is not None and ema_curr_iter > ema_start_iter: + smooth_ema_imprv = (ema_start_loss - ema_curr_loss) / ema_start_loss + if smooth_ema_imprv > 1e-9: + relative_improvement_history_ref.append(smooth_ema_imprv) + relative_improvement_history_ref[:] = relative_improvement_history_ref[-2:] + if len(relative_improvement_history_ref) >= 2 and relative_improvement_history_ref[-2] > 1e-9: + proj_type_at_check = "advanced_ema" + decay_R_factor = max(adv_proj_decay_cap_min, min(adv_proj_decay_cap_max, smooth_ema_imprv / relative_improvement_history_ref[-2])) + sim_loss, sim_R_per_interval, sim_iters_elapsed = ema_curr_loss, smooth_ema_imprv, 0 + if sim_loss > target_loss: + for _window_idx in range(num_intervals_remaining + 1): + sim_loss *= (1.0 - max(0, sim_R_per_interval)) + sim_iters_elapsed += prog_check_interval + if sim_loss <= target_loss: break + if sim_iters_elapsed >= iters_remaining_in_attempt + prog_check_interval: break + sim_R_per_interval = max(1e-7, sim_R_per_interval * decay_R_factor) + req_iters_to_target = sim_iters_elapsed if sim_loss <= target_loss else float('inf') + if req_iters_to_target == float('inf'): projected_final_loss_val = sim_loss + else: + proj_type_at_check = "simple_ema" + (" (fallback_adv)" if len(relative_improvement_history_ref) >=2 else "") + if ema_curr_loss > target_loss: + try: + intervals_needed = math.log(target_loss / ema_curr_loss) / math.log(1.0 - smooth_ema_imprv) + req_iters_to_target = math.ceil(intervals_needed) * prog_check_interval + except (ValueError, OverflowError, ZeroDivisionError): + req_iters_to_target = float('inf') + projected_final_loss_val = ema_curr_loss * ((1.0 - max(0, smooth_ema_imprv))**(num_intervals_remaining + 1)) + else: req_iters_to_target = 0 + elif use_ema: + proj_type_at_check = "stalled_ema"; projected_final_loss_val = ema_curr_loss + if req_iters_to_target == float('inf') and not proj_type_at_check.startswith("stalled_"): + if use_ema and not proj_type_at_check.endswith("_ema"): + log_layer_optimization_event(LogType.EMA_PROJECTION_INCONCLUSIVE_FALLBACK_RAW, layer_name, attempt=attempt_idx+1, rank=rank) + if raw_rel_imprv_for_fallback > 1e-9 and current_raw_loss > target_loss: + proj_type_at_check = "simple_raw_fallback" + try: + intervals_needed = math.log(target_loss / current_raw_loss) / math.log(1.0 - raw_rel_imprv_for_fallback) + req_iters_to_target = math.ceil(intervals_needed) * prog_check_interval + projected_final_loss_val = None + except (ValueError, OverflowError, ZeroDivisionError): + req_iters_to_target = float('inf') + projected_final_loss_val = current_raw_loss * ((1.0 - max(0, raw_rel_imprv_for_fallback))**(num_intervals_remaining + 1)) + elif current_raw_loss <= target_loss: + proj_type_at_check = "raw_target_met_fallback"; req_iters_to_target = 0; projected_final_loss_val = None + elif raw_rel_imprv_for_fallback <= 1e-9 : + proj_type_at_check = "stalled_raw_fallback"; projected_final_loss_val = current_raw_loss + log_details = { + 'attempt': attempt_idx + 1, 'rank': rank, 'proj_type': proj_type_at_check, + 'iters_needed': req_iters_to_target, 'proj_final_loss': projected_final_loss_val, + 'target_loss': target_loss, 'avail_iters': iters_remaining_in_attempt + } + if req_iters_to_target > iters_remaining_in_attempt : + log_event_type = LogType.PROJECTION_STOP if perform_stop_check else LogType.PROJECTION_LOG_ONLY + log_layer_optimization_event(log_event_type, layer_name, **log_details) + if perform_stop_check: + return True, log_details + return False, log_details if proj_type_at_check != "none" else None + +def get_module_shape_info_from_weight(weight_tensor: torch.Tensor): + if len(weight_tensor.shape) == 4: is_conv = True; out_dim, in_dim_effective, k_h, k_w = weight_tensor.shape; return out_dim, in_dim_effective, k_h, k_w, True + elif len(weight_tensor.shape) == 2: is_conv = False; out_dim, in_dim = weight_tensor.shape; return out_dim, in_dim, None, None, False + return None + +def generate_intermediate_filename(base_save_path: str, num_total_completed_layers: int) -> str: + base, ext = os.path.splitext(base_save_path) + return f"{base}_resume_L{num_total_completed_layers}{ext}" + +def prepare_save_metadata( + script_args: argparse.Namespace, + output_filename: str, + total_completed_modules: int, + processed_this_run: int, + skipped_identical_this_run: int, + skipped_other_this_run: int, + skipped_good_initial_this_run: int, + scanned_keys_this_run: int, + all_completed_module_prefixes_list: list[str], + layer_opt_stats_this_run: list[dict], + is_interrupted_save: bool +) -> tuple[dict, dict]: + net_alpha_str = f"{script_args.initial_alpha:.8f}" if script_args.initial_alpha is not None else str(script_args.rank) + conv_alpha_val = script_args.initial_conv_alpha if script_args.initial_conv_alpha is not None else (script_args.conv_rank or script_args.rank) + conv_alpha_str = f"{conv_alpha_val:.8f}" if isinstance(conv_alpha_val, float) else str(conv_alpha_val) + network_args_data = { + "algo": "loha", "dim": str(script_args.rank), "alpha": net_alpha_str, + "conv_dim": str(script_args.conv_rank or script_args.rank), "conv_alpha": conv_alpha_str, + **{k: str(getattr(script_args, k)) for k in ["dropout", "rank_dropout", "module_dropout"]} + } + sf_meta = { + "ss_network_module": "lycoris.kohya", "ss_network_rank": str(script_args.rank), + "ss_network_alpha": net_alpha_str, "ss_network_algo": "loha", + "ss_network_args": json.dumps(network_args_data), + "ss_comment": f"Extracted LoHA (Int: {is_interrupted_save}). OptPrec: {script_args.precision}. SaveDtype: {script_args.save_weights_dtype}. Layers: {total_completed_modules}.", + "ss_base_model_name": os.path.splitext(os.path.basename(script_args.base_model_path))[0], + "ss_ft_model_name": os.path.splitext(os.path.basename(script_args.ft_model_path))[0], + "ss_save_weights_dtype": script_args.save_weights_dtype, + "ss_optimization_precision": script_args.precision, + "ss_completed_loha_modules": json.dumps(all_completed_module_prefixes_list) + } + serializable_script_args = { + k: str(v) if not isinstance(v, (str, int, float, bool, list, dict, type(None))) else v + for k, v in vars(script_args).items() + } + extraction_summary_data = { + "total_cumulative": total_completed_modules, "this_session": processed_this_run, + "skipped_identical": skipped_identical_this_run, "skipped_other": skipped_other_this_run, + "skipped_good_initial": skipped_good_initial_this_run, "scanned_keys": scanned_keys_this_run + } + json_meta = { + "comfyui_lora_type": "LyCORIS_LoHa", "model_name": os.path.splitext(output_filename)[0], + "base_model_path": script_args.base_model_path, "ft_model_path": script_args.ft_model_path, + "loha_extraction_settings": serializable_script_args, + "extraction_summary": extraction_summary_data, + "layer_optimization_details_this_session": layer_opt_stats_this_run, + "embedded_safetensors_metadata": sf_meta, "interrupted_save": is_interrupted_save + } + return sf_meta, json_meta + +def perform_graceful_save(output_path_to_save: str): + global extracted_loha_state_dict_global, layer_optimization_stats_global, args_global, \ + processed_layers_this_session_count_global, save_attempted_on_interrupt, \ + skipped_identical_count_global, skipped_other_reason_count_global, keys_scanned_this_run_global, \ + all_completed_module_prefixes_ever_global, skipped_good_initial_loss_count_global + + total_processed_ever = len(all_completed_module_prefixes_ever_global) + if not extracted_loha_state_dict_global and not total_processed_ever: + print(f"No layers to save to {output_path_to_save}. Aborted.") + return False + if not args_global: + print("Error: Global args not set for saving metadata.") + return False + + save_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16}.get(args_global.save_weights_dtype, torch.bfloat16) + final_sd = OrderedDict( + (k, v.to(save_dtype) if hasattr(v, 'is_floating_point') and v.is_floating_point() else v) + for k, v in extracted_loha_state_dict_global.items() + ) + print(f"\nSaving LoHA for {total_processed_ever} modules ({processed_layers_this_session_count_global} this session) to {output_path_to_save}") + sf_meta, json_meta = prepare_save_metadata( + script_args=args_global, output_filename=os.path.basename(output_path_to_save), + total_completed_modules=total_processed_ever, + processed_this_run=processed_layers_this_session_count_global, + skipped_identical_this_run=skipped_identical_count_global, + skipped_other_this_run=skipped_other_reason_count_global, + skipped_good_initial_this_run=skipped_good_initial_loss_count_global, + scanned_keys_this_run=keys_scanned_this_run_global, + all_completed_module_prefixes_list=list(all_completed_module_prefixes_ever_global), + layer_opt_stats_this_run=layer_optimization_stats_global, + is_interrupted_save=save_attempted_on_interrupt + ) + temp_sf_path, temp_json_path = None, None + try: + if output_path_to_save.endswith(".safetensors"): + temp_sf_path = output_path_to_save + ".part" + final_json_path = os.path.splitext(output_path_to_save)[0] + "_extraction_metadata.json" + temp_json_path = final_json_path + ".part" + save_file(final_sd, temp_sf_path, metadata=sf_meta) + with open(temp_json_path, 'w') as f: json.dump(json_meta, f, indent=4) + os.replace(temp_sf_path, output_path_to_save) + os.replace(temp_json_path, final_json_path) + print(f"Saved: {output_path_to_save} and {final_json_path}") + else: + torch.save({'state_dict': final_sd, '__metadata__': sf_meta, '__extended_metadata__': json_meta}, output_path_to_save) + print(f"Saved (basic .pt): {output_path_to_save}") + return True + except Exception as e: + print(f"Error saving to {output_path_to_save}: {e}"); traceback.print_exc() + if temp_sf_path and os.path.exists(temp_sf_path): + try: os.remove(temp_sf_path) + except OSError: pass + if temp_json_path and os.path.exists(temp_json_path): + try: os.remove(temp_json_path) + except OSError: pass + return False + +def cleanup_intermediate_files(final_intended_path: str, for_resume_management: bool = False, keep_n: int = 0): + output_dir = os.path.dirname(final_intended_path); base_name, save_ext = os.path.splitext(os.path.basename(final_intended_path)) + if not output_dir: output_dir = "." + intermediate_pattern = os.path.join(output_dir, f"{base_name}_resume_L*{save_ext}") + files_to_consider = [{'path': fp, 'l_count': int(m.group(1))} for fp in glob.glob(intermediate_pattern) if (m := re.search(r'_resume_L(\d+)', os.path.basename(fp)))] + if not files_to_consider: return + cleaned_count = 0 + if for_resume_management: + if keep_n <= 0 or len(files_to_consider) <= keep_n: return + files_to_consider.sort(key=lambda x: x['l_count']) + files_to_delete = files_to_consider[:-keep_n] + if args_global and args_global.verbose: print(f" Resume Manager: Found {len(files_to_consider)} files. Deleting {len(files_to_delete)} oldest to keep {keep_n}.") + else: + files_to_delete = files_to_consider + if args_global and args_global.verbose: print(f" Cleaning ALL {len(files_to_delete)} intermediate files...") + for file_info in files_to_delete: + try: + os.remove(file_info['path']) + if args_global and args_global.verbose: print(f" Cleaned: {file_info['path']}") + cleaned_count += 1 + json_path = os.path.splitext(file_info['path'])[0] + "_extraction_metadata.json" + if os.path.exists(json_path): os.remove(json_path) + except OSError as e: print(f" Warning: Could not clean {file_info['path']}: {e}") + if cleaned_count > 0: print(f" Cleaned {cleaned_count} file(s).") + +def find_best_resume_file(intended_final_path: str) -> tuple[str | None, int]: + output_dir = os.path.dirname(intended_final_path) + if not output_dir: output_dir = "." + base_save_name, save_ext = os.path.splitext(os.path.basename(intended_final_path)) + potential_files = [] + if os.path.exists(intended_final_path): potential_files.append(intended_final_path) + intermediate_pattern = os.path.join(output_dir, f"{base_save_name}_resume_L*{save_ext}") + potential_files.extend(glob.glob(intermediate_pattern)) + best_file_path = None; max_completed_modules = -1 + if not potential_files: return None, -1 + for file_path in sorted(potential_files): + try: + if not os.path.exists(file_path): continue + with safetensors.safe_open(file_path, framework="pt", device="cpu") as f: metadata = f.metadata() + if metadata and "ss_completed_loha_modules" in metadata: + num_completed = len(json.loads(metadata["ss_completed_loha_modules"])) + if num_completed > max_completed_modules: max_completed_modules, best_file_path = num_completed, file_path + elif num_completed == max_completed_modules: + if file_path == intended_final_path or (base_save_name+"_resume_L" in os.path.basename(file_path) and best_file_path != intended_final_path): + best_file_path = file_path + elif max_completed_modules == -1 and (best_file_path is None or (file_path == intended_final_path and best_file_path != intended_final_path)): + best_file_path, max_completed_modules = file_path, 0 + if args_global and args_global.verbose: print(f" File {os.path.basename(file_path)} no metadata. Treating as 0 completed.") + except Exception as e: + print(f" Warning: Could not read metadata from {file_path}: {e}") + if best_file_path is None and file_path == intended_final_path and max_completed_modules == -1 : best_file_path, max_completed_modules = file_path, 0 + if best_file_path: print(f" Selected '{os.path.basename(best_file_path)}' for resume (est. {max_completed_modules} modules).") + return best_file_path, max_completed_modules + +def handle_resume_or_continue_loha( + current_args: argparse.Namespace, + extracted_sd_ref: OrderedDict, + params_to_seed_opt_ref: dict, + prev_completed_prefixes_ref: set, + all_completed_prefixes_ref: set +): + if current_args.continue_training_from_loha: + print(f"\nMode: Continue/Refine from LoHA: {current_args.continue_training_from_loha}") + if not os.path.exists(current_args.continue_training_from_loha): + print(f" Error: LoHA not found: {current_args.continue_training_from_loha}"); sys.exit(1) + try: + loaded_sd = load_file(current_args.continue_training_from_loha, device='cpu') + extracted_sd_ref.update(loaded_sd) + module_prefixes = {".".join(k.split('.')[:-1]) for k in loaded_sd if ".hada_w1_a" in k} + loaded_count = 0 + for prefix in module_prefixes: + params = {p: loaded_sd.get(f"{prefix}.{p}") for p in ['hada_w1_a', 'hada_w1_b', 'hada_w2_a', 'hada_w2_b']} + alpha_t = loaded_sd.get(f"{prefix}.alpha") + if all(params.values()) and alpha_t is not None: + params_to_seed_opt_ref[prefix] = {'params': params, 'rank': params['hada_w1_a'].shape[1], 'alpha': alpha_t.item()} + loaded_count += 1 + elif current_args.verbose: + tqdm.write(f" Info: Module {prefix} from LoHA missing components. Will treat as new if encountered.") + print(f" Loaded {len(extracted_sd_ref)} tensors. Identified {loaded_count} full LoHA modules for re-optimization.") + del loaded_sd + if os.path.exists(current_args.save_to) and not current_args.overwrite: + print(f" Warning: Output {current_args.save_to} exists and may be overwritten.") + elif os.path.exists(current_args.save_to) and current_args.overwrite: + print(f" Info: Output {current_args.save_to} will be overwritten due to --overwrite.") + except Exception as e: + print(f" Error loading LoHA for continuation: {e}."); traceback.print_exc(); sys.exit(1) + elif not current_args.overwrite: + print(f"\nMode: Standard extraction. Checking resume states for: {current_args.save_to}") + resume_file, num_modules_resume = find_best_resume_file(current_args.save_to) + if resume_file: + print(f" Attempting resume from: {resume_file} (est. {num_modules_resume} modules).") + try: + completed_in_file = set() + with safetensors.safe_open(resume_file, framework="pt", device="cpu") as f: + meta = f.metadata() + if meta and "ss_completed_loha_modules" in meta: + completed_in_file = set(json.loads(meta["ss_completed_loha_modules"])) + loaded_sd_resume = load_file(resume_file, device='cpu') + if not completed_in_file and loaded_sd_resume: + completed_in_file = {".".join(k.split('.')[:-1]) for k in loaded_sd_resume if k.endswith(".hada_w1_a")} + res_tensor_count = 0 + if completed_in_file: + for k, v in loaded_sd_resume.items(): + module_prefix_of_key = ".".join(k.split('.')[:-1]) + if module_prefix_of_key in completed_in_file or k.endswith(".bias"): + extracted_sd_ref[k] = v + res_tensor_count += 1 + prev_completed_prefixes_ref.update(completed_in_file) + all_completed_prefixes_ref.update(completed_in_file) + print(f" Loaded {len(prev_completed_prefixes_ref)} module prefixes, {res_tensor_count} tensors for resume.") + elif loaded_sd_resume: + extracted_sd_ref.update(loaded_sd_resume) + inferred_completed = {".".join(k.split('.')[:-1]) for k in loaded_sd_resume if k.endswith(".hada_w1_a")} + prev_completed_prefixes_ref.update(inferred_completed) + all_completed_prefixes_ref.update(inferred_completed) + print(f" Loaded all {len(loaded_sd_resume)} tensors from resume file (metadata for completed modules missing/empty, inferred {len(inferred_completed)}).") + del loaded_sd_resume + except Exception as e: + print(f" Error loading resume file '{resume_file}': {e}. Starting fresh.") + extracted_sd_ref.clear() + prev_completed_prefixes_ref.clear() + all_completed_prefixes_ref.clear() + else: + print(" No suitable existing LoHA to resume from. Starting fresh.") + elif current_args.overwrite: + print(f"\nMode: Standard extraction with --overwrite. Final output {current_args.save_to} will be overwritten.") + +def print_script_summary( + layer_stats: list[dict], + all_completed_prefixes: set, + processed_session: int, + skipped_identical_session: int, + skipped_good_initial_session: int, + skipped_other_session: int, + skipped_vae_session: int, + keys_scanned_session: int, + total_candidates: int +): + print("\n--- Final Optimization Summary (This Session) ---") + if not layer_stats: + print(" No layers were optimized in this session.") + for stat in layer_stats: + rank_info = f"InitialR: {stat['initial_rank_attempted']}, FinalR: {stat['final_rank_used']}" + if stat.get('rank_was_increased', False): rank_info += " (Increased)" + proj_loss_info = "" + if stat.get('final_projected_loss_on_stop') is not None: + proj_loss_info = f" (Proj.FinalLoss ~{stat['final_projected_loss_on_stop']:.2e})" + stop_info = "" + if stat.get('skipped_reopt_due_to_initial_good_loss'): stop_info = ", SkipReOpt:GoodInitialLoss" + elif stat.get('stopped_early_by_loss_target'): stop_info = ", Stop:LossTarget" + elif stat.get('stopped_by_projection', False): + stop_info = f", Stop:Proj({stat.get('projection_type_used','?')})" + proj_loss_info + elif stat.get('stopped_by_insufficient_progress', False): + stop_info = ", Stop:RawProg" + alpha_str = f", Alpha: {stat.get('alpha_final', 'N/A'):.2f}" if 'alpha_final' in stat else "" + print(f"Layer: {stat['name']}, {rank_info}{alpha_str}, Loss: {stat['final_loss']:.4e}, Iters: {stat['iterations_done']}{stop_info}") + + print(f"\n--- Overall Summary ---") + print(f"Total unique LoHA modules in final state (processed, skipped, errored): {len(all_completed_prefixes)}") + print(f" Processed (new/re-opt/skipped-good) this session: {processed_session}") + print(f" Skipped identical (this session): {skipped_identical_session}") + print(f" Skipped re-opt due to good initial loss (this session): {skipped_good_initial_session}") + print(f" Skipped other reasons (this session, VAE, opt error): {skipped_other_session} (incl. {skipped_vae_session} VAE)") + print(f" Total candidate keys scanned (this session): {keys_scanned_session}/{total_candidates}") + +def setup_and_print_configuration(current_args: argparse.Namespace): + if current_args.progress_check_start_iter is None: + current_args.progress_check_start_iter = max(1, current_args.progress_check_interval) if current_args.progress_check_interval > 0 else current_args.max_iterations + 1 + elif current_args.progress_check_interval <= 0 : + current_args.progress_check_start_iter = current_args.max_iterations + 1 + opt_dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} + target_opt_dtype = opt_dtype_map.get(current_args.precision, torch.float32) + final_save_dtype_torch = opt_dtype_map.get(current_args.save_weights_dtype, torch.bfloat16) + print(f"Device: {current_args.device}, Opt Dtype: {target_opt_dtype}, Save Dtype: {final_save_dtype_torch}") + if current_args.target_loss: print(f"Target Loss: {current_args.target_loss:.2e} (min iters: {current_args.min_iterations} for target check)") + else: print(f"No Target Loss. Min iters for any early stop: {current_args.min_iterations}.") + print(f"Max Iters/Layer: {current_args.max_iterations}, Max Rank Retries: {current_args.max_rank_retries}, Rank Incr Factor: {current_args.rank_increase_factor}") + if current_args.save_every_n_layers > 0: print(f"Save every {current_args.save_every_n_layers} processed layers enabled.") + if current_args.keep_n_resume_files > 0: print(f"Keeping the {current_args.keep_n_resume_files} most recent resume files.") + if current_args.progress_check_interval > 0: + first_eval_iter = current_args.progress_check_start_iter + current_args.progress_check_interval + print(f"Progress Check: Enabled. Interval: {current_args.progress_check_interval} iters, Min Rel. Loss Decrease: {current_args.min_progress_loss_ratio:.1e}.") + print(f" Progress window starts at iter: {current_args.progress_check_start_iter}, first evaluation at iter: {first_eval_iter}.") + if current_args.target_loss is not None: + print(f" Projection Check: Enabled (if target loss specified). Decay Caps: min={getattr(current_args, 'advanced_projection_decay_cap_min', 'N/A')}, max={getattr(current_args, 'advanced_projection_decay_cap_max', 'N/A')}") + else: print("Progress Check: Disabled (and Projection Check disabled).") + return current_args + +def load_models(base_model_path: str, ft_model_path: str) -> tuple[OrderedDict, OrderedDict]: + print(f"\nLoading base model: {base_model_path}") + try: + base_sd_raw = load_file(base_model_path, device='cpu') if base_model_path.endswith(".safetensors") else torch.load(base_model_path, map_location='cpu') + base_model_sd = base_sd_raw.get('state_dict', base_sd_raw) if not isinstance(base_sd_raw, OrderedDict) and hasattr(base_sd_raw, 'get') else base_sd_raw + except Exception as e: + print(f"Error loading base model: {e}"); traceback.print_exc(); sys.exit(1) + print(f"Loading fine-tuned model: {ft_model_path}") + try: + ft_sd_raw = load_file(ft_model_path, device='cpu') if ft_model_path.endswith(".safetensors") else torch.load(ft_model_path, map_location='cpu') + ft_model_sd = ft_sd_raw.get('state_dict', ft_sd_raw) if not isinstance(ft_sd_raw, OrderedDict) and hasattr(ft_sd_raw, 'get') else ft_sd_raw + except Exception as e: + print(f"Error loading fine-tuned model: {e}"); traceback.print_exc(); sys.exit(1) + return base_model_sd, ft_model_sd + +def optimize_loha_for_layer( + layer_name: str, delta_W_target: torch.Tensor, out_dim: int, in_dim_effective: int, + k_h: int, k_w: int, initial_rank_for_layer: int, initial_alpha_for_layer: float, + lr: float = 1e-3, max_iterations: int = 1000, min_iterations: int = 100, + target_loss: float = None, weight_decay: float = 1e-4, + device: str = 'cuda', dtype: torch.dtype = torch.float32, + is_conv: bool = True, verbose_layer_debug: bool = False, + max_rank_retries: int = 0, + rank_increase_factor: float = 1.25, + existing_loha_layer_parameters: dict | None = None +): + delta_W_target = delta_W_target.to(device, dtype=dtype) + is_initial_call_with_existing_params = existing_loha_layer_parameters is not None + best_result_so_far = { + 'final_loss': float('inf'), 'stopped_early_by_loss': False, 'stopped_by_insufficient_progress': False, + 'stopped_by_projection': False, 'projection_type_used': 'none', 'iterations_done': 0, + 'final_rank_used': initial_rank_for_layer, 'interrupted_mid_layer': False, 'final_projected_loss_on_stop': None + } + current_rank_for_this_attempt = initial_rank_for_layer + alpha_init_for_this_attempt = initial_alpha_for_layer + rank_base_for_next_increase = initial_rank_for_layer + + prog_check_interval_val = args_global.progress_check_interval + min_prog_ratio_val = args_global.min_progress_loss_ratio + iter_to_begin_first_progress_window = args_global.progress_check_start_iter + adv_proj_decay_cap_min_val = getattr(args_global, 'advanced_projection_decay_cap_min', 0.5) + adv_proj_decay_cap_max_val = getattr(args_global, 'advanced_projection_decay_cap_max', 1.05) + proj_sample_interval_val = getattr(args_global, 'projection_sample_interval', 20) + proj_ema_alpha_val = getattr(args_global, 'projection_ema_alpha', 0.1) + proj_min_ema_hist_val = getattr(args_global, 'projection_min_ema_history', 5) + + for attempt_idx in range(max_rank_retries + 1): + is_last_rank_attempt = (attempt_idx == max_rank_retries) + if save_attempted_on_interrupt: + return {**best_result_so_far, 'interrupted_mid_layer': True, 'iterations_done': 0 if attempt_idx == 0 else best_result_so_far['iterations_done'], 'projection_type_used': 'interrupted'} + + current_warm_start_status = None + params_for_initialization = None + prev_rank_for_warm_start_log = None + if attempt_idx > 0: + log_layer_optimization_event(LogType.RANK_RETRY_STARTING, layer_name, prev_rank=rank_base_for_next_increase, prev_best_loss=best_result_so_far.get('final_loss', float('inf'))) + new_rank_float = rank_base_for_next_increase * rank_increase_factor + increased_rank = math.ceil(new_rank_float) + current_rank_for_this_attempt = max(rank_base_for_next_increase + 1, increased_rank) + original_alpha_to_rank_ratio = initial_alpha_for_layer / float(initial_rank_for_layer) if initial_rank_for_layer > 0 else 1.0 + alpha_init_for_this_attempt = original_alpha_to_rank_ratio * float(current_rank_for_this_attempt) + prev_rank_for_warm_start_log = best_result_so_far.get('final_rank_used') + if 'hada_w1_a' in best_result_so_far and not args_global.no_warm_start: + if prev_rank_for_warm_start_log < current_rank_for_this_attempt: current_warm_start_status = 'applied'; params_for_initialization = best_result_so_far + else: current_warm_start_status = 'skipped_cannot_warm_start' + elif args_global.no_warm_start and 'hada_w1_a' in best_result_so_far: current_warm_start_status = 'skipped_no_warm_start_arg' + elif 'hada_w1_a' not in best_result_so_far: current_warm_start_status = 'no_prior_params_for_warm_start' + log_layer_optimization_event(LogType.RANK_INCREASED_INFO, layer_name, new_rank=current_rank_for_this_attempt, new_alpha=alpha_init_for_this_attempt, warm_start_status=current_warm_start_status, prev_rank_for_warm_start=prev_rank_for_warm_start_log) + elif is_initial_call_with_existing_params: + params_for_initialization = existing_loha_layer_parameters + + k_ops = k_h * k_w if is_conv else 1 + hada_w1_a_p, hada_w1_b_p, hada_w2_a_p, hada_w2_b_p = initialize_loha_parameters( + out_dim, current_rank_for_this_attempt, (in_dim_effective * k_ops), device, dtype, layer_name, attempt_idx, + (attempt_idx == 0 and is_initial_call_with_existing_params), + params_for_initialization, current_warm_start_status, + prev_rank_for_warm_start_log if current_warm_start_status == 'applied' else None + ) + alpha_param = nn.Parameter(torch.tensor(alpha_init_for_this_attempt, device=device, dtype=dtype)) + params_to_optimize = [hada_w1_a_p, hada_w1_b_p, hada_w2_a_p, hada_w2_b_p, alpha_param] + optimizer = torch.optim.AdamW(params_to_optimize, lr=lr, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=max(10, int(max_iterations * 0.05)), factor=0.5, min_lr=max(1e-7, lr * 0.001)) + iter_pbar_desc = f"Opt Att {attempt_idx+1}/{max_rank_retries+1} (R:{current_rank_for_this_attempt}){' [LastRank]' if is_last_rank_attempt else ''}: {layer_name}" + iter_pbar = tqdm(range(max_iterations), desc=iter_pbar_desc, leave=False, dynamic_ncols=True, position=1, mininterval=0.5) + + current_attempt_final_loss = float('inf'); current_attempt_stopped_early_by_loss = False + current_attempt_insufficient_progress = False; current_attempt_stopped_by_projection = False + current_attempt_projection_type = "none"; current_attempt_iterations_done = 0 + loss_at_start_of_current_window = float('inf'); progress_window_started_for_attempt = False + relative_improvement_history_for_proj = [] + final_projected_loss_if_failed = None + ema_loss_history = [] + current_ema_loss_value = None + + for i in iter_pbar: + current_attempt_iterations_done = i + 1 + if save_attempted_on_interrupt: + iter_pbar.close() + best_result_so_far.update({'interrupted_mid_layer': True, 'projection_type_used': 'interrupted', 'iterations_done': i}) + return best_result_so_far + + if prog_check_interval_val > 0 and not progress_window_started_for_attempt and current_attempt_iterations_done >= iter_to_begin_first_progress_window: + loss_at_start_of_current_window = current_attempt_final_loss + progress_window_started_for_attempt = True + + optimizer.zero_grad() + eff_alpha_scale = alpha_param / current_rank_for_this_attempt + term1_flat = hada_w1_a_p @ hada_w1_b_p; term2_flat = hada_w2_a_p @ hada_w2_b_p + delta_W_loha = eff_alpha_scale * (term1_flat.view(out_dim, in_dim_effective, k_h, k_w) * term2_flat.view(out_dim, in_dim_effective, k_h, k_w)) if is_conv else (eff_alpha_scale * term1_flat * term2_flat) + + loss = F.mse_loss(delta_W_loha, delta_W_target) + raw_current_loss_item = loss.item() + if i == 0 and progress_window_started_for_attempt and loss_at_start_of_current_window == float('inf'): + loss_at_start_of_current_window = raw_current_loss_item + current_attempt_final_loss = raw_current_loss_item + loss.backward(); optimizer.step(); scheduler.step(raw_current_loss_item) + + if target_loss is not None and prog_check_interval_val > 0 and (i + 1) % proj_sample_interval_val == 0: + current_ema_loss_value = proj_ema_alpha_val * raw_current_loss_item + (1 - proj_ema_alpha_val) * current_ema_loss_value if current_ema_loss_value is not None else raw_current_loss_item + ema_loss_history.append((current_attempt_iterations_done, current_ema_loss_value)) + + iter_pbar.set_postfix_str(f"Loss={current_attempt_final_loss:.3e}, AlphaP={alpha_param.item():.2f}, LR={optimizer.param_groups[0]['lr']:.1e}", refresh=True) + + if target_loss is not None and current_attempt_iterations_done >= min_iterations and current_attempt_final_loss <= target_loss: + log_layer_optimization_event(LogType.TARGET_LOSS_REACHED_IN_ATTEMPT, layer_name, attempt=attempt_idx+1, rank=current_rank_for_this_attempt, target_loss=target_loss, iter=current_attempt_iterations_done) + current_attempt_stopped_early_by_loss = True; break + + if prog_check_interval_val > 0 and progress_window_started_for_attempt and \ + (current_attempt_iterations_done >= iter_to_begin_first_progress_window + prog_check_interval_val) and \ + (((current_attempt_iterations_done - iter_to_begin_first_progress_window) % prog_check_interval_val) == 0): + + perform_early_stop_checks = not is_last_rank_attempt + + stop_insufficient_prog, _ = check_insufficient_progress( + current_attempt_final_loss, loss_at_start_of_current_window, min_prog_ratio_val, target_loss, + perform_early_stop_checks, layer_name, attempt_idx, current_rank_for_this_attempt + ) + if stop_insufficient_prog: + current_attempt_insufficient_progress = True; break + + if target_loss is not None and current_attempt_final_loss > target_loss: + raw_rel_imprv_for_fallback_proj = (loss_at_start_of_current_window - current_attempt_final_loss) / loss_at_start_of_current_window if loss_at_start_of_current_window > 1e-12 and loss_at_start_of_current_window > current_attempt_final_loss else 0.0 + stop_projection, proj_details = check_loss_projection( + ema_loss_history, current_attempt_final_loss, raw_rel_imprv_for_fallback_proj, target_loss, + max_iterations, current_attempt_iterations_done, prog_check_interval_val, + proj_min_ema_hist_val, adv_proj_decay_cap_min_val, adv_proj_decay_cap_max_val, + perform_early_stop_checks, layer_name, attempt_idx, current_rank_for_this_attempt, + relative_improvement_history_for_proj + ) + if proj_details: + current_attempt_projection_type = proj_details.get('proj_type', 'none') + if stop_projection : final_projected_loss_if_failed = proj_details.get('proj_final_loss') + + if stop_projection: + current_attempt_stopped_by_projection = True; break + + loss_at_start_of_current_window = current_attempt_final_loss + iter_pbar.close() + + if current_attempt_final_loss < best_result_so_far['final_loss'] or \ + (current_attempt_final_loss == best_result_so_far['final_loss'] and current_rank_for_this_attempt < best_result_so_far['final_rank_used']): + log_layer_optimization_event(LogType.NEW_BEST_RESULT_FOR_LAYER, layer_name, attempt=attempt_idx+1, rank=current_rank_for_this_attempt, loss=current_attempt_final_loss) + best_result_so_far.update({ + 'hada_w1_a': hada_w1_a_p.data.cpu().contiguous(), 'hada_w1_b': hada_w1_b_p.data.cpu().contiguous(), + 'hada_w2_a': hada_w2_a_p.data.cpu().contiguous(), 'hada_w2_b': hada_w2_b_p.data.cpu().contiguous(), + 'alpha': alpha_param.data.cpu().contiguous(), 'final_loss': current_attempt_final_loss, + 'stopped_early_by_loss': current_attempt_stopped_early_by_loss, + 'stopped_by_insufficient_progress': current_attempt_insufficient_progress, + 'stopped_by_projection': current_attempt_stopped_by_projection, + 'projection_type_used': current_attempt_projection_type, + 'iterations_done': current_attempt_iterations_done, 'final_rank_used': current_rank_for_this_attempt, + 'interrupted_mid_layer': False, + 'final_projected_loss_on_stop': final_projected_loss_if_failed if current_attempt_stopped_by_projection else None + }) + rank_base_for_next_increase = current_rank_for_this_attempt + + if current_attempt_stopped_early_by_loss: + log_layer_optimization_event(LogType.TARGET_LOSS_MET_STOP_ALL_RETRIES, layer_name) + break + + if current_attempt_iterations_done < max_iterations and not any([current_attempt_insufficient_progress, current_attempt_stopped_by_projection, current_attempt_stopped_early_by_loss]): + if not is_last_rank_attempt: + log_layer_optimization_event(LogType.ATTEMPT_EARLY_FINISH_NO_STOP_FLAG, layer_name, attempt=attempt_idx+1, rank=current_rank_for_this_attempt, iters_done=current_attempt_iterations_done, max_iters=max_iterations) + + if is_last_rank_attempt: + if not (current_attempt_stopped_early_by_loss or current_attempt_insufficient_progress or current_attempt_stopped_by_projection): + _, prog_log_details = check_insufficient_progress(current_attempt_final_loss, loss_at_start_of_current_window, min_prog_ratio_val, target_loss, False, layer_name, attempt_idx, current_rank_for_this_attempt) + if target_loss and current_attempt_final_loss > target_loss and prog_check_interval_val > 0: + raw_rel_imprv_for_fallback_proj_last = (loss_at_start_of_current_window - current_attempt_final_loss) / loss_at_start_of_current_window if loss_at_start_of_current_window > 1e-12 and loss_at_start_of_current_window > current_attempt_final_loss else 0.0 + _, proj_log_details_last = check_loss_projection( + ema_loss_history, current_attempt_final_loss, raw_rel_imprv_for_fallback_proj_last, target_loss, + max_iterations, current_attempt_iterations_done, prog_check_interval_val, + proj_min_ema_hist_val, adv_proj_decay_cap_min_val, adv_proj_decay_cap_max_val, + False, layer_name, attempt_idx, current_rank_for_this_attempt, + relative_improvement_history_for_proj + ) + if proj_log_details_last: + best_result_so_far['projection_type_used'] = proj_log_details_last.get('proj_type', best_result_so_far['projection_type_used']) + best_result_so_far['final_projected_loss_on_stop'] = proj_log_details_last.get('proj_final_loss', best_result_so_far['final_projected_loss_on_stop']) + log_layer_optimization_event(LogType.LAST_RANK_ATTEMPT_SUMMARY, layer_name, target_loss=target_loss, + final_loss_for_layer=best_result_so_far['final_loss'], final_rank_for_layer=best_result_so_far['final_rank_used']) + break + + if not current_attempt_stopped_early_by_loss : + reason_kwargs = {'attempt': attempt_idx + 1, 'rank': current_rank_for_this_attempt, 'is_last_rank_attempt': is_last_rank_attempt} + if current_attempt_stopped_by_projection: + reason_kwargs.update({'reason_type': 'projection_unreachable', 'target_loss': target_loss, 'proj_final_loss': final_projected_loss_if_failed, 'proj_type': current_attempt_projection_type}) + elif current_attempt_insufficient_progress: + reason_kwargs.update({'reason_type': 'insufficient_progress'}) + elif current_attempt_iterations_done >= max_iterations: + reason_kwargs.update({'reason_type': 'max_iterations_no_target' if target_loss else 'max_iterations_no_target_set', 'current_loss': current_attempt_final_loss}) + + if 'reason_type' in reason_kwargs : + log_layer_optimization_event(LogType.ATTEMPT_ENDED_WILL_RETRY, layer_name, **reason_kwargs) + + if 'hada_w1_a' not in best_result_so_far: + log_layer_optimization_event(LogType.NO_VALID_OPTIMIZATION_RESULT, layer_name) + return {'final_loss': float('inf'), 'interrupted_mid_layer': True, 'final_rank_used': initial_rank_for_layer, 'iterations_done':0} + + for key, default_val in [('stopped_early_by_loss', False), ('stopped_by_insufficient_progress', False), + ('stopped_by_projection', False), ('projection_type_used', 'none'), + ('interrupted_mid_layer', False), ('final_projected_loss_on_stop', None), + ('final_rank_used', initial_rank_for_layer)]: + best_result_so_far.setdefault(key, default_val) + return best_result_so_far + + +def handle_interrupt(signum, frame): + # ... (remains the same) ... + global save_attempted_on_interrupt, outer_pbar_global, args_global, all_completed_module_prefixes_ever_global + print("\n" + "="*30 + "\nCtrl+C Detected!\n" + "="*30) + if save_attempted_on_interrupt: print("Save already attempted. Exiting."); return + save_attempted_on_interrupt = True + if outer_pbar_global: outer_pbar_global.close() + if args_global and args_global.save_to: + save_path = generate_intermediate_filename(args_global.save_to, len(all_completed_module_prefixes_ever_global)) + print(f"Attempting interrupt save to: {save_path}") + if perform_graceful_save(save_path) and args_global.keep_n_resume_files > 0: + cleanup_intermediate_files(args_global.save_to, True, args_global.keep_n_resume_files) + else: print("Cannot perform interrupt save: args not defined.") + print("Exiting.") + sys.exit(0) + + +def main(cli_args): + global args_global, extracted_loha_state_dict_global, layer_optimization_stats_global, \ + processed_layers_this_session_count_global, save_attempted_on_interrupt, outer_pbar_global, \ + skipped_identical_count_global, skipped_other_reason_count_global, keys_scanned_this_run_global, \ + previously_completed_module_prefixes_global, all_completed_module_prefixes_ever_global, \ + main_loop_completed_scan_flag_global, params_to_seed_optimizer_global, skipped_good_initial_loss_count_global, \ + skipped_vae_layers_count + + args_global = cli_args + signal.signal(signal.SIGINT, handle_interrupt) + extracted_loha_state_dict_global.clear(); layer_optimization_stats_global.clear() + params_to_seed_optimizer_global.clear(); previously_completed_module_prefixes_global.clear() + all_completed_module_prefixes_ever_global.clear() + processed_layers_this_session_count_global = skipped_identical_count_global = skipped_other_reason_count_global = 0 + skipped_good_initial_loss_count_global = keys_scanned_this_run_global = skipped_vae_layers_count = 0 + main_loop_completed_scan_flag_global = False; save_attempted_on_interrupt = False + + args_global = setup_and_print_configuration(args_global) + target_opt_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}.get(args_global.precision, torch.float32) + final_save_dtype_torch = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}.get(args_global.save_weights_dtype, torch.bfloat16) + handle_resume_or_continue_loha( + args_global, extracted_loha_state_dict_global, params_to_seed_optimizer_global, + previously_completed_module_prefixes_global, all_completed_module_prefixes_ever_global + ) + base_model_sd, ft_model_sd = load_models(args_global.base_model_path, args_global.ft_model_path) + all_candidate_keys = sorted([k for k in base_model_sd if k.endswith('.weight') and k in ft_model_sd and base_model_sd[k].shape == ft_model_sd[k].shape and (len(base_model_sd[k].shape) in [2,4])]) + total_candidates_to_scan = len(all_candidate_keys) + print(f"Found {total_candidates_to_scan} candidate '.weight' keys for LoHA extraction.") + outer_pbar_global = tqdm(total=total_candidates_to_scan, desc="Scanning Layers", dynamic_ncols=True, position=0) + + try: + for key_name in all_candidate_keys: + if save_attempted_on_interrupt: break + keys_scanned_this_run_global += 1; outer_pbar_global.update(1) + original_module_path = key_name[:-len(".weight")] + loha_key_prefix = "lora_" + original_module_path.replace(".", "_") + if "model.diffusion_model." in original_module_path: loha_key_prefix = "lora_unet_" + original_module_path.split("model.diffusion_model.")[-1].replace(".", "_") + elif "first_stage_model." in original_module_path: loha_key_prefix = "lora_vae_" + original_module_path.split("first_stage_model.")[-1].replace(".", "_") + if any(vp in original_module_path for vp in [".encoder.", ".decoder.", ".quant_conv."]) and any(tp in original_module_path for tp in ["first_stage_model.", "autoencoder."]): + if args_global.verbose_layer_debug: tqdm.write(f" Skipping VAE layer: {original_module_path}") + skipped_vae_layers_count += 1; skipped_other_reason_count_global += 1 + all_completed_module_prefixes_ever_global.add(loha_key_prefix) + continue + is_reopt_target = args_global.continue_training_from_loha and loha_key_prefix in params_to_seed_optimizer_global + if loha_key_prefix in all_completed_module_prefixes_ever_global and not is_reopt_target: + if args_global.verbose_layer_debug: tqdm.write(f" Skipping {loha_key_prefix} (already processed/resumed, not re-opt target).") + continue + if args_global.max_layers is not None and args_global.max_layers > 0 and processed_layers_this_session_count_global >= args_global.max_layers: + if args_global.verbose and processed_layers_this_session_count_global == args_global.max_layers: + tqdm.write(f"\nMax_layers ({args_global.max_layers}) for new/re-optimized hit. Scan continues to find all identical/skipped layers.") + outer_pbar_global.set_description_str(f"Scan {keys_scanned_this_run_global}/{total_candidates_to_scan} (Max Layers Reached)") + continue + base_W = base_model_sd[key_name].to(dtype=torch.float32) + ft_W = ft_model_sd[key_name].to(dtype=torch.float32) + out_dim, in_dim_effective, k_h, k_w, is_conv = get_module_shape_info_from_weight(base_W) + delta_W_fp32 = (ft_W - base_W) + if torch.allclose(delta_W_fp32, torch.zeros_like(delta_W_fp32), atol=args_global.atol_fp32_check): + if args_global.verbose_layer_debug: tqdm.write(f" Skipping {loha_key_prefix} (weights identical atol={args_global.atol_fp32_check:.1e}).") + skipped_identical_count_global += 1 + all_completed_module_prefixes_ever_global.add(loha_key_prefix) + continue + current_key_processed_or_skipped_good = False + should_skip_due_to_pre_existing_good_loss = False + if is_reopt_target and args_global.target_loss is not None: + seed_data = params_to_seed_optimizer_global[loha_key_prefix] + loaded_params_cpu = seed_data['params']; loaded_rank_check = seed_data['rank']; loaded_alpha_check = seed_data['alpha'] + if all(k_ in loaded_params_cpu for k_ in ['hada_w1_a', 'hada_w1_b', 'hada_w2_a', 'hada_w2_b']): + try: + with torch.no_grad(): + w1a,w1b,w2a,w2b = (loaded_params_cpu[p].to(args_global.device, target_opt_dtype) for p in ['hada_w1_a','hada_w1_b','hada_w2_a','hada_w2_b']) + alpha_v = torch.tensor(loaded_alpha_check, device=args_global.device, dtype=target_opt_dtype) + eff_a_s = alpha_v / loaded_rank_check; delta_W_target_c = delta_W_fp32.to(args_global.device, target_opt_dtype) + init_loha_d = eff_a_s * (w1a@w1b).view(out_dim,in_dim_effective,k_h,k_w) * (w2a@w2b).view(out_dim,in_dim_effective,k_h,k_w) if is_conv else eff_a_s * (w1a@w1b) * (w2a@w2b) + init_loss_c = F.mse_loss(init_loha_d, delta_W_target_c).item() + if init_loss_c <= args_global.target_loss: + tqdm.write(f" Skip Re-Opt {loha_key_prefix}: Loaded (R:{loaded_rank_check}, A:{loaded_alpha_check:.2f}) meets target. Loss: {init_loss_c:.4e} <= {args_global.target_loss:.4e}") + stat_entry_skip = {"name": str(loha_key_prefix), "original_name": str(original_module_path),"initial_rank_attempted": int(loaded_rank_check), "final_rank_used": int(loaded_rank_check),"rank_was_increased": False, "final_loss": float(init_loss_c),"alpha_final": float(loaded_alpha_check), "iterations_done": 0,"stopped_early_by_loss_target": True, "stopped_by_insufficient_progress": False,"stopped_by_projection": False, "projection_type_used": "none","final_projected_loss_on_stop": None,"skipped_reopt_due_to_initial_good_loss": True, "interrupted_mid_layer": False} + layer_optimization_stats_global.append(stat_entry_skip) + all_completed_module_prefixes_ever_global.add(loha_key_prefix) + processed_layers_this_session_count_global += 1; skipped_good_initial_loss_count_global += 1 + should_skip_due_to_pre_existing_good_loss = True; current_key_processed_or_skipped_good = True + outer_pbar_global.set_description_str(f"Scan {keys_scanned_this_run_global}/{total_candidates_to_scan} (New/ReOpt: {processed_layers_this_session_count_global - skipped_good_initial_loss_count_global}, SkipGood:{skipped_good_initial_loss_count_global})") + elif args_global.verbose_layer_debug: tqdm.write(f" Initial loss for loaded {loha_key_prefix}: {init_loss_c:.4e}. Re-optimizing.") + except Exception as e_c: tqdm.write(f" Warn: Pre-opt loss check failed for {loha_key_prefix}: {e_c}. Optimizing."); + if not should_skip_due_to_pre_existing_good_loss: + current_op_mode_str = "ReOpt" if is_reopt_target else "NewOpt" + if args_global.verbose: tqdm.write(f"\n--- {current_op_mode_str} Layer {processed_layers_this_session_count_global + 1 - skipped_good_initial_loss_count_global}: {loha_key_prefix} (Orig: {original_module_path}) ---") + initial_rank_opt = args_global.conv_rank if is_conv and args_global.conv_rank is not None else args_global.rank + initial_alpha_opt = args_global.initial_conv_alpha if is_conv else args_global.initial_alpha + existing_params_init = None; max_retries_layer = args_global.max_rank_retries + if is_reopt_target: + seed_data = params_to_seed_optimizer_global[loha_key_prefix] + initial_rank_opt, initial_alpha_opt = seed_data['rank'], seed_data['alpha'] + existing_params_init = seed_data['params'] + base_rank_est = args_global.conv_rank if is_conv and args_global.conv_rank is not None else args_global.rank + if initial_rank_opt > base_rank_est and args_global.max_rank_retries > 0 : + est_retries_used = 0; cur_sim_rank = float(base_rank_est) + for _ in range(args_global.max_rank_retries + 10): + if cur_sim_rank >= initial_rank_opt: break + cur_sim_rank = max(math.ceil(cur_sim_rank * args_global.rank_increase_factor), cur_sim_rank + 1); est_retries_used += 1 + max_retries_layer = max(0, args_global.max_rank_retries - est_retries_used) + if args_global.verbose: tqdm.write(f" Using loaded R:{initial_rank_opt}, A:{initial_alpha_opt:.1f}. Max further retries for layer: {max_retries_layer}.") + outer_pbar_global.set_description_str(f"{current_op_mode_str} L{processed_layers_this_session_count_global + 1 - skipped_good_initial_loss_count_global} (Scan {keys_scanned_this_run_global}/{total_candidates_to_scan}, SkipGood:{skipped_good_initial_loss_count_global})") + opt_results = optimize_loha_for_layer(loha_key_prefix, delta_W_fp32, out_dim, in_dim_effective, k_h, k_w, initial_rank_opt, initial_alpha_opt, args_global.lr, args_global.max_iterations, args_global.min_iterations, args_global.target_loss, args_global.weight_decay, args_global.device, target_opt_dtype, is_conv, args_global.verbose_layer_debug, max_retries_layer, args_global.rank_increase_factor, existing_params_init) + if not opt_results.get('interrupted_mid_layer') and 'hada_w1_a' in opt_results : + for p_name, p_val in opt_results.items(): + if p_name not in ['final_loss', 'stopped_early_by_loss', 'stopped_by_insufficient_progress', 'stopped_by_projection', 'projection_type_used', 'iterations_done', 'final_rank_used', 'interrupted_mid_layer', 'final_projected_loss_on_stop']: + if torch.is_tensor(p_val): extracted_loha_state_dict_global[f'{loha_key_prefix}.{p_name}'] = p_val.to(final_save_dtype_torch) + final_rank_used = opt_results['final_rank_used'] + stat_entry = {"name": str(loha_key_prefix),"original_name": str(original_module_path),"initial_rank_attempted": int(initial_rank_opt),"final_rank_used": int(final_rank_used),"rank_was_increased": bool(final_rank_used > initial_rank_opt),"final_loss": float(opt_results['final_loss']),"alpha_final": float(opt_results['alpha'].item()) if isinstance(opt_results.get('alpha'), torch.Tensor) else float(opt_results.get('alpha', 0.0)),"iterations_done": int(opt_results['iterations_done']),"stopped_early_by_loss_target": bool(opt_results['stopped_early_by_loss']),"stopped_by_insufficient_progress": bool(opt_results.get('stopped_by_insufficient_progress', False)),"stopped_by_projection": bool(opt_results.get('stopped_by_projection', False)),"projection_type_used": str(opt_results.get('projection_type_used', 'none')),"final_projected_loss_on_stop": float(l_val) if (l_val := opt_results.get('final_projected_loss_on_stop')) is not None else None,"skipped_reopt_due_to_initial_good_loss": False,"interrupted_mid_layer": bool(opt_results.get('interrupted_mid_layer', False))} + layer_optimization_stats_global.append(stat_entry) + all_completed_module_prefixes_ever_global.add(loha_key_prefix) + stop_reason_short = "" + if opt_results['stopped_early_by_loss']: stop_reason_short = ", Stop:LossTarget" + elif opt_results.get('stopped_by_projection', False): stop_reason_short = f", Stop:Proj({opt_results.get('projection_type_used','?')})" + elif opt_results['stopped_by_insufficient_progress']: stop_reason_short = ", Stop:RawProg" + tqdm.write(f" Layer {loha_key_prefix} Opt. Done. R_used: {final_rank_used}, FinalLoss: {opt_results['final_loss']:.4e}, Iters: {opt_results['iterations_done']}{stop_reason_short}") + if args_global.use_bias: + bias_key = f"{original_module_path}.bias" + if bias_key in ft_model_sd and (bias_key not in base_model_sd or not torch.allclose(base_model_sd[bias_key], ft_model_sd[bias_key], atol=args_global.atol_fp32_check)): + extracted_loha_state_dict_global[bias_key] = ft_model_sd[bias_key].cpu().to(final_save_dtype_torch) + if args_global.verbose: tqdm.write(f" Saved differing/new bias for {bias_key}") + processed_layers_this_session_count_global += 1 + current_key_processed_or_skipped_good = True + else: + tqdm.write(f" Optimization for {loha_key_prefix} did not yield saveable results (Interrupt: {opt_results.get('interrupted_mid_layer', 'N/A')}, Loss: {opt_results.get('final_loss', 'N/A')})") + if not opt_results.get('interrupted_mid_layer', False) and 'hada_w1_a' not in opt_results : + skipped_other_reason_count_global += 1 + all_completed_module_prefixes_ever_global.add(loha_key_prefix) + if current_key_processed_or_skipped_good: + if args_global.save_every_n_layers > 0 and processed_layers_this_session_count_global > 0 and processed_layers_this_session_count_global % args_global.save_every_n_layers == 0 and keys_scanned_this_run_global < total_candidates_to_scan: + periodic_save_path = generate_intermediate_filename(args_global.save_to, len(all_completed_module_prefixes_ever_global)) + tqdm.write(f"\n--- Periodic Save: Processed {processed_layers_this_session_count_global} layers this session. Saving to {periodic_save_path} ---") + if perform_graceful_save(periodic_save_path) and args_global.keep_n_resume_files > 0: + cleanup_intermediate_files(args_global.save_to, True, args_global.keep_n_resume_files) + if not save_attempted_on_interrupt and keys_scanned_this_run_global == total_candidates_to_scan: + main_loop_completed_scan_flag_global = True + finally: + if outer_pbar_global: outer_pbar_global.close() + + if not save_attempted_on_interrupt: + print_script_summary( + layer_stats=layer_optimization_stats_global, + all_completed_prefixes=all_completed_module_prefixes_ever_global, + processed_session=processed_layers_this_session_count_global, + skipped_identical_session=skipped_identical_count_global, + skipped_good_initial_session=skipped_good_initial_loss_count_global, + skipped_other_session=skipped_other_reason_count_global, + skipped_vae_session=skipped_vae_layers_count, + keys_scanned_session=keys_scanned_this_run_global, + total_candidates=total_candidates_to_scan + ) + + is_fully_complete = main_loop_completed_scan_flag_global and len(all_completed_module_prefixes_ever_global) >= total_candidates_to_scan + actual_save_path = args_global.save_to if is_fully_complete else generate_intermediate_filename(args_global.save_to, len(all_completed_module_prefixes_ever_global)) + reason_for_save_path = "Saving to final path (all candidates processed/skipped)" if is_fully_complete else \ + ("Run incomplete (scan not finished or --max_layers hit)." if not main_loop_completed_scan_flag_global else "Full scan done, but not all layers processed/accounted for (e.g. new errors).") + print(f"\n{reason_for_save_path}: {actual_save_path}") + if perform_graceful_save(output_path_to_save=actual_save_path): + if args_global.keep_n_resume_files > 0 and not is_fully_complete : + cleanup_intermediate_files(args_global.save_to, True, args_global.keep_n_resume_files) + if is_fully_complete and actual_save_path == args_global.save_to : + print("\nCleaning up ALL intermediate resume files (from this script's previous runs)...") + cleanup_intermediate_files(args_global.save_to, False) + else: print("\nProcess interrupted. Graceful save to intermediate file attempted.") + +def post_process_cli_args(parsed_args: argparse.Namespace) -> argparse.Namespace: + if parsed_args.verbose_layer_debug: + parsed_args.verbose = True + if not os.path.exists(parsed_args.base_model_path): + print(f"Error: Base model not found: {parsed_args.base_model_path}"); sys.exit(1) + if not os.path.exists(parsed_args.ft_model_path): + print(f"Error: FT model not found: {parsed_args.ft_model_path}"); sys.exit(1) + save_dir = os.path.dirname(parsed_args.save_to) + if save_dir and not os.path.exists(save_dir): + try: + os.makedirs(save_dir, exist_ok=True) + except OSError as e: + print(f"Error creating directory {save_dir}: {e}"); sys.exit(1) + if parsed_args.initial_alpha is None: + parsed_args.initial_alpha = float(parsed_args.rank) + if parsed_args.initial_conv_alpha is None: + if parsed_args.conv_rank is not None: + parsed_args.initial_conv_alpha = float(parsed_args.conv_rank) + else: + parsed_args.initial_conv_alpha = parsed_args.initial_alpha + return parsed_args + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Extract LoHA parameters. Saves intermediate files like 'name_resume_L{count}.safetensors'.") + parser.add_argument("base_model_path", type=str, help="Path to base model (.pt, .pth, .safetensors)") + parser.add_argument("ft_model_path", type=str, help="Path to fine-tuned model (.pt, .pth, .safetensors)") + parser.add_argument("save_to", type=str, help="Path for FINAL LoHA output (recommended .safetensors).") + parser.add_argument("--overwrite", action="store_true", help="Overwrite existing FINAL LoHA. Does NOT clean intermediates until successful final save.") + parser.add_argument("--continue_training_from_loha", type=str, default=None, help="Path to existing LoHA to load and continue optimizing.") + parser.add_argument("--rank", type=int, default=4, help="Default rank for LoHA.") + parser.add_argument("--conv_rank", type=int, default=None, help="Specific rank for Conv LoHA. Defaults to --rank.") + parser.add_argument("--initial_alpha", type=float, default=None, help="Global initial alpha. Defaults to 'rank'.") + parser.add_argument("--initial_conv_alpha", type=float, default=None, help="Specific initial alpha for Conv LoHA. Defaults to conv_rank or initial_alpha.") + parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate per layer.") + parser.add_argument("--max_iterations", type=int, default=1000, help="Max optimization iterations per layer/attempt.") + parser.add_argument("--min_iterations", type=int, default=100, help="Min iterations before target_loss check per attempt.") + parser.add_argument("--target_loss", type=float, default=None, help="Target MSE loss for early stopping. Also for pre-re-opt check.") + parser.add_argument("--weight_decay", type=float, default=1e-5, help="Weight decay for optimization.") + parser.add_argument("--max_rank_retries", type=int, default=0, help="Rank increase retries if target_loss not met (0 for no retries).") + parser.add_argument("--rank_increase_factor", type=float, default=1.25, help="Factor to increase rank on retry.") + parser.add_argument("--progress_check_interval", type=int, default=100, help="Check loss improvement every N iterations (0 to disable).") + parser.add_argument("--min_progress_loss_ratio", type=float, default=0.001, help="Min relative loss decrease over interval.") + parser.add_argument("--progress_check_start_iter", type=int, default=None, help="Iteration for start of first progress window. Default: 'progress_check_interval'.") + parser.add_argument("--advanced_projection_decay_cap_min", type=float, default=0.5, help="Min cap for decay factor in advanced projection.") + parser.add_argument("--advanced_projection_decay_cap_max", type=float, default=1.05, help="Max cap for decay factor in advanced projection.") + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device ('cuda' or 'cpu').") + parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], help="Optimization precision.") + parser.add_argument("--save_weights_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"], help="Dtype for saved LoHA weights.") + parser.add_argument("--atol_fp32_check", type=float, default=1e-6, help="Tolerance for identical weight check.") + parser.add_argument("--no_warm_start", action="store_true", help="Disable warm-starting higher rank attempts from previous best.") + parser.add_argument("--use_bias", action="store_true", help="Save differing bias terms into LoHA.") + parser.add_argument("--dropout", type=float, default=0.0, help="General dropout (metadata only).") + parser.add_argument("--rank_dropout", type=float, default=0.0, help="Rank dropout (metadata only).") + parser.add_argument("--module_dropout", type=float, default=0.0, help="Module dropout (metadata only).") + parser.add_argument("--max_layers", type=int, default=None, help="Max NEW differing layers to process this session.") + parser.add_argument("--verbose", action="store_true", help="General verbose output.") + parser.add_argument("--verbose_layer_debug", action="store_true", help="Detailed per-iteration debug output (implies --verbose).") + parser.add_argument("--projection_sample_interval", type=int, default=20, help="Loss sample interval for EMA (iterations).") + parser.add_argument("--projection_ema_alpha", type=float, default=0.1, help="Smoothing factor for EMA.") + parser.add_argument("--projection_min_ema_history", type=int, default=5, help="Min EMA samples for EMA-based projection.") + parser.add_argument("--save_every_n_layers", type=int, default=0, help="Save intermediate LoHA every N processed layers (0 to disable).") + parser.add_argument("--keep_n_resume_files", type=int, default=0, help="Keep only N most recent intermediate resume files (0 to keep all).") + + raw_parsed_args = parser.parse_args() + processed_args = post_process_cli_args(raw_parsed_args) + main(processed_args) \ No newline at end of file diff --git a/tools/extract_lora_from_models-new.py b/tools/extract_lora_from_models-new.py new file mode 100644 index 0000000..62e2f6a --- /dev/null +++ b/tools/extract_lora_from_models-new.py @@ -0,0 +1,535 @@ +import sys +import os +import argparse +import json +import time +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +import logging # Import for logging + +# NEW: Add diffusers import for model loading +try: + from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +except ImportError: + print("Diffusers library not found. Please install it: pip install diffusers transformers accelerate") + raise + +# --- Localized Logging Setup --- +def _local_setup_logging(log_level=logging.INFO): + """ + Sets up basic logging to console. + """ + logging.basicConfig( + level=log_level, + format="%(asctime)s %(levelname)-8s %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + +_local_setup_logging() # Initialize logging +logger = logging.getLogger(__name__) # Get logger for this module + +MIN_SV = 1e-6 + +# --- Localized sd-scripts constants and utility functions --- +_LOCAL_MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_v10" + +def _local_get_model_version_str_for_sd1_sd2(is_v2: bool, is_v_parameterization: bool) -> str: + if is_v2: + return "v2-v" if is_v_parameterization else "v2" + return "v1" + +# --- Localized LoRA Placeholder and Network Creation --- +class LocalLoRAModulePlaceholder: + def __init__(self, lora_name: str, org_module: torch.nn.Module): + self.lora_name = lora_name + self.org_module = org_module + # Add other attributes if _calculate_module_diffs_and_check needs them, + # but it primarily uses .lora_name and .org_module.weight + +def _local_create_network_placeholders(text_encoders: list, unet: torch.nn.Module, lora_conv_dim_init: int): + """ + Creates placeholders for LoRA-able modules in text encoders and UNet. + Mimics the module identification and naming of sd-scripts' lora.create_network. + `lora_conv_dim_init`: If > 0, Conv2d layers are considered for LoRA. + """ + unet_loras = [] + text_encoder_loras = [] + + # Target U-Net modules + for name, module in unet.named_modules(): + lora_name = "lora_unet_" + name.replace(".", "_") + if isinstance(module, torch.nn.Linear): + unet_loras.append(LocalLoRAModulePlaceholder(lora_name, module)) + elif isinstance(module, torch.nn.Conv2d): + if lora_conv_dim_init > 0: # Only consider conv layers if conv_dim > 0 + # Kernel size check might be relevant if sd-scripts has specific logic, + # but for diffing, any conv is a candidate if conv_dim > 0. + # SVD will later handle rank based on actual layer type (1x1 vs 3x3). + unet_loras.append(LocalLoRAModulePlaceholder(lora_name, module)) + + # Target Text Encoder modules + for i, text_encoder in enumerate(text_encoders): + if text_encoder is None: # SDXL can have None TEs if not loaded + continue + # Determine prefix based on number of text encoders (for SDXL compatibility) + te_prefix = f"lora_te{i+1}_" if len(text_encoders) > 1 else "lora_te_" + + for name, module in text_encoder.named_modules(): + lora_name = te_prefix + name.replace(".", "_") + if isinstance(module, torch.nn.Linear): + text_encoder_loras.append(LocalLoRAModulePlaceholder(lora_name, module)) + # Conv2d in text encoders is rare but check just in case (sd-scripts might) + elif isinstance(module, torch.nn.Conv2d): + if lora_conv_dim_init > 0: + text_encoder_loras.append(LocalLoRAModulePlaceholder(lora_name, module)) + + logger.info(f"Found {len(text_encoder_loras)} LoRA-able placeholder modules in Text Encoders.") + logger.info(f"Found {len(unet_loras)} LoRA-able placeholder modules in U-Net.") + return text_encoder_loras, unet_loras + + +# --- Singular Value Indexing Functions (Unchanged) --- +def index_sv_cumulative(S, target): + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0) / original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + index = max(1, min(index, len(S) - 1)) + return index + +def index_sv_fro(S, target): + S_squared = S.pow(2) + S_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + index = max(1, min(index, len(S) - 1)) + return index + +def index_sv_ratio(S, target): + max_sv = S[0] + min_sv = max_sv / target + index = int(torch.sum(S > min_sv).item()) + index = max(1, min(index, len(S) - 1)) + return index + +def index_sv_knee(S, MIN_SV_KNEE=1e-8): + n = len(S) + if n < 3: return 1 + s_max, s_min = S[0], S[-1] + if s_max - s_min < MIN_SV_KNEE: return 1 + s_normalized = (S - s_min) / (s_max - s_min) + x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype) + distances = (x_normalized + s_normalized - 1).abs() + knee_index_0based = torch.argmax(distances).item() + rank = knee_index_0based + 1 + rank = max(1, min(rank, n - 1)) + return rank + +def index_sv_cumulative_knee(S, min_sv_threshold=1e-8): + n = len(S) + if n < 3: return 1 + s_sum = torch.sum(S) + if s_sum < min_sv_threshold: return 1 + y_values = torch.cumsum(S, dim=0) / s_sum + y_min, y_max = y_values[0], y_values[n-1] + if y_max - y_min < min_sv_threshold: return 1 + y_norm = (y_values - y_min) / (y_max - y_min) + x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype) + distances = (y_norm - x_norm).abs() + knee_index_0based = torch.argmax(distances).item() + rank = knee_index_0based + 1 + rank = max(1, min(rank, n - 1)) + return rank + +def index_sv_rel_decrease(S, tau=0.1): + if len(S) < 2: return 1 + ratios = S[1:] / S[:-1] + for k in range(len(ratios)): + if ratios[k] < tau: + return k + 1 + return len(S) + +# --- Utility Functions --- +def _str_to_dtype(p): + if p == "float": return torch.float + if p == "fp16": return torch.float16 + if p == "bf16": return torch.bfloat16 + return None + +def save_to_file(file_name, state_dict_to_save, dtype, metadata=None): + state_dict_final = {} + for key, value in state_dict_to_save.items(): + if isinstance(value, torch.Tensor) and dtype is not None: + state_dict_final[key] = value.to(dtype) + else: + state_dict_final[key] = value + + if os.path.splitext(file_name)[1] == ".safetensors": + save_file(state_dict_final, file_name, metadata=metadata) + else: + torch.save(state_dict_final, file_name) + +def _build_local_sai_metadata(title, creation_time, is_v2_flag, is_v_param_flag, is_sdxl_flag): + metadata = {} + metadata["ss_sd_model_name"] = str(title) + metadata["ss_creation_time"] = str(int(creation_time)) + if is_sdxl_flag: + metadata["ss_base_model_version"] = "sdxl_v10" + metadata["ss_sdxl_model_version"] = "1.0" + if is_v_param_flag: + metadata["ss_v_parameterization"] = "true" + elif is_v2_flag: + metadata["ss_base_model_version"] = "sd_v2" + if is_v_param_flag: + metadata["ss_v_parameterization"] = "true" + else: + metadata["ss_base_model_version"] = "sd_v1" + if is_v_param_flag: + metadata["ss_v_parameterization"] = "true" + return metadata + +# --- MODIFIED Helper Functions for Model Loading --- +def _load_sd_model_components(model_path, is_v2_flag, target_device_override, load_dtype_torch): + logger.info(f"Loading SD model using Diffusers.StableDiffusionPipeline from: {model_path}") + pipeline = StableDiffusionPipeline.from_single_file( + model_path, + torch_dtype=load_dtype_torch + ) + eff_device = target_device_override if target_device_override else "cpu" + text_encoder = pipeline.text_encoder.to(eff_device) + unet = pipeline.unet.to(eff_device) + text_encoders = [text_encoder] + logger.info(f"Loaded SD model components. UNet device: {unet.device}, TextEncoder device: {text_encoder.device}") + return text_encoders, unet + +def _load_sdxl_model_components(model_path, target_device_override, load_dtype_torch): + actual_load_device = target_device_override if target_device_override else "cpu" + logger.info(f"Loading SDXL model using Diffusers.StableDiffusionXLPipeline from: {model_path} to device: {actual_load_device}") + pipeline = StableDiffusionXLPipeline.from_single_file( + model_path, + torch_dtype=load_dtype_torch + ) + pipeline.to(actual_load_device) + text_encoder = pipeline.text_encoder + text_encoder_2 = pipeline.text_encoder_2 + unet = pipeline.unet + text_encoders = [text_encoder, text_encoder_2] + logger.info(f"Loaded SDXL model components. UNet device: {unet.device}, TextEncoder1 device: {text_encoder.device}, TextEncoder2 device: {text_encoder_2.device}") + return text_encoders, unet + +def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_device, min_diff_thresh, module_type_str): + diffs_map = {} + is_different_flag = False + first_diff_logged = False + for lora_o, lora_t in zip(module_loras_o, module_loras_t): + lora_name = lora_o.lora_name + if lora_o.org_module is None or lora_t.org_module is None or \ + not hasattr(lora_o.org_module, 'weight') or lora_o.org_module.weight is None or \ + not hasattr(lora_t.org_module, 'weight') or lora_t.org_module.weight is None: + logger.warning(f"Skipping {lora_name} in {module_type_str} due to missing org_module or weight.") + continue + weight_o = lora_o.org_module.weight + weight_t = lora_t.org_module.weight + if str(weight_o.device) != str(diff_calc_device): weight_o = weight_o.to(diff_calc_device) + if str(weight_t.device) != str(diff_calc_device): weight_t = weight_t.to(diff_calc_device) + diff = weight_t - weight_o + diffs_map[lora_name] = diff + current_max_diff = torch.max(torch.abs(diff)) + if not is_different_flag and current_max_diff > min_diff_thresh: + is_different_flag = True + if not first_diff_logged: + logger.info(f"{module_type_str} '{lora_name}' differs: max diff {current_max_diff} > {min_diff_thresh}") + first_diff_logged = True + return diffs_map, is_different_flag + +def _determine_rank(S_values, dynamic_method_name, dynamic_param_value, max_rank_limit, + module_eff_in_dim, module_eff_out_dim, min_sv_threshold=MIN_SV): + if not S_values.numel() or S_values[0] <= min_sv_threshold: return 1 + rank = 0 + if dynamic_method_name == "sv_ratio": rank = index_sv_ratio(S_values, dynamic_param_value) + elif dynamic_method_name == "sv_cumulative": rank = index_sv_cumulative(S_values, dynamic_param_value) + elif dynamic_method_name == "sv_fro": rank = index_sv_fro(S_values, dynamic_param_value) + elif dynamic_method_name == "sv_knee": rank = index_sv_knee(S_values, min_sv_threshold) + elif dynamic_method_name == "sv_cumulative_knee": rank = index_sv_cumulative_knee(S_values, min_sv_threshold) + elif dynamic_method_name == "sv_rel_decrease": rank = index_sv_rel_decrease(S_values, dynamic_param_value) + else: rank = max_rank_limit + rank = min(rank, max_rank_limit, module_eff_in_dim, module_eff_out_dim, len(S_values)) + rank = max(1, rank) + return rank + +def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, rank, + clamp_quantile_val, is_conv2d, is_conv2d_3x3, + conv_kernel_size, + module_out_channels, module_in_channels, + target_device_for_final_weights, target_dtype_for_final_weights): + S_k = S_all_values[:rank] + U_k = U_full[:, :rank] + Vh_k = Vh_full[:rank, :] + S_k_non_negative = torch.clamp(S_k, min=0.0) + s_sqrt = torch.sqrt(S_k_non_negative) + U_final = U_k * s_sqrt.unsqueeze(0) + Vh_final = Vh_k * s_sqrt.unsqueeze(1) + dist = torch.cat([U_final.flatten(), Vh_final.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile_val) + if hi_val == 0 and torch.max(torch.abs(dist)) > 1e-9: + logger.debug(f"Clamping hi_val is zero for non-zero distribution. Max abs val: {torch.max(torch.abs(dist))}. Quantile: {clamp_quantile_val}") + U_clamped = U_final.clamp(-hi_val, hi_val) + Vh_clamped = Vh_final.clamp(-hi_val, hi_val) + if is_conv2d: + U_clamped = U_clamped.reshape(module_out_channels, rank, 1, 1) + if is_conv2d_3x3: + Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, *conv_kernel_size) + else: + Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, 1, 1) + U_clamped = U_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous() + Vh_clamped = Vh_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous() + return U_clamped, Vh_clamped + +def _log_svd_stats(lora_module_name, S_all_values, rank_used, min_sv_for_calc=MIN_SV): + if not S_all_values.numel(): + logger.info(f"{lora_module_name:75} | rank: {rank_used}, SVD not performed (empty singular values).") + return + S_cpu = S_all_values.to('cpu') + s_sum_total = float(torch.sum(S_cpu)) + s_sum_rank = float(torch.sum(S_cpu[:rank_used])) + fro_orig_total = float(torch.sqrt(torch.sum(S_cpu.pow(2)))) + fro_reconstructed_rank = float(torch.sqrt(torch.sum(S_cpu[:rank_used].pow(2)))) + ratio_sv = float('inf') + if rank_used > 0 and S_cpu[rank_used - 1].abs() > min_sv_for_calc: + ratio_sv = S_cpu[0] / S_cpu[rank_used - 1] + sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > min_sv_for_calc else 1.0 + fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > min_sv_for_calc else 1.0 + logger.info( + f"{lora_module_name:75} | rank: {rank_used}, " + f"sum(S) retained: {sum_s_retained_percentage:.2%}, " + f"Frobenius norm retained: {fro_retained_percentage:.2%}, " + f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}" + ) + +def _prepare_lora_metadata(output_path, is_v2_flag, kohya_base_model_version_str, network_conv_dim_val, + use_dynamic_method_flag, network_dim_config_val, + is_v_param_flag, is_sdxl_flag, skip_sai_meta): + net_kwargs = {"conv_dim": str(network_conv_dim_val), "conv_alpha": str(float(network_conv_dim_val))} if network_conv_dim_val is not None else {} + if use_dynamic_method_flag: + network_dim_meta = "Dynamic" + network_alpha_meta = "Dynamic" + else: + network_dim_meta = str(network_dim_config_val) + network_alpha_meta = str(float(network_dim_config_val)) + final_metadata = { + "ss_v2": str(is_v2_flag), + "ss_base_model_version": kohya_base_model_version_str, + "ss_network_module": "networks.lora", # This remains for compatibility with tools expecting it + "ss_network_dim": network_dim_meta, + "ss_network_alpha": network_alpha_meta, + "ss_network_args": json.dumps(net_kwargs), + "ss_lowram": "False", + "ss_num_train_images": "N/A", + } + if not skip_sai_meta: + title = os.path.splitext(os.path.basename(output_path))[0] + current_time = time.time() + sai_metadata_content = _build_local_sai_metadata( + title=title, creation_time=current_time, is_v2_flag=is_v2_flag, + is_v_param_flag=is_v_param_flag, is_sdxl_flag=is_sdxl_flag + ) + final_metadata.update(sai_metadata_content) + return final_metadata + +# --- Main SVD Function --- +def svd( + model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, + conv_dim=None, v_parameterization=None, device=None, save_precision=None, + clamp_quantile=0.99, min_diff=0.01, no_metadata=False, load_precision=None, + load_original_model_to=None, load_tuned_model_to=None, + dynamic_method=None, dynamic_param=None, verbose=False, +): + actual_v_parameterization = v2 if v_parameterization is None else v_parameterization + load_dtype_torch = _str_to_dtype(load_precision) + save_dtype_torch = _str_to_dtype(save_precision) if save_precision else torch.float + + svd_computation_device = torch.device(device if device else "cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using SVD computation device: {svd_computation_device}") + diff_calculation_device = torch.device("cpu") + logger.info(f"Calculating weight differences on: {diff_calculation_device}") + final_weights_device = torch.device("cpu") + + if not sdxl: + text_encoders_o, unet_o = _load_sd_model_components(model_org, v2, load_original_model_to, load_dtype_torch) + text_encoders_t, unet_t = _load_sd_model_components(model_tuned, v2, load_tuned_model_to, load_dtype_torch) + kohya_model_version = _local_get_model_version_str_for_sd1_sd2(v2, actual_v_parameterization) + else: + text_encoders_o, unet_o = _load_sdxl_model_components(model_org, load_original_model_to, load_dtype_torch) + text_encoders_t, unet_t = _load_sdxl_model_components(model_tuned, load_tuned_model_to, load_dtype_torch) + kohya_model_version = _LOCAL_MODEL_VERSION_SDXL_BASE_V1_0 + + # Determine lora_conv_dim_init based on conv_dim argument for network creation + # The original script used init_dim_val (1) if conv_dim was None. + # Here, conv_dim is already defaulted to args.dim if None by the main block. + # So, lora_conv_dim_init will be args.conv_dim (which defaults to args.dim). + # If args.conv_dim was explicitly 0, this would be 0. + lora_conv_dim_init_val = conv_dim # conv_dim is args.conv_dim (or args.dim) + + # Create LoRA placeholders using the localized function + text_encoder_loras_o, unet_loras_o = _local_create_network_placeholders(text_encoders_o, unet_o, lora_conv_dim_init_val) + text_encoder_loras_t, unet_loras_t = _local_create_network_placeholders(text_encoders_t, unet_t, lora_conv_dim_init_val) # same conv_dim logic for tuned + + # Group LoRA placeholders for easier processing (mimicking LoraNetwork structure somewhat) + class LocalLoraNetworkPlaceholder: + def __init__(self, te_loras, unet_loras_list): + self.text_encoder_loras = te_loras + self.unet_loras = unet_loras_list + + lora_network_o = LocalLoraNetworkPlaceholder(text_encoder_loras_o, unet_loras_o) + lora_network_t = LocalLoraNetworkPlaceholder(text_encoder_loras_t, unet_loras_t) + + assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), \ + f"Model versions (based on identified LoRA-able TE modules) differ: {len(lora_network_o.text_encoder_loras)} vs {len(lora_network_t.text_encoder_loras)} TEs" + + all_diffs = {} + te_diffs, text_encoder_different = _calculate_module_diffs_and_check( + lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras, + diff_calculation_device, min_diff, "Text Encoder" + ) + + if text_encoder_different: + all_diffs.update(te_diffs) + else: + logger.warning("Text encoders are considered identical based on min_diff. Not extracting TE LoRA.") + # To prevent processing empty list later, ensure it's empty if no diffs + lora_network_o.text_encoder_loras = [] + del text_encoders_t # Free memory early + + unet_diffs, _ = _calculate_module_diffs_and_check( + lora_network_o.unet_loras, lora_network_t.unet_loras, + diff_calculation_device, min_diff, "U-Net" + ) + all_diffs.update(unet_diffs) + del lora_network_t, unet_t # Free memory early + + # Ensure lora_names_to_process only includes modules from lora_network_o + # that are actually present (e.g., if TEs were skipped) + lora_names_to_process = set() + if text_encoder_different: # Only add TE loras if they were deemed different + lora_names_to_process.update(p.lora_name for p in lora_network_o.text_encoder_loras) + lora_names_to_process.update(p.lora_name for p in lora_network_o.unet_loras) + + logger.info("Extracting and resizing LoRA via SVD") + lora_weights = {} + with torch.no_grad(): + for lora_name in tqdm(lora_names_to_process): + if lora_name not in all_diffs: + logger.warning(f"Skipping {lora_name} as no diff was calculated for it (e.g., Text Encoders were identical).") + continue + original_diff_tensor = all_diffs[lora_name] + is_conv2d_layer = len(original_diff_tensor.size()) == 4 + kernel_s = original_diff_tensor.size()[2:4] if is_conv2d_layer else None + is_conv2d_3x3_layer = is_conv2d_layer and kernel_s != (1, 1) + module_true_out_channels, module_true_in_channels = original_diff_tensor.size()[0:2] + mat_for_svd = original_diff_tensor.to(svd_computation_device, dtype=torch.float) + if is_conv2d_layer: + if is_conv2d_3x3_layer: mat_for_svd = mat_for_svd.flatten(start_dim=1) + else: mat_for_svd = mat_for_svd.squeeze() + if mat_for_svd.numel() == 0 or mat_for_svd.shape[0] == 0 or mat_for_svd.shape[1] == 0 : + logger.warning(f"Skipping SVD for {lora_name} due to empty/invalid shape: {mat_for_svd.shape}") + continue + try: + U_full, S_full, Vh_full = torch.linalg.svd(mat_for_svd) + except Exception as e: + logger.error(f"SVD failed for {lora_name} with shape {mat_for_svd.shape}. Error: {e}") + continue + + # Max rank for SVD is based on 'dim' for linear and 'conv_dim' for conv3x3 + # The original `current_max_rank` logic was: + # current_max_rank = dim if not is_conv2d_3x3_layer or conv_dim is None else conv_dim + # Here, `dim` is args.dim and `conv_dim` is args.conv_dim (defaulted to args.dim) + module_specific_max_rank = conv_dim if is_conv2d_3x3_layer else dim + + eff_out_dim, eff_in_dim = mat_for_svd.shape[0], mat_for_svd.shape[1] + rank = _determine_rank(S_full, dynamic_method, dynamic_param, + module_specific_max_rank, eff_in_dim, eff_out_dim, MIN_SV) + U_clamped, Vh_clamped = _construct_lora_weights_from_svd_components( + U_full, S_full, Vh_full, rank, clamp_quantile, + is_conv2d_layer, is_conv2d_3x3_layer, kernel_s, + module_true_out_channels, module_true_in_channels, + final_weights_device, save_dtype_torch + ) + lora_weights[lora_name] = (U_clamped, Vh_clamped) + if verbose: _log_svd_stats(lora_name, S_full, rank, MIN_SV) + + lora_sd = {} + for lora_name, (up_weight, down_weight) in lora_weights.items(): + lora_sd[lora_name + ".lora_up.weight"] = up_weight + lora_sd[lora_name + ".lora_down.weight"] = down_weight + # Alpha is set to the rank (dim of down_weight's 0th axis, which is rank) + lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype_torch, device=final_weights_device) + + del text_encoders_o, unet_o, lora_network_o, all_diffs # Clean up original models and placeholders + if 'torch' in sys.modules and hasattr(torch, 'cuda') and torch.cuda.is_available(): + torch.cuda.empty_cache() + + if not os.path.exists(os.path.dirname(save_to)) and os.path.dirname(save_to) != "": + os.makedirs(os.path.dirname(save_to), exist_ok=True) + + metadata_to_save = _prepare_lora_metadata( + output_path=save_to, + is_v2_flag=v2, + kohya_base_model_version_str=kohya_model_version, + network_conv_dim_val=conv_dim, # This is args.conv_dim (defaulted to args.dim) + use_dynamic_method_flag=bool(dynamic_method), + network_dim_config_val=dim, # This is args.dim + is_v_param_flag=actual_v_parameterization, + is_sdxl_flag=sdxl, + skip_sai_meta=no_metadata + ) + + save_to_file(save_to, lora_sd, save_dtype_torch, metadata_to_save) + logger.info(f"LoRA saved to: {save_to}") + +def setup_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model") + parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2 if --v2 is set)") + parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model") + parser.add_argument("--load_precision", type=str, choices=["float", "fp16", "bf16"], default=None, help="Precision for loading models (applied after initial load)") + parser.add_argument("--save_precision", type=str, choices=["float", "fp16", "bf16"], default="float", help="Precision for saving LoRA weights") + parser.add_argument("--model_org", type=str, required=True, help="Original Stable Diffusion model (ckpt/safetensors)") + parser.add_argument("--model_tuned", type=str, required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)") + parser.add_argument("--save_to", type=str, required=True, help="Output file name (ckpt/safetensors)") + parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers") + parser.add_argument("--conv_dim", type=int, default=None, help="Max dimension (rank) of LoRA for Conv2d-3x3. Defaults to 'dim' if not set.") + parser.add_argument("--device", type=str, default=None, help="Device for SVD computation (e.g., cuda, cpu). Defaults to cuda if available, else cpu.") + parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights") + parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract LoRA for a module") + parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata from SAI and Kohya_ss") + parser.add_argument("--load_original_model_to", type=str, default=None, help="Device for original model (e.g. 'cpu', 'cuda:0'). Defaults to CPU for SD1/2, honored for SDXL.") + parser.add_argument("--load_tuned_model_to", type=str, default=None, help="Device for tuned model (e.g. 'cpu', 'cuda:0'). Defaults to CPU for SD1/2, honored for SDXL.") + parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction") + parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info for each module") + parser.add_argument( + "--dynamic_method", type=str, + choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], + default=None, help="Dynamic rank reduction method" + ) + return parser + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + + if args.conv_dim is None: + args.conv_dim = args.dim # Default conv_dim to dim if not provided + logger.info(f"--conv_dim not set, using value of --dim: {args.conv_dim}") + + methods_requiring_param = ["sv_ratio", "sv_fro", "sv_cumulative", "sv_rel_decrease"] + if args.dynamic_method in methods_requiring_param and args.dynamic_param is None: + parser.error(f"Dynamic method '{args.dynamic_method}' requires --dynamic_param to be set.") + + if not args.dynamic_method: # Ranks must be positive if not using dynamic method + if args.dim <= 0: parser.error(f"--dim (rank) must be > 0. Got {args.dim}") + if args.conv_dim <=0: parser.error(f"--conv_dim (rank) must be > 0. Got {args.conv_dim}") # Check after defaulting + + if MIN_SV <= 0: logger.warning(f"Global MIN_SV ({MIN_SV}) should be positive.") + + svd_args = vars(args).copy() + svd(**svd_args) \ No newline at end of file diff --git a/tools/extract_lora_from_models-nw.py b/tools/extract_lora_from_models-nw.py deleted file mode 100644 index 43c1d00..0000000 --- a/tools/extract_lora_from_models-nw.py +++ /dev/null @@ -1,360 +0,0 @@ -# extract approximating LoRA by svd from two SD models -# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py -# Thanks to cloneofsimo! - -import argparse -import json -import os -import time -import torch -from safetensors.torch import load_file, save_file -from tqdm import tqdm -from library import sai_model_spec, model_util, sdxl_model_util -import lora -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -# CLAMP_QUANTILE = 0.99 -# MIN_DIFF = 1e-1 - - -def save_to_file(file_name, model, state_dict, dtype): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - - if os.path.splitext(file_name)[1] == ".safetensors": - save_file(model, file_name) - else: - torch.save(model, file_name) - - -def svd( - model_org=None, - model_tuned=None, - save_to=None, - dim=4, - v2=None, - sdxl=None, - conv_dim=None, - v_parameterization=None, - device=None, - save_precision=None, - clamp_quantile=0.99, - min_diff=0.01, - no_metadata=False, - load_precision=None, - load_original_model_to=None, - load_tuned_model_to=None, -): - def str_to_dtype(p): - if p == "float": - return torch.float - if p == "fp16": - return torch.float16 - if p == "bf16": - return torch.bfloat16 - return None - - assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" - if v_parameterization is None: - v_parameterization = v2 - - load_dtype = str_to_dtype(load_precision) if load_precision else None - save_dtype = str_to_dtype(save_precision) - work_device = "cpu" - - # load models - if not sdxl: - logger.info(f"loading original SD model : {model_org}") - text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) - text_encoders_o = [text_encoder_o] - if load_dtype is not None: - text_encoder_o = text_encoder_o.to(load_dtype) - unet_o = unet_o.to(load_dtype) - - logger.info(f"loading tuned SD model : {model_tuned}") - text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) - text_encoders_t = [text_encoder_t] - if load_dtype is not None: - text_encoder_t = text_encoder_t.to(load_dtype) - unet_t = unet_t.to(load_dtype) - - model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) - else: - device_org = load_original_model_to if load_original_model_to else "cpu" - device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu" - - logger.info(f"loading original SDXL model : {model_org}") - text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org - ) - text_encoders_o = [text_encoder_o1, text_encoder_o2] - if load_dtype is not None: - text_encoder_o1 = text_encoder_o1.to(load_dtype) - text_encoder_o2 = text_encoder_o2.to(load_dtype) - unet_o = unet_o.to(load_dtype) - - logger.info(f"loading original SDXL model : {model_tuned}") - text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned - ) - text_encoders_t = [text_encoder_t1, text_encoder_t2] - if load_dtype is not None: - text_encoder_t1 = text_encoder_t1.to(load_dtype) - text_encoder_t2 = text_encoder_t2.to(load_dtype) - unet_t = unet_t.to(load_dtype) - - model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 - - # create LoRA network to extract weights: Use dim (rank) as alpha - if conv_dim is None: - kwargs = {} - else: - kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} - - lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs) - lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs) - assert len(lora_network_o.text_encoder_loras) == len( - lora_network_t.text_encoder_loras - ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " - - # get diffs - diffs = {} - text_encoder_different = False - for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): - lora_name = lora_o.lora_name - module_o = lora_o.org_module - module_t = lora_t.org_module - diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) - - # clear weight to save memory - module_o.weight = None - module_t.weight = None - - # Text Encoder might be same - if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: - text_encoder_different = True - logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") - - diffs[lora_name] = diff - - # clear target Text Encoder to save memory - for text_encoder in text_encoders_t: - del text_encoder - - if not text_encoder_different: - logger.warning("Text encoder is same. Extract U-Net only.") - lora_network_o.text_encoder_loras = [] - diffs = {} # clear diffs - - for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): - lora_name = lora_o.lora_name - module_o = lora_o.org_module - module_t = lora_t.org_module - diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) - - # clear weight to save memory - module_o.weight = None - module_t.weight = None - - diffs[lora_name] = diff - - # clear LoRA network, target U-Net to save memory - del lora_network_o - del lora_network_t - del unet_t - - # make LoRA with svd - logger.info("calculating by svd") - lora_weights = {} - with torch.no_grad(): - for lora_name, mat in tqdm(list(diffs.items())): - if args.device: - mat = mat.to(args.device) - mat = mat.to(torch.float) # calc by float - - # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 - conv2d = len(mat.size()) == 4 - kernel_size = None if not conv2d else mat.size()[2:4] - conv2d_3x3 = conv2d and kernel_size != (1, 1) - - rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim - out_dim, in_dim = mat.size()[0:2] - - if device: - mat = mat.to(device) - - # logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) - rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim - - if conv2d: - if conv2d_3x3: - mat = mat.flatten(start_dim=1) - else: - mat = mat.squeeze() - - U, S, Vh = torch.linalg.svd(mat) - - U = U[:, :rank] - S = S[:rank] - U = U @ torch.diag(S) - - Vh = Vh[:rank, :] - - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, clamp_quantile) - low_val = -hi_val - - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) - - if conv2d: - U = U.reshape(out_dim, rank, 1, 1) - Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) - - U = U.to(work_device, dtype=save_dtype).contiguous() - Vh = Vh.to(work_device, dtype=save_dtype).contiguous() - - lora_weights[lora_name] = (U, Vh) - - # make state dict for LoRA - lora_sd = {} - for lora_name, (up_weight, down_weight) in lora_weights.items(): - lora_sd[lora_name + ".lora_up.weight"] = up_weight - lora_sd[lora_name + ".lora_down.weight"] = down_weight - lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) - - # load state dict to LoRA and save it - lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd) - lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict - - info = lora_network_save.load_state_dict(lora_sd) - logger.info(f"Loading extracted LoRA weights: {info}") - - dir_name = os.path.dirname(save_to) - if dir_name and not os.path.exists(dir_name): - os.makedirs(dir_name, exist_ok=True) - - # minimum metadata - net_kwargs = {} - if conv_dim is not None: - net_kwargs["conv_dim"] = str(conv_dim) - net_kwargs["conv_alpha"] = str(float(conv_dim)) - - metadata = { - "ss_v2": str(v2), - "ss_base_model_version": model_version, - "ss_network_module": "networks.lora", - "ss_network_dim": str(dim), - "ss_network_alpha": str(float(dim)), - "ss_network_args": json.dumps(net_kwargs), - } - - if not no_metadata: - title = os.path.splitext(os.path.basename(save_to))[0] - sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title) - metadata.update(sai_metadata) - - lora_network_save.save_weights(save_to, save_dtype, metadata) - logger.info(f"LoRA weights are saved to: {save_to}") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") - parser.add_argument( - "--v_parameterization", - action="store_true", - default=None, - help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)", - ) - parser.add_argument( - "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む" - ) - parser.add_argument( - "--load_precision", - type=str, - default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる" - ) - parser.add_argument( - "--save_precision", - type=str, - default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", - ) - parser.add_argument( - "--model_org", - type=str, - default=None, - required=True, - help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors", - ) - parser.add_argument( - "--model_tuned", - type=str, - default=None, - required=True, - help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", - ) - parser.add_argument( - "--save_to", - type=str, - default=None, - required=True, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", - ) - parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") - parser.add_argument( - "--conv_dim", - type=int, - default=None, - help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)", - ) - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") - parser.add_argument( - "--clamp_quantile", - type=float, - default=0.99, - help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", - ) - parser.add_argument( - "--min_diff", - type=float, - default=0.01, - help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" - + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", - ) - parser.add_argument( - "--no_metadata", - action="store_true", - help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " - + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", - ) - parser.add_argument( - "--load_original_model_to", - type=str, - default=None, - help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", - ) - parser.add_argument( - "--load_tuned_model_to", - type=str, - default=None, - help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - svd(**vars(args)) diff --git a/tools/extract_model_difference.py b/tools/extract_model_difference.py new file mode 100644 index 0000000..34744d1 --- /dev/null +++ b/tools/extract_model_difference.py @@ -0,0 +1,175 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict +import os +import argparse # Import argparse + +def extract_model_differences(base_model_path, finetuned_model_path, output_delta_path=None, save_dtype_str="float32"): + """ + Calculates the difference between the state dictionaries of a fine-tuned model + and a base model. + + Args: + base_model_path (str): Path to the base model .safetensors file. + finetuned_model_path (str): Path to the fine-tuned model .safetensors file. + output_delta_path (str, optional): Path to save the resulting delta weights + .safetensors file. If None, not saved. + save_dtype_str (str, optional): Data type to save the delta weights ('float32', 'float16', 'bfloat16'). + Defaults to 'float32'. + Returns: + OrderedDict: A state dictionary containing the delta weights. + Returns None if loading fails or other critical errors. + """ + print(f"Loading base model from: {base_model_path}") + try: + # Ensure model is loaded to CPU to avoid CUDA issues if not needed for diffing + base_state_dict = load_file(base_model_path, device="cpu") + print(f"Base model loaded. Found {len(base_state_dict)} tensors.") + except Exception as e: + print(f"Error loading base model: {e}") + return None + + print(f"\nLoading fine-tuned model from: {finetuned_model_path}") + try: + finetuned_state_dict = load_file(finetuned_model_path, device="cpu") + print(f"Fine-tuned model loaded. Found {len(finetuned_state_dict)} tensors.") + except Exception as e: + print(f"Error loading fine-tuned model: {e}") + return None + + delta_state_dict = OrderedDict() + diff_count = 0 + skipped_count = 0 + error_count = 0 + unique_to_finetuned_count = 0 + unique_to_base_count = 0 + + print("\nCalculating differences...") + + # Keys in finetuned model + finetuned_keys = set(finetuned_state_dict.keys()) + base_keys = set(base_state_dict.keys()) + + common_keys = finetuned_keys.intersection(base_keys) + keys_only_in_finetuned = finetuned_keys - base_keys + keys_only_in_base = base_keys - finetuned_keys + + for key in common_keys: + ft_tensor = finetuned_state_dict[key] + base_tensor = base_state_dict[key] + + if not (ft_tensor.is_floating_point() and base_tensor.is_floating_point()): + # print(f"Skipping key '{key}': Non-floating point tensors (FT: {ft_tensor.dtype}, Base: {base_tensor.dtype}).") + skipped_count += 1 + continue + + if ft_tensor.shape != base_tensor.shape: + print(f"Skipping key '{key}': Shape mismatch (FT: {ft_tensor.shape}, Base: {base_tensor.shape}).") + skipped_count += 1 + continue + + try: + # Calculate difference in float32 for precision, then cast to save_dtype + delta_tensor = ft_tensor.to(dtype=torch.float32) - base_tensor.to(dtype=torch.float32) + delta_state_dict[key] = delta_tensor + diff_count += 1 + except Exception as e: + print(f"Error calculating difference for key '{key}': {e}") + error_count += 1 + + for key in keys_only_in_finetuned: + print(f"Warning: Key '{key}' (Shape: {finetuned_state_dict[key].shape}, Dtype: {finetuned_state_dict[key].dtype}) is present in fine-tuned model but not in base model. Storing as is.") + delta_state_dict[key] = finetuned_state_dict[key] # Store the tensor from the finetuned model + unique_to_finetuned_count += 1 + + if keys_only_in_base: + print(f"\nWarning: {len(keys_only_in_base)} key(s) are present only in the base model and will not be in the delta file.") + for key in list(keys_only_in_base)[:5]: # Print first 5 as examples + print(f" - Example key only in base: {key}") + if len(keys_only_in_base) > 5: + print(f" ... and {len(keys_only_in_base) - 5} more.") + + + print(f"\nDifference calculation complete.") + print(f" {diff_count} layers successfully diffed.") + print(f" {unique_to_finetuned_count} layers unique to fine-tuned model (added as is).") + print(f" {skipped_count} common layers skipped (shape/type mismatch).") + print(f" {error_count} common layers had errors during diffing.") + + if output_delta_path and delta_state_dict: + save_dtype = torch.float32 # Default + if save_dtype_str == "float16": + save_dtype = torch.float16 + elif save_dtype_str == "bfloat16": + save_dtype = torch.bfloat16 + elif save_dtype_str != "float32": + print(f"Warning: Invalid save_dtype '{save_dtype_str}'. Defaulting to float32.") + save_dtype_str = "float32" # for print message + + print(f"\nPreparing to save delta weights with dtype: {save_dtype_str}") + + final_save_dict = OrderedDict() + for k, v_tensor in delta_state_dict.items(): + if v_tensor.is_floating_point(): + final_save_dict[k] = v_tensor.to(dtype=save_dtype) + else: + final_save_dict[k] = v_tensor # Keep non-float as is (e.g. int tensors if any) + + try: + save_file(final_save_dict, output_delta_path) + print(f"Delta weights saved to: {output_delta_path}") + except Exception as e: + print(f"Error saving delta weights: {e}") + import traceback + traceback.print_exc() + + + return delta_state_dict + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Extract weight differences between a fine-tuned and a base SDXL model.") + parser.add_argument("base_model_path", type=str, help="File path for the BASE SDXL model (.safetensors).") + parser.add_argument("finetuned_model_path", type=str, help="File path for the FINE-TUNED SDXL model (.safetensors).") + parser.add_argument("--output_path", type=str, default=None, + help="Optional: File path to save the delta weights (.safetensors). " + "If not provided, defaults to 'model_deltas/delta_[finetuned_model_name].safetensors'.") + parser.add_argument("--save_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"], + help="Data type for saving the delta weights. Choose from 'float32', 'float16', 'bfloat16'. " + "Defaults to 'float32'.") + + args = parser.parse_args() + + print("--- Model Difference Extraction Script ---") + + if not os.path.exists(args.base_model_path): + print(f"Error: Base model file not found at {args.base_model_path}") + exit(1) + if not os.path.exists(args.finetuned_model_path): + print(f"Error: Fine-tuned model file not found at {args.finetuned_model_path}") + exit(1) + + output_delta_file = args.output_path + if output_delta_file is None: + output_dir = "model_deltas" + os.makedirs(output_dir, exist_ok=True) + finetuned_basename = os.path.splitext(os.path.basename(args.finetuned_model_path))[0] + output_delta_file = os.path.join(output_dir, f"delta_{finetuned_basename}.safetensors") + + # Ensure the output directory exists if a full path is given + if output_delta_file: + output_dir_for_file = os.path.dirname(output_delta_file) + if output_dir_for_file and not os.path.exists(output_dir_for_file): + os.makedirs(output_dir_for_file, exist_ok=True) + + + differences = extract_model_differences( + args.base_model_path, + args.finetuned_model_path, + output_delta_path=output_delta_file, + save_dtype_str=args.save_dtype + ) + + if differences: + print(f"\nExtraction process finished. {len(differences)} total keys in the delta state_dict.") + else: + print("\nCould not extract differences due to errors during model loading.") \ No newline at end of file