pull/3264/head
bmaltais 2025-05-11 00:52:58 -04:00
parent ece84c6d06
commit 926c8f344f
2 changed files with 1283 additions and 242 deletions

View File

@ -0,0 +1,662 @@
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file, load_file
import safetensors # Import the main library to use safetensors.safe_open
from tqdm import tqdm
import math
import json
from collections import OrderedDict
import signal
import sys
import glob
# --- Global variables ---
extracted_loha_state_dict_global = OrderedDict()
layer_optimization_stats_global = []
args_global = None
processed_layers_this_session_count_global = 0
previously_completed_module_prefixes_global = set()
all_completed_module_prefixes_ever_global = set() # Tracks all module prefixes ever completed (resumed + current)
skipped_identical_count_global = 0
skipped_other_reason_count_global = 0
keys_scanned_this_run_global = 0
save_attempted_on_interrupt = False
outer_pbar_global = None
main_loop_completed_scan_flag_global = False # True if the main key loop finished a full scan
# --- optimize_loha_for_layer and get_module_shape_info_from_weight (UNCHANGED) ---
def optimize_loha_for_layer(
layer_name: str, delta_W_target: torch.Tensor, out_dim: int, in_dim_effective: int,
k_h: int, k_w: int, rank: int, initial_alpha_val: float, lr: float = 1e-3,
max_iterations: int = 1000, min_iterations: int = 100, target_loss: float = None,
weight_decay: float = 1e-4, device: str = 'cuda', dtype: torch.dtype = torch.float32,
is_conv: bool = True, verbose_layer_debug: bool = False
):
delta_W_target = delta_W_target.to(device, dtype=dtype)
if is_conv:
k_ops = k_h * k_w
hada_w1_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype)); nn.init.kaiming_uniform_(hada_w1_a, a=math.sqrt(5))
hada_w1_b = nn.Parameter(torch.empty(rank, in_dim_effective * k_ops, device=device, dtype=dtype)); nn.init.normal_(hada_w1_b, std=0.02)
hada_w2_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype)); nn.init.kaiming_uniform_(hada_w2_a, a=math.sqrt(5))
hada_w2_b = nn.Parameter(torch.empty(rank, in_dim_effective * k_ops, device=device, dtype=dtype)); nn.init.normal_(hada_w2_b, std=0.02)
else: # Linear
hada_w1_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype)); nn.init.kaiming_uniform_(hada_w1_a, a=math.sqrt(5))
hada_w1_b = nn.Parameter(torch.empty(rank, in_dim_effective, device=device, dtype=dtype)); nn.init.normal_(hada_w1_b, std=0.02)
hada_w2_a = nn.Parameter(torch.empty(out_dim, rank, device=device, dtype=dtype)); nn.init.kaiming_uniform_(hada_w2_a, a=math.sqrt(5))
hada_w2_b = nn.Parameter(torch.empty(rank, in_dim_effective, device=device, dtype=dtype)); nn.init.normal_(hada_w2_b, std=0.02)
alpha_param = nn.Parameter(torch.tensor(initial_alpha_val, device=device, dtype=dtype))
optimizer = torch.optim.AdamW([hada_w1_a, hada_w1_b, hada_w2_a, hada_w2_b, alpha_param], lr=lr, weight_decay=weight_decay)
patience_epochs = max(10, int(max_iterations * 0.05))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience_epochs, factor=0.5, min_lr=1e-7, verbose=False)
iter_pbar = tqdm(range(max_iterations), desc=f"Opt: {layer_name}", leave=False, dynamic_ncols=True, position=1, mininterval=0.5)
final_loss = float('inf'); stopped_early_by_loss = False; iterations_actually_done = 0
for i in iter_pbar:
iterations_actually_done = i + 1
if save_attempted_on_interrupt: print(f"\n Interrupt during opt of {layer_name}. Stopping layer after {i} iters."); break
optimizer.zero_grad(); eff_alpha_scale = alpha_param / rank
if is_conv:
term1_flat = hada_w1_a @ hada_w1_b; term1_reshaped = term1_flat.view(out_dim, in_dim_effective, k_h, k_w)
term2_flat = hada_w2_a @ hada_w2_b; term2_reshaped = term2_flat.view(out_dim, in_dim_effective, k_h, k_w)
delta_W_loha = eff_alpha_scale * term1_reshaped * term2_reshaped
else:
term1 = hada_w1_a @ hada_w1_b; term2 = hada_w2_a @ hada_w2_b
delta_W_loha = eff_alpha_scale * term1 * term2
loss = F.mse_loss(delta_W_loha, delta_W_target); final_loss = loss.item()
loss.backward(); optimizer.step(); scheduler.step(loss)
current_lr = optimizer.param_groups[0]['lr']
iter_pbar.set_postfix_str(f"Loss={final_loss:.3e}, AlphaP={alpha_param.item():.2f}, LR={current_lr:.1e}", refresh=True)
if verbose_layer_debug and (i == 0 or (i + 1) % (max_iterations // 10 if max_iterations >= 10 else 1) == 0 or i == max_iterations - 1):
iter_pbar.write(f" Debug {layer_name} - Iter {i+1}/{max_iterations}: Loss: {final_loss:.6e}, LR: {current_lr:.2e}, AlphaP: {alpha_param.item():.4f}")
if target_loss is not None and i >= min_iterations -1 and final_loss <= target_loss:
if verbose_layer_debug or (args_global and args_global.verbose): iter_pbar.write(f" Target loss {target_loss:.2e} reached for {layer_name} at iter {i+1}.")
stopped_early_by_loss = True; break
if not save_attempted_on_interrupt: iter_pbar.set_description_str(f"Opt: {layer_name} (Done)"); iter_pbar.set_postfix_str(f"FinalLoss={final_loss:.2e}, It={iterations_actually_done}{', EarlyStop' if stopped_early_by_loss else ''}")
iter_pbar.close()
if save_attempted_on_interrupt and not stopped_early_by_loss and iterations_actually_done < max_iterations:
return {'final_loss': final_loss, 'stopped_early': False, 'iterations_done': iterations_actually_done, 'interrupted_mid_layer': True}
return {'hada_w1_a': hada_w1_a.data.cpu().contiguous(), 'hada_w1_b': hada_w1_b.data.cpu().contiguous(),
'hada_w2_a': hada_w2_a.data.cpu().contiguous(), 'hada_w2_b': hada_w2_b.data.cpu().contiguous(),
'alpha': alpha_param.data.cpu().contiguous(), 'final_loss': final_loss,
'stopped_early': stopped_early_by_loss, 'iterations_done': iterations_actually_done,
'interrupted_mid_layer': False}
def get_module_shape_info_from_weight(weight_tensor: torch.Tensor):
if len(weight_tensor.shape) == 4: is_conv = True; out_dim, in_dim_effective, k_h, k_w = weight_tensor.shape; groups = 1; return out_dim, in_dim_effective, k_h, k_w, groups, is_conv
elif len(weight_tensor.shape) == 2: is_conv = False; out_dim, in_dim = weight_tensor.shape; return out_dim, in_dim, None, None, 1, is_conv
return None
# --- NEW: Helper function to generate intermediate filenames ---
def generate_intermediate_filename(base_save_path: str, num_total_completed_layers: int) -> str:
base, ext = os.path.splitext(base_save_path)
return f"{base}_resume_L{num_total_completed_layers}{ext}"
# --- NEW: Helper function to find the best file to resume from ---
def find_best_resume_file(intended_final_path: str) -> tuple[str | None, int]:
output_dir = os.path.dirname(intended_final_path)
if not output_dir: output_dir = "." # Current directory if no path part
base_save_name, save_ext = os.path.splitext(os.path.basename(intended_final_path))
potential_files = []
# Check the main intended file first
if os.path.exists(intended_final_path):
potential_files.append(intended_final_path)
# Check for intermediate files
intermediate_pattern = os.path.join(output_dir, f"{base_save_name}_resume_L*{save_ext}")
potential_files.extend(glob.glob(intermediate_pattern))
best_file_path = None
max_completed_modules = -1
if not potential_files:
print(" No existing main LoHA file or intermediate resume files found.")
return None, -1
print(f" Found potential resume files: {potential_files}")
for file_path in potential_files:
try:
if not os.path.exists(file_path): continue # Should not happen with glob but good check
with safetensors.safe_open(file_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata and "ss_completed_loha_modules" in metadata:
num_completed = len(json.loads(metadata["ss_completed_loha_modules"]))
if num_completed > max_completed_modules:
max_completed_modules = num_completed
best_file_path = file_path
elif num_completed == max_completed_modules and best_file_path != intended_final_path and file_path == intended_final_path:
# Prefer the main file if module count is the same as an intermediate
best_file_path = file_path
elif max_completed_modules == -1: # If no file has metadata, consider the first one (or main one)
# This case handles files without the specific metadata, preferring the main file if it exists.
# It's a basic fallback; files with proper metadata will usually win.
if best_file_path is None or (file_path == intended_final_path and best_file_path != intended_final_path):
best_file_path = file_path # Fallback to considering the file itself if no metadata found yet
max_completed_modules = 0 # Treat as 0 if no metadata, to be potentially overridden
print(f" File {file_path} has no 'ss_completed_loha_modules' metadata. Treating as 0 completed for now.")
except Exception as e:
print(f" Warning: Could not read or parse metadata from {file_path}: {e}")
if best_file_path is None and file_path == intended_final_path and max_completed_modules == -1:
best_file_path = file_path # If primary file is broken, still note it as a candidate if nothing better
max_completed_modules = 0
if best_file_path:
print(f" Selected '{os.path.basename(best_file_path)}' for resume (contains {max_completed_modules} completed modules in metadata).")
elif not potential_files: # Already handled above, but as a safeguard
print(f" No existing LoHA file or intermediate files found matching pattern for: {intended_final_path}")
else:
print(f" Could not determine a best file to resume from among candidates, or no valid metadata found.")
return best_file_path, max_completed_modules
# --- NEW: Helper function to clean up intermediate files ---
def cleanup_intermediate_files(final_intended_path: str):
output_dir = os.path.dirname(final_intended_path)
if not output_dir: output_dir = "."
base_save_name, save_ext = os.path.splitext(os.path.basename(final_intended_path))
intermediate_pattern = os.path.join(output_dir, f"{base_save_name}_resume_L*{save_ext}")
cleaned_count = 0
for file_path in glob.glob(intermediate_pattern):
try:
os.remove(file_path)
if args_global and args_global.verbose: print(f" Cleaned up intermediate file: {file_path}")
cleaned_count +=1
except OSError as e:
print(f" Warning: Could not clean up intermediate file {file_path}: {e}")
if cleaned_count > 0:
print(f" Cleaned up {cleaned_count} intermediate file(s).")
# --- perform_graceful_save (MODIFIED to only require output_path_override) ---
def perform_graceful_save(output_path_to_save: str):
global extracted_loha_state_dict_global, layer_optimization_stats_global, args_global
global processed_layers_this_session_count_global, save_attempted_on_interrupt
global skipped_identical_count_global, skipped_other_reason_count_global, keys_scanned_this_run_global
global all_completed_module_prefixes_ever_global # Use this for metadata
# Ensure all_completed_module_prefixes_ever is up-to-date before saving
# This should already be handled by adding to it when layers are processed or resumed.
current_session_processed_prefixes = {stat['name'] for stat in layer_optimization_stats_global}
# `all_completed_module_prefixes_ever_global` should already include `previously_completed_module_prefixes_global`
# and any newly processed ones.
total_processed_ever = len(all_completed_module_prefixes_ever_global)
if not extracted_loha_state_dict_global and not previously_completed_module_prefixes_global : # Check against all_completed for empty save
# If all_completed is also empty, it means nothing was resumed and nothing new processed
if not all_completed_module_prefixes_ever_global:
print(f"No layers were processed or loaded to save to {output_path_to_save}. Save aborted.")
return
args_to_use = args_global
if not args_to_use: print("Error: Global args not available for saving metadata."); return
final_save_path = output_path_to_save # Use the direct path given
if args_to_use.save_weights_dtype == "fp16": final_save_dtype_torch = torch.float16
elif args_to_use.save_weights_dtype == "bf16": final_save_dtype_torch = torch.bfloat16
else: final_save_dtype_torch = torch.float32
final_state_dict_to_save = OrderedDict()
for k, v_tensor in extracted_loha_state_dict_global.items():
if v_tensor.is_floating_point(): final_state_dict_to_save[k] = v_tensor.to(final_save_dtype_torch)
else: final_state_dict_to_save[k] = v_tensor
# Metadata uses all_completed_module_prefixes_ever_global
print(f"\nAttempting to save LoHA for {total_processed_ever} unique modules in total "
f"({processed_layers_this_session_count_global} new this session) to {final_save_path}")
eff_global_network_alpha_val = args_to_use.initial_alpha; eff_global_network_alpha_str = f"{eff_global_network_alpha_val:.8f}"
global_rank_str = str(args_to_use.rank)
conv_rank_str = str(args_to_use.conv_rank if args_to_use.conv_rank is not None else args_to_use.rank)
eff_conv_alpha_val = args_to_use.initial_conv_alpha; conv_alpha_str = f"{eff_conv_alpha_val:.8f}"
network_args_dict = {
"algo": "loha", "dim": global_rank_str, "alpha": eff_global_network_alpha_str,
"conv_dim": conv_rank_str, "conv_alpha": conv_alpha_str,
"dropout": str(args_to_use.dropout), "rank_dropout": str(args_to_use.rank_dropout), "module_dropout": str(args_to_use.module_dropout),
"use_tucker": "false", "use_scalar": "false", "block_size": "1",}
sf_metadata = {
"ss_network_module": "lycoris.kohya", "ss_network_rank": global_rank_str,
"ss_network_alpha": eff_global_network_alpha_str, "ss_network_algo": "loha",
"ss_network_args": json.dumps(network_args_dict),
"ss_comment": f"Extracted LoHA (Interrupt: {save_attempted_on_interrupt}). OptPrec: {args_to_use.precision}. SaveDtype: {args_to_use.save_weights_dtype}. ATOL: {args_to_use.atol_fp32_check}. Layers: {total_processed_ever}. MaxIter: {args_to_use.max_iterations}. TargetLoss: {args_to_use.target_loss}",
"ss_base_model_name": os.path.splitext(os.path.basename(args_to_use.base_model_path))[0],
"ss_ft_model_name": os.path.splitext(os.path.basename(args_to_use.ft_model_path))[0],
"ss_save_weights_dtype": args_to_use.save_weights_dtype, "ss_optimization_precision": args_to_use.precision,
"ss_completed_loha_modules": json.dumps(list(all_completed_module_prefixes_ever_global)) # Use the global cumulative set
}
json_metadata_for_file = {
"comfyui_lora_type": "LyCORIS_LoHa", "model_name": os.path.splitext(os.path.basename(final_save_path))[0],
"base_model_path": args_to_use.base_model_path, "ft_model_path": args_to_use.ft_model_path,
"loha_extraction_settings": {k: str(v) if isinstance(v, type(os.pathsep)) else v for k,v in vars(args_to_use).items()},
"extraction_summary":{
"processed_layers_in_total_cumulative": total_processed_ever, # Cumulative
"processed_this_session": processed_layers_this_session_count_global,
"skipped_identical_count_this_session": skipped_identical_count_global,
"skipped_other_reason_count_this_session": skipped_other_reason_count_global,
"total_candidate_keys_scanned_in_loop_this_session": keys_scanned_this_run_global,
},
"layer_optimization_details_this_session": layer_optimization_stats_global,
"embedded_safetensors_metadata": sf_metadata,
"interrupted_save": save_attempted_on_interrupt
}
if final_save_path.endswith(".safetensors"):
try:
save_file(final_state_dict_to_save, final_save_path, metadata=sf_metadata)
print(f"LoHA state_dict saved to: {final_save_path}")
except Exception as e:
print(f"Error saving .safetensors file: {e}"); return
metadata_json_file_path = os.path.splitext(final_save_path)[0] + "_extraction_metadata.json"
try:
with open(metadata_json_file_path, 'w') as f: json.dump(json_metadata_for_file, f, indent=4)
print(f"Extended metadata saved to: {metadata_json_file_path}")
except Exception as e: print(f"Could not save extended metadata JSON: {e}")
else:
# Saving to .pt might not be fully robust with this new scheme if JSON metadata is critical
print(f"Saving to .pt not fully supported with extended metadata JSON. Saving basic .pt file.")
torch.save({'state_dict': final_state_dict_to_save, 'metadata': sf_metadata}, final_save_path)
print(f"LoHA state_dict saved to: {final_save_path} (basic .pt save)")
# --- handle_interrupt (MODIFIED to use intermediate filenames) ---
def handle_interrupt(signum, frame):
global save_attempted_on_interrupt, outer_pbar_global, args_global, all_completed_module_prefixes_ever_global
print("\n" + "="*30 + "\nCtrl+C (SIGINT) detected!\n" + "="*30)
if save_attempted_on_interrupt: print("Save already attempted. Force exiting."); os._exit(1); return
save_attempted_on_interrupt = True
if outer_pbar_global: outer_pbar_global.close() # Close main progress bar
# Close any active layer progress bar (it's trickier, this might not catch it if deep in opt)
# For simplicity, we rely on the check within optimize_loha_for_layer
print("Attempting to save progress for processed layers...")
if args_global and args_global.save_to:
num_layers_for_filename = len(all_completed_module_prefixes_ever_global)
interrupt_save_path = generate_intermediate_filename(args_global.save_to, num_layers_for_filename)
print(f"Interrupt save will be to: {interrupt_save_path}")
perform_graceful_save(output_path_to_save=interrupt_save_path)
else:
print("Cannot perform interrupt save: args_global or save_to path not defined.")
print("Graceful save attempt finished. Exiting.")
sys.exit(0)
def main(cli_args):
global args_global, extracted_loha_state_dict_global, layer_optimization_stats_global
global processed_layers_this_session_count_global, save_attempted_on_interrupt, outer_pbar_global
global skipped_identical_count_global, skipped_other_reason_count_global, keys_scanned_this_run_global
global previously_completed_module_prefixes_global, all_completed_module_prefixes_ever_global
global main_loop_completed_scan_flag_global
args_global = cli_args
signal.signal(signal.SIGINT, handle_interrupt)
if args_global.precision == "fp16": target_opt_dtype = torch.float16
elif args_global.precision == "bf16": target_opt_dtype = torch.bfloat16
else: target_opt_dtype = torch.float32
if args_global.save_weights_dtype == "fp16": final_save_dtype = torch.float16
elif args_global.save_weights_dtype == "bf16": final_save_dtype = torch.bfloat16
else: final_save_dtype = torch.float32
print(f"Using device: {args_global.device}, Opt Dtype: {target_opt_dtype}, Save Dtype: {final_save_dtype}")
if args_global.target_loss: print(f"Target Loss: {args_global.target_loss:.2e} (after {args_global.min_iterations} min iters)")
print(f"Max Iters/Layer: {args_global.max_iterations}")
# --- MODIFIED: Loading Existing LoHA for resuming (using find_best_resume_file) ---
chosen_resume_file = None
if not args_global.overwrite:
print(f"\nChecking for existing LoHA file or resume states for: {args_global.save_to}")
chosen_resume_file, num_modules_in_chosen_file = find_best_resume_file(args_global.save_to)
if chosen_resume_file:
print(f" Attempting to resume from: {chosen_resume_file} ({num_modules_in_chosen_file} modules reported in metadata).")
try:
file_metadata = None
with safetensors.safe_open(chosen_resume_file, framework="pt", device="cpu") as f:
file_metadata = f.metadata()
completed_modules_in_file = set()
if file_metadata and "ss_completed_loha_modules" in file_metadata:
try:
completed_modules_in_file = set(json.loads(file_metadata.get("ss_completed_loha_modules")))
# Verify count if possible, though num_modules_in_chosen_file is already from this.
if len(completed_modules_in_file) != num_modules_in_chosen_file and num_modules_in_chosen_file !=0 : # 0 can be if file had no metadata but was chosen
print(f" Warning: Metadata module count ({len(completed_modules_in_file)}) differs from initial scan count ({num_modules_in_chosen_file}). Using parsed set.")
except json.JSONDecodeError:
print(" Warning: Could not parse 'ss_completed_loha_modules' metadata from chosen file. Will not load specific tensors by prefix matching.")
else:
print(" 'ss_completed_loha_modules' not found in chosen file's metadata. Will not load specific tensors by prefix matching (might load all if no prefixes known).")
if completed_modules_in_file: # Only load if we have a list of modules to check against
print(" Loading tensors from chosen resume file...")
loaded_sd_for_resume = load_file(chosen_resume_file, device='cpu')
resumed_tensor_count = 0
for key, tensor_val in loaded_sd_for_resume.items():
module_prefix_for_check = ".".join(key.split('.')[:-1]) # e.g. lora_unet_..._block_0_fc1
is_bias_for_completed_module = key.endswith(".bias") and module_prefix_for_check in completed_modules_in_file
# Check if the tensor belongs to a module marked as completed
# This covers hada_w1_a, hada_w1_b etc. for LoHA layers, and biases.
if module_prefix_for_check in completed_modules_in_file or is_bias_for_completed_module :
extracted_loha_state_dict_global[key] = tensor_val
resumed_tensor_count +=1
previously_completed_module_prefixes_global = completed_modules_in_file
all_completed_module_prefixes_ever_global.update(previously_completed_module_prefixes_global) # Initialize with loaded
print(f" Successfully loaded {len(previously_completed_module_prefixes_global)} module prefixes "
f"with {resumed_tensor_count} tensors for resume from {os.path.basename(chosen_resume_file)}.")
del loaded_sd_for_resume
elif not completed_modules_in_file and num_modules_in_chosen_file == 0 and os.path.exists(chosen_resume_file):
# This case could mean an empty LoRA was found (e.g. from a previous failed start)
# or a file without the specific metadata was chosen by find_best_resume_file.
# We don't load anything specific but acknowledge the file existed.
print(f" Chosen resume file {os.path.basename(chosen_resume_file)} seems empty or has no LoHA module metadata. Starting new layer processing.")
# Optional: Load accompanying JSON metadata if it exists for the chosen_resume_file
resume_metadata_json_path = os.path.splitext(chosen_resume_file)[0] + "_extraction_metadata.json"
if os.path.exists(resume_metadata_json_path):
try:
with open(resume_metadata_json_path, 'r') as f_meta:
loaded_json_meta = json.load(f_meta)
# You could potentially load old layer_optimization_stats_global if needed for some cumulative report
# For now, we just acknowledge it.
print(f" Loaded accompanying metadata from: {os.path.basename(resume_metadata_json_path)}")
except Exception as e_json:
print(f" Could not load or parse JSON metadata from {resume_metadata_json_path}: {e_json}")
except Exception as e:
print(f" Error loading or parsing chosen LoHA file '{chosen_resume_file}': {e}. Starting fresh for new layers.")
extracted_loha_state_dict_global.clear()
previously_completed_module_prefixes_global.clear()
all_completed_module_prefixes_ever_global.clear()
else:
print(" No suitable existing LoHA file found to resume from. Starting fresh.")
# Globals are already empty, so no action needed.
elif args_global.overwrite and os.path.exists(args_global.save_to):
print(f"Overwriting specified output file as per --overwrite: {args_global.save_to}")
print(" Any existing intermediate resume files for this target will NOT be automatically cleaned with --overwrite until a new final save.")
extracted_loha_state_dict_global.clear()
previously_completed_module_prefixes_global.clear()
all_completed_module_prefixes_ever_global.clear()
# Note: We don't clean intermediates here because the user might want to revert.
# Cleanup happens on successful *final* save.
print(f"\nLoading base model: {args_global.base_model_path}")
if args_global.base_model_path.endswith(".safetensors"): base_model_sd = load_file(args_global.base_model_path, device='cpu')
else: base_model_sd = torch.load(args_global.base_model_path, map_location='cpu'); base_model_sd = base_model_sd.get('state_dict', base_model_sd)
print(f"Loading fine-tuned model: {args_global.ft_model_path}")
if args_global.ft_model_path.endswith(".safetensors"): ft_model_sd = load_file(args_global.ft_model_path, device='cpu')
else: ft_model_sd = torch.load(args_global.ft_model_path, map_location='cpu'); ft_model_sd = ft_model_sd.get('state_dict', ft_model_sd)
# Reset session-specific counters
processed_layers_this_session_count_global = 0
skipped_identical_count_global = 0 # For this session's scan
skipped_other_reason_count_global = 0 # For this session's scan
keys_scanned_this_run_global = 0
layer_optimization_stats_global.clear() # For this session's stats
main_loop_completed_scan_flag_global = False
all_candidate_keys = []
for k in base_model_sd.keys():
if k.endswith('.weight') and k in ft_model_sd and (len(base_model_sd[k].shape) == 2 or len(base_model_sd[k].shape) == 4):
all_candidate_keys.append(k)
all_candidate_keys.sort()
total_candidates_to_scan = len(all_candidate_keys)
print(f"Found {total_candidates_to_scan} candidate '.weight' keys common to both models and of suitable shape.")
outer_pbar_global = tqdm(total=total_candidates_to_scan, desc="Scanning Layers", dynamic_ncols=True, position=0)
try:
for key_name in all_candidate_keys:
if save_attempted_on_interrupt: break
keys_scanned_this_run_global += 1
outer_pbar_global.update(1)
original_module_path = key_name[:-len(".weight")]
loha_key_prefix = ""
if original_module_path.startswith("model.diffusion_model."): loha_key_prefix = "lora_unet_" + original_module_path[len("model.diffusion_model."):].replace(".", "_")
elif original_module_path.startswith("conditioner.embedders.0.transformer."): loha_key_prefix = "lora_te1_" + original_module_path[len("conditioner.embedders.0.transformer."):].replace(".", "_")
elif original_module_path.startswith("conditioner.embedders.1.model.transformer."): loha_key_prefix = "lora_te2_" + original_module_path[len("conditioner.embedders.1.model.transformer."):].replace(".", "_")
else: loha_key_prefix = "lora_" + original_module_path.replace(".", "_")
# Check if already processed (either resumed or done in this session earlier if logic allowed re-scanning)
if loha_key_prefix in all_completed_module_prefixes_ever_global:
if args_global.verbose:
if loha_key_prefix in previously_completed_module_prefixes_global:
tqdm.write(f"Skipping {loha_key_prefix} (scan): already processed (loaded from resumed LoHA).")
# else: # This case should not happen if all_completed_module_prefixes_ever_global is managed correctly
# tqdm.write(f"Skipping {loha_key_prefix} (scan): already processed in this session (should be rare).")
outer_pbar_global.set_description_str(f"Scan {keys_scanned_this_run_global}/{total_candidates_to_scan} (Resumed: {len(previously_completed_module_prefixes_global)}, New Opt: {processed_layers_this_session_count_global})")
continue
if args_global.max_layers is not None and args_global.max_layers > 0 and processed_layers_this_session_count_global >= args_global.max_layers:
# Still need to scan all keys to correctly determine if the job is "fully complete" later
# So, we just skip optimization but continue scanning.
if args_global.verbose and processed_layers_this_session_count_global == args_global.max_layers and (keys_scanned_this_run_global - (len(all_completed_module_prefixes_ever_global) - processed_layers_this_session_count_global) - skipped_identical_count_global - skipped_other_reason_count_global) == (args_global.max_layers +1) : # First time hitting this after max_layers
tqdm.write(f"\nReached max_layers limit ({args_global.max_layers}) for new layers this session. Continuing scan only to assess remaining layers.")
outer_pbar_global.set_description_str(f"Scan {keys_scanned_this_run_global}/{total_candidates_to_scan} (Max New Layers Reached, Opt Ths Sess: {processed_layers_this_session_count_global})")
# This key is not skipped due to being identical or other error, but due to max_layers.
# We don't increment skipped_other_reason_count_global here, as it's still a valid candidate for a future run.
continue # Continue scanning
base_W = base_model_sd[key_name].to(dtype=torch.float32)
ft_W = ft_model_sd[key_name].to(dtype=torch.float32)
if base_W.shape != ft_W.shape:
skipped_other_reason_count_global +=1
if args_global.verbose: tqdm.write(f"Skipping {key_name} (shape mismatch).")
continue
shape_info = get_module_shape_info_from_weight(base_W)
if shape_info is None:
skipped_other_reason_count_global +=1
if args_global.verbose: tqdm.write(f"Skipping {key_name} (unsupported shape).")
continue
delta_W_fp32 = (ft_W - base_W)
if torch.allclose(delta_W_fp32, torch.zeros_like(delta_W_fp32), atol=args_global.atol_fp32_check):
if args_global.verbose: tqdm.write(f"Skipping {key_name} (identical weights).")
skipped_identical_count_global += 1
continue
# If we reach here, this layer is a candidate for optimization in this session
max_layers_target_str = f"/{args_global.max_layers}" if args_global.max_layers is not None and args_global.max_layers > 0 else ""
outer_pbar_global.set_description_str(f"Optimizing L{processed_layers_this_session_count_global + 1}{max_layers_target_str} (Scan {keys_scanned_this_run_global}/{total_candidates_to_scan})")
if args_global.verbose: tqdm.write(f"\n Orig: {key_name} -> LoHA: {loha_key_prefix}")
out_dim, in_dim_effective, k_h, k_w, _, is_conv = shape_info
delta_W_target_for_opt = delta_W_fp32.to(dtype=target_opt_dtype)
current_rank = args_global.conv_rank if is_conv and args_global.conv_rank is not None else args_global.rank
current_initial_alpha = args_global.initial_conv_alpha if is_conv else args_global.initial_alpha
tqdm.write(f"Optimizing Layer {processed_layers_this_session_count_global + 1}{max_layers_target_str}: {loha_key_prefix} (Orig: {original_module_path}, Shp: {list(base_W.shape)}, R: {current_rank}, Alpha_init: {current_initial_alpha:.1f})")
try:
opt_results = optimize_loha_for_layer(
layer_name=loha_key_prefix, delta_W_target=delta_W_target_for_opt,
out_dim=out_dim, in_dim_effective=in_dim_effective, k_h=k_h, k_w=k_w, rank=current_rank,
initial_alpha_val=current_initial_alpha, lr=args_global.lr,
max_iterations=args_global.max_iterations, min_iterations=args_global.min_iterations,
target_loss=args_global.target_loss, weight_decay=args_global.weight_decay,
device=args_global.device, dtype=target_opt_dtype, is_conv=is_conv,
verbose_layer_debug=args_global.verbose_layer_debug
)
if not opt_results.get('interrupted_mid_layer'):
for p_name, p_val in opt_results.items():
if p_name not in ['final_loss', 'stopped_early', 'iterations_done', 'interrupted_mid_layer']:
extracted_loha_state_dict_global[f'{loha_key_prefix}.{p_name}'] = p_val.to(final_save_dtype)
layer_optimization_stats_global.append({
"name": loha_key_prefix, "original_name": original_module_path,
"final_loss": opt_results['final_loss'], "iterations_done": opt_results['iterations_done'],
"stopped_early_by_loss_target": opt_results['stopped_early']})
all_completed_module_prefixes_ever_global.add(loha_key_prefix) # Add to cumulative set
tqdm.write(f" Layer {loha_key_prefix} Done. Loss: {opt_results['final_loss']:.4e}, Iters: {opt_results['iterations_done']}{', Stopped by Loss' if opt_results['stopped_early'] else ''}")
if args_global.use_bias:
original_bias_key = f"{original_module_path}.bias"
# Check if bias exists in ft_model and differs from base (or base doesn't have it)
bias_differs = False
if original_bias_key in ft_model_sd:
ft_B = ft_model_sd[original_bias_key].to(dtype=torch.float32)
if original_bias_key in base_model_sd:
base_B = base_model_sd[original_bias_key].to(dtype=torch.float32)
if not torch.allclose(base_B, ft_B, atol=args_global.atol_fp32_check):
bias_differs = True
else: # Bias in FT but not in base
bias_differs = True
if bias_differs:
extracted_loha_state_dict_global[original_bias_key] = ft_B.cpu().to(final_save_dtype)
if args_global.verbose: tqdm.write(f" Saved differing/new bias for {original_bias_key}")
# Note: Bias keys are not added to "loha_key_prefix" sets as they don't have LoHA params.
# They are just carried over if different.
processed_layers_this_session_count_global += 1
else: # Interrupted mid-layer
if args_global.verbose: tqdm.write(f" Opt for {loha_key_prefix} interrupted; not saving params for this layer.")
# Do not add to all_completed_module_prefixes_ever_global or increment processed_layers_this_session_count_global
except Exception as e:
print(f"\nError during optimization for {original_module_path} ({loha_key_prefix}): {e}")
import traceback; traceback.print_exc()
skipped_other_reason_count_global +=1 # Count as skipped due to error during opt
# After the loop finishes (or breaks due to interrupt)
if not save_attempted_on_interrupt and keys_scanned_this_run_global == total_candidates_to_scan:
main_loop_completed_scan_flag_global = True
finally: # This will run whether the try block completes normally or an exception (like interrupt) occurs
if outer_pbar_global:
if not outer_pbar_global.disable and outer_pbar_global.n < outer_pbar_global.total:
outer_pbar_global.update(outer_pbar_global.total - outer_pbar_global.n) # Fill up the bar
outer_pbar_global.close()
# --- Save decision logic ---
if not save_attempted_on_interrupt: # If interrupted, handler already saved
print("\n--- Final Optimization Summary (This Session) ---")
for stat in layer_optimization_stats_global: print(f"Layer: {stat['name']}, Final Loss: {stat['final_loss']:.4e}, Iters: {stat['iterations_done']}{', Stopped by Loss' if stat['stopped_early_by_loss_target'] else ''}")
print(f"\n--- Overall Summary ---")
print(f"Total unique LoHA modules accumulated (resumed + new): {len(all_completed_module_prefixes_ever_global)}")
print(f" Processed new this session: {processed_layers_this_session_count_global}")
print(f" Skipped as identical (this session's scan): {skipped_identical_count_global}")
print(f" Skipped for other reasons (this session's scan, e.g., shape error, opt error): {skipped_other_reason_count_global}")
print(f" Total candidate keys scanned in loop (this session): {keys_scanned_this_run_global}/{total_candidates_to_scan}")
actual_save_path: str
save_to_final_name = False
if main_loop_completed_scan_flag_global:
# Number of layers that were found to be different and optimizable during the full scan of *this session*
num_optimizable_layers_identified_in_scan = total_candidates_to_scan - skipped_identical_count_global - skipped_other_reason_count_global
# Check if all *such* layers are now accounted for in our cumulative set
if len(all_completed_module_prefixes_ever_global) >= num_optimizable_layers_identified_in_scan:
# This implies that all layers that showed a difference in the current model comparison
# are now present in the LoHA state dict (either from resume or processed now).
# We also need to ensure max_layers didn't prematurely stop us if it was less than this count.
if args_global.max_layers is None or processed_layers_this_session_count_global >= args_global.max_layers or len(all_completed_module_prefixes_ever_global) < (len(previously_completed_module_prefixes_global) + args_global.max_layers):
# If max_layers is not set, or if we processed up to max_layers (or didn't need to because all were done),
# and the total count matches the optimizable count from scan, then it's final.
save_to_final_name = True
else: # max_layers was hit, and it's less than total optimizable, so not final.
print(f" Scan completed, but max_layers ({args_global.max_layers}) may have limited processing before all {num_optimizable_layers_identified_in_scan} differing layers were handled.")
else:
print(f" Scan completed, but not all {num_optimizable_layers_identified_in_scan} differing layers are processed yet "
f"(current total: {len(all_completed_module_prefixes_ever_global)}).")
else:
print(" Scan did not complete fully. Saving intermediate state.")
if save_to_final_name:
actual_save_path = args_global.save_to
print(f"\nAll optimizable layers appear to be processed. Saving to final path: {actual_save_path}")
else:
num_layers_for_filename = len(all_completed_module_prefixes_ever_global)
actual_save_path = generate_intermediate_filename(args_global.save_to, num_layers_for_filename)
print(f"\nRun incomplete or not all differing layers processed. Saving intermediate state to: {actual_save_path}")
perform_graceful_save(output_path_to_save=actual_save_path)
if save_to_final_name and actual_save_path == args_global.save_to : # Ensure it's the final path
print("\nCleaning up intermediate resume files...")
cleanup_intermediate_files(args_global.save_to)
else: # Save was attempted by interrupt handler
print("\nProcess was interrupted. Graceful save to an intermediate file was attempted by signal handler.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract LoHA parameters by optimizing against weight differences. Saves intermediate files like 'name_resume_L{count}.safetensors'.")
parser.add_argument("base_model_path", type=str, help="Path to the base model state_dict file (.pt, .pth, .safetensors)")
parser.add_argument("ft_model_path", type=str, help="Path to the fine-tuned model state_dict file (.pt, .pth, .safetensors)")
parser.add_argument("save_to", type=str, help="Path to save the FINAL extracted LoHA file (recommended .safetensors). Intermediate files will be based on this name.")
parser.add_argument("--overwrite", action="store_true", help="Ignore and overwrite any existing FINAL LoHA output file and its intermediate files if found at the start. Does not prevent resuming from other intermediate files if the final target does not exist.")
parser.add_argument("--rank", type=int, default=4, help="Default rank for LoHA decomposition (used for linear layers and as fallback for conv).")
parser.add_argument("--conv_rank", type=int, default=None, help="Specific rank for convolutional LoHA layers. Defaults to --rank if not set.")
parser.add_argument("--initial_alpha", type=float, default=None, help="Global initial alpha for optimization. Defaults to 'rank'.")
parser.add_argument("--initial_conv_alpha", type=float, default=None, help="Specific initial alpha for Conv LoHA. Defaults to '--initial_alpha' or conv_rank.")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate for LoHA optimization per layer.")
parser.add_argument("--max_iterations", type=int, default=1000, help="Maximum number of optimization iterations per layer.")
parser.add_argument("--min_iterations", type=int, default=100, help="Minimum iterations before checking target loss.")
parser.add_argument("--target_loss", type=float, default=None, help="Target MSE loss for early stopping per layer.")
parser.add_argument("--weight_decay", type=float, default=1e-5, help="Weight decay for LoHA optimization.")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device ('cuda' or 'cpu').")
parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], help="Optimization precision. Default: fp32.")
parser.add_argument("--save_weights_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"], help="Dtype for saved LoHA weights. Default: bf16.")
parser.add_argument("--atol_fp32_check", type=float, default=1e-6, help="Tolerance for identical weight check.")
parser.add_argument("--use_bias", action="store_true", help="Save differing bias terms.")
parser.add_argument("--dropout", type=float, default=0.0, help="General dropout (metadata only).")
parser.add_argument("--rank_dropout", type=float, default=0.0, help="Rank dropout (metadata only).")
parser.add_argument("--module_dropout", type=float, default=0.0, help="Module dropout (metadata only).")
parser.add_argument("--max_layers", type=int, default=None, help="Max NEW differing layers to process this session. Scan will continue to assess all layers.")
parser.add_argument("--verbose", action="store_true", help="General verbose output.")
parser.add_argument("--verbose_layer_debug", action="store_true", help="Detailed per-iteration optimization debug output.")
parsed_args = parser.parse_args()
if not os.path.exists(parsed_args.base_model_path): print(f"Error: Base model path not found: {parsed_args.base_model_path}"); exit(1)
if not os.path.exists(parsed_args.ft_model_path): print(f"Error: Fine-tuned model path not found: {parsed_args.ft_model_path}"); exit(1)
save_dir = os.path.dirname(parsed_args.save_to)
if save_dir and not os.path.exists(save_dir):
try:
os.makedirs(save_dir, exist_ok=True)
print(f"Created directory: {save_dir}")
except OSError as e:
print(f"Error: Could not create directory {save_dir}: {e}"); exit(1)
if parsed_args.initial_alpha is None: parsed_args.initial_alpha = float(parsed_args.rank)
# Ensure conv_alpha defaults correctly after initial_alpha might have defaulted to rank
if parsed_args.initial_conv_alpha is None:
# If conv_rank is set, use that for default alpha, else use the global initial_alpha (which might itself be rank)
conv_rank_for_alpha_default = parsed_args.conv_rank if parsed_args.conv_rank is not None else parsed_args.rank
parsed_args.initial_conv_alpha = float(conv_rank_for_alpha_default) if parsed_args.conv_rank is not None else parsed_args.initial_alpha
main(parsed_args)

File diff suppressed because it is too large Load Diff