pull/3264/head
bmaltais 2025-05-13 10:51:18 -04:00
parent 8a50bf2dca
commit 235a059bb1
1 changed files with 226 additions and 155 deletions

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file, load_file
import safetensors
import safetensors
from tqdm import tqdm
import math
import json
@ -12,8 +12,8 @@ from collections import OrderedDict
import signal
import sys
import glob
import traceback
import re
import traceback
import re
from enum import Enum, auto
# --- Global variables ---
@ -54,7 +54,7 @@ class LogType(Enum):
def log_layer_optimization_event(log_type: LogType, layer_name: str, **kwargs):
# This function will only be called after args_global is set in main()
if not (args_global and args_global.verbose):
if not (args_global and args_global.verbose):
return
if not args_global.verbose_layer_debug:
@ -120,7 +120,7 @@ def log_layer_optimization_event(log_type: LogType, layer_name: str, **kwargs):
elif kwargs.get('reason_type') == 'insufficient_progress': reason_detail = "insufficient raw progress"
elif kwargs.get('reason_type') == 'max_iterations_no_target': reason_detail = f"max iters (Loss {kwargs['current_loss']:.2e}, Target not met)"
elif kwargs.get('reason_type') == 'max_iterations_no_target_set': reason_detail = f"max iters (Loss {kwargs['current_loss']:.2e})"
if reason_detail:
msg = f"Att {kwargs['attempt']}(R:{kwargs['rank']}) ended: {reason_detail}. Will try next rank..."
elif log_type == LogType.NO_VALID_OPTIMIZATION_RESULT:
@ -146,7 +146,7 @@ def optimize_loha_for_layer(
lr: float = 1e-3, max_iterations: int = 1000, min_iterations: int = 100,
target_loss: float = None, weight_decay: float = 1e-4,
device: str = 'cuda', dtype: torch.dtype = torch.float32,
is_conv: bool = True, verbose_layer_debug: bool = False,
is_conv: bool = True, verbose_layer_debug: bool = False,
max_rank_retries: int = 0,
rank_increase_factor: float = 1.25,
existing_loha_layer_parameters: dict | None = None
@ -189,7 +189,7 @@ def optimize_loha_for_layer(
current_rank_for_this_attempt = max(rank_base_for_next_increase + 1, increased_rank)
original_alpha_to_rank_ratio = initial_alpha_for_layer / float(initial_rank_for_layer) if initial_rank_for_layer > 0 else 1.0
alpha_init_for_this_attempt = original_alpha_to_rank_ratio * float(current_rank_for_this_attempt)
prev_rank_for_warm_start_log = best_result_so_far.get('final_rank_used')
if 'hada_w1_a' in best_result_so_far and not args_global.no_warm_start:
if prev_rank_for_warm_start_log < current_rank_for_this_attempt:
@ -221,7 +221,7 @@ def optimize_loha_for_layer(
getattr(locals()[f"{p_name}_p"], 'data').copy_(existing_loha_layer_parameters[p_name].to(device, dtype))
log_layer_optimization_event(LogType.INITIAL_PARAMS_LOADED, layer_name, rank=current_rank_for_this_attempt)
initialized_from_external_or_warm_start = True
except Exception: pass
except Exception: pass
if not initialized_from_external_or_warm_start and attempt_idx > 0 and warm_start_status_for_log == 'applied':
prev_params = {k: best_result_so_far[k].to(device, dtype) for k in ['hada_w1_a', 'hada_w1_b', 'hada_w2_a', 'hada_w2_b']}
@ -231,7 +231,7 @@ def optimize_loha_for_layer(
for p_slice in [hada_w1_a_p.data[:, prev_rank:], hada_w2_a_p.data[:, prev_rank:]]: nn.init.kaiming_uniform_(p_slice, a=math.sqrt(5))
for p_slice in [hada_w1_b_p.data[prev_rank:, :], hada_w2_b_p.data[prev_rank:, :]]: nn.init.normal_(p_slice, std=0.02)
initialized_from_external_or_warm_start = True
if not initialized_from_external_or_warm_start:
log_layer_optimization_event(LogType.INITIAL_PARAMS_KAIMING_NORMAL, layer_name, rank=current_rank_for_this_attempt, attempt=attempt_idx + 1)
for p in [hada_w1_a_p, hada_w2_a_p]: nn.init.kaiming_uniform_(p.data, a=math.sqrt(5))
@ -244,14 +244,13 @@ def optimize_loha_for_layer(
iter_pbar_desc = f"Opt Att {attempt_idx+1}/{max_rank_retries+1} (R:{current_rank_for_this_attempt}){' [LastRank]' if is_last_rank_attempt else ''}: {layer_name}"
iter_pbar = tqdm(range(max_iterations), desc=iter_pbar_desc, leave=False, dynamic_ncols=True, position=1, mininterval=0.5)
current_attempt_final_loss = float('inf'); current_attempt_stopped_early_by_loss = False
current_attempt_insufficient_progress = False; current_attempt_stopped_by_projection = False
current_attempt_projection_type = "none"; current_attempt_iterations_done = 0
loss_at_start_of_current_window = float('inf'); progress_window_started_for_attempt = False
relative_improvement_history = []; final_projected_loss_if_failed = None
ema_loss_history = []; current_ema_loss_value = None
# latest_reliable_R_ema_for_pbar, latest_reliable_decay_factor_ema removed as pbar projection removed
for i in iter_pbar:
current_attempt_iterations_done = i + 1
@ -270,7 +269,7 @@ def optimize_loha_for_layer(
delta_W_loha = eff_alpha_scale * term1_flat.view(out_dim, in_dim_effective, k_h, k_w) * term2_flat.view(out_dim, in_dim_effective, k_h, k_w)
else:
delta_W_loha = eff_alpha_scale * term1_flat * term2_flat
loss = F.mse_loss(delta_W_loha, delta_W_target)
raw_current_loss_item = loss.item()
if i == 0 and progress_window_started_for_attempt and loss_at_start_of_current_window == float('inf'):
@ -292,21 +291,21 @@ def optimize_loha_for_layer(
if prog_check_interval_val > 0 and progress_window_started_for_attempt and \
(current_attempt_iterations_done >= iter_to_begin_first_progress_window + prog_check_interval_val) and \
((current_attempt_iterations_done - iter_to_begin_first_progress_window) % prog_check_interval_val == 0):
perform_stop_checks = not is_last_rank_attempt
raw_rel_imprv = (loss_at_start_of_current_window - current_attempt_final_loss) / loss_at_start_of_current_window if loss_at_start_of_current_window > 1e-12 and loss_at_start_of_current_window > current_attempt_final_loss else 0.0
if (target_loss is None or current_attempt_final_loss > target_loss * 1.01) and raw_rel_imprv < min_prog_ratio_val:
log_event_type = LogType.INSUFFICIENT_PROGRESS_STOP if perform_stop_checks else LogType.INSUFFICIENT_PROGRESS_LOG_ONLY
log_layer_optimization_event(log_event_type, layer_name, attempt=attempt_idx+1, rank=current_rank_for_this_attempt, rel_imprv=raw_rel_imprv, min_ratio=min_prog_ratio_val, current_loss=current_attempt_final_loss)
if perform_stop_checks: current_attempt_insufficient_progress = True; break
if target_loss is not None and current_attempt_final_loss > target_loss and not current_attempt_insufficient_progress:
use_ema = len(ema_loss_history) >= proj_min_ema_hist
if not use_ema : log_layer_optimization_event(LogType.EMA_PROJECTION_SKIPPED_HISTORY, layer_name, attempt=attempt_idx+1, rank=current_rank_for_this_attempt, hist_len=len(ema_loss_history), min_hist=proj_min_ema_hist)
req_iters_to_target = float('inf'); proj_type_at_check = "none"; temp_proj_loss = None
if use_ema:
ema_curr_iter, ema_curr_loss = ema_loss_history[-1]
ema_start_iter, ema_start_loss = _get_closest_ema_value_before_iter(current_attempt_iterations_done - prog_check_interval_val, ema_loss_history)
@ -314,7 +313,7 @@ def optimize_loha_for_layer(
if smooth_ema_imprv > 1e-9:
relative_improvement_history.append(smooth_ema_imprv); relative_improvement_history = relative_improvement_history[-2:]
if len(relative_improvement_history) >= 2 and relative_improvement_history[-2] > 1e-9:
if len(relative_improvement_history) >= 2 and relative_improvement_history[-2] > 1e-9:
proj_type_at_check = "advanced_ema"
decay_R = max(adv_proj_decay_cap_min, min(adv_proj_decay_cap_max, smooth_ema_imprv / relative_improvement_history[-2]))
sim_loss, sim_R, sim_iters = ema_curr_loss, smooth_ema_imprv, 0
@ -328,15 +327,15 @@ def optimize_loha_for_layer(
sim_R = max(1e-7, sim_R * decay_R)
req_iters_to_target = sim_iters if sim_loss <= target_loss else float('inf')
else: req_iters_to_target = 0
else:
else:
proj_type_at_check = "simple_ema" + (" (fallback_adv)" if len(relative_improvement_history) >=2 else "")
if ema_curr_loss > target_loss:
try: req_iters_to_target = math.ceil(math.log(target_loss/ema_curr_loss) / math.log(1.0-smooth_ema_imprv)) * prog_check_interval_val; temp_proj_loss = None
except: req_iters_to_target = float('inf'); temp_proj_loss = ema_curr_loss * ((1.0-max(0,smooth_ema_imprv))**((max_iterations-current_attempt_iterations_done)//prog_check_interval_val +1 ))
else: req_iters_to_target = 0
elif use_ema: proj_type_at_check = "stalled_ema"; temp_proj_loss = ema_curr_loss
if req_iters_to_target == float('inf') and not proj_type_at_check.endswith("_ema"):
if req_iters_to_target == float('inf') and not proj_type_at_check.endswith("_ema"):
if use_ema: log_layer_optimization_event(LogType.EMA_PROJECTION_INCONCLUSIVE_FALLBACK_RAW, layer_name, attempt=attempt_idx+1, rank=current_rank_for_this_attempt)
if raw_rel_imprv > 1e-9 and current_attempt_final_loss > target_loss:
proj_type_at_check = "simple_raw_fallback"
@ -344,20 +343,20 @@ def optimize_loha_for_layer(
except: req_iters_to_target = float('inf'); temp_proj_loss = current_attempt_final_loss * ((1.0-max(0,raw_rel_imprv))**((max_iterations-current_attempt_iterations_done)//prog_check_interval_val+1))
elif current_attempt_final_loss <= target_loss: req_iters_to_target = 0
else: proj_type_at_check = "stalled_raw_fallback"; temp_proj_loss = current_attempt_final_loss
if req_iters_to_target > (max_iterations - current_attempt_iterations_done):
log_event_type = LogType.PROJECTION_STOP if perform_stop_checks else LogType.PROJECTION_LOG_ONLY
log_layer_optimization_event(log_event_type, layer_name, attempt=attempt_idx+1, rank=current_rank_for_this_attempt, proj_type=proj_type_at_check,
iters_needed=req_iters_to_target, proj_final_loss=temp_proj_loss, target_loss=target_loss, avail_iters=(max_iterations - current_attempt_iterations_done))
if perform_stop_checks: current_attempt_stopped_by_projection = True; final_projected_loss_if_failed = temp_proj_loss; current_attempt_projection_type = proj_type_at_check; break
else: final_projected_loss_if_failed = temp_proj_loss; current_attempt_projection_type = proj_type_at_check
else: final_projected_loss_if_failed = temp_proj_loss; current_attempt_projection_type = proj_type_at_check
loss_at_start_of_current_window = current_attempt_final_loss
if proj_type_at_check != "none": current_attempt_projection_type = proj_type_at_check
iter_pbar.close()
if current_attempt_final_loss < best_result_so_far['final_loss'] or \
(current_attempt_final_loss == best_result_so_far['final_loss'] and current_rank_for_this_attempt < best_result_so_far['final_rank_used']):
log_layer_optimization_event(LogType.NEW_BEST_RESULT_FOR_LAYER, layer_name, attempt=attempt_idx+1, rank=current_rank_for_this_attempt, loss=current_attempt_final_loss)
@ -370,25 +369,25 @@ def optimize_loha_for_layer(
'stopped_by_projection': current_attempt_stopped_by_projection,
'projection_type_used': current_attempt_projection_type,
'iterations_done': current_attempt_iterations_done, 'final_rank_used': current_rank_for_this_attempt,
'interrupted_mid_layer': False,
'interrupted_mid_layer': False,
'final_projected_loss_on_stop': final_projected_loss_if_failed if current_attempt_stopped_by_projection else None
})
rank_base_for_next_increase = current_rank_for_this_attempt
if current_attempt_stopped_early_by_loss:
log_layer_optimization_event(LogType.TARGET_LOSS_MET_STOP_ALL_RETRIES, layer_name)
break
break
if current_attempt_iterations_done < max_iterations and not any([current_attempt_insufficient_progress, current_attempt_stopped_by_projection, current_attempt_stopped_early_by_loss]):
if not is_last_rank_attempt:
log_layer_optimization_event(LogType.ATTEMPT_EARLY_FINISH_NO_STOP_FLAG, layer_name, attempt=attempt_idx+1, rank=current_rank_for_this_attempt, iters_done=current_attempt_iterations_done, max_iters=max_iterations)
if is_last_rank_attempt:
log_layer_optimization_event(LogType.LAST_RANK_ATTEMPT_SUMMARY, layer_name, target_loss=target_loss,
final_loss_for_layer=best_result_so_far['final_loss'], final_rank_for_layer=best_result_so_far['final_rank_used'])
break
break
if not current_attempt_stopped_early_by_loss :
if not current_attempt_stopped_early_by_loss :
reason_kwargs = {'attempt': attempt_idx + 1, 'rank': current_rank_for_this_attempt, 'is_last_rank_attempt': is_last_rank_attempt}
if current_attempt_stopped_by_projection:
reason_kwargs.update({'reason_type': 'projection_unreachable', 'target_loss': target_loss, 'proj_final_loss': final_projected_loss_if_failed, 'proj_type': current_attempt_projection_type})
@ -396,16 +395,16 @@ def optimize_loha_for_layer(
reason_kwargs.update({'reason_type': 'insufficient_progress'})
elif current_attempt_iterations_done >= max_iterations:
reason_kwargs.update({'reason_type': 'max_iterations_no_target' if target_loss else 'max_iterations_no_target_set', 'current_loss': current_attempt_final_loss})
if 'reason_type' in reason_kwargs :
if 'reason_type' in reason_kwargs :
log_layer_optimization_event(LogType.ATTEMPT_ENDED_WILL_RETRY, layer_name, **reason_kwargs)
if 'hada_w1_a' not in best_result_so_far:
log_layer_optimization_event(LogType.NO_VALID_OPTIMIZATION_RESULT, layer_name)
return {'final_loss': float('inf'), 'interrupted_mid_layer': True, 'final_rank_used': initial_rank_for_layer, 'iterations_done':0}
for key, default_val in [('stopped_early_by_loss', False), ('stopped_by_insufficient_progress', False),
('stopped_by_projection', False), ('projection_type_used', 'none'),
for key, default_val in [('stopped_early_by_loss', False), ('stopped_by_insufficient_progress', False),
('stopped_by_projection', False), ('projection_type_used', 'none'),
('interrupted_mid_layer', False), ('final_projected_loss_on_stop', None),
('final_rank_used', initial_rank_for_layer)]:
best_result_so_far.setdefault(key, default_val)
@ -438,7 +437,7 @@ def find_best_resume_file(intended_final_path: str) -> tuple[str | None, int]:
if metadata and "ss_completed_loha_modules" in metadata:
num_completed = len(json.loads(metadata["ss_completed_loha_modules"]))
if num_completed > max_completed_modules: max_completed_modules, best_file_path = num_completed, file_path
elif num_completed == max_completed_modules:
elif num_completed == max_completed_modules:
if file_path == intended_final_path or (base_save_name+"_resume_L" in os.path.basename(file_path) and best_file_path != intended_final_path):
best_file_path = file_path
elif max_completed_modules == -1 and (best_file_path is None or (file_path == intended_final_path and best_file_path != intended_final_path)):
@ -463,7 +462,7 @@ def cleanup_intermediate_files(final_intended_path: str, for_resume_management:
files_to_consider.sort(key=lambda x: x['l_count']) # Sorts oldest to newest by L_count
files_to_delete = files_to_consider[:-keep_n] # Keeps the last 'keep_n' items (newest)
if args_global and args_global.verbose: print(f" Resume Manager: Found {len(files_to_consider)} files. Deleting {len(files_to_delete)} oldest to keep {keep_n}.")
else:
else:
files_to_delete = files_to_consider
if args_global and args_global.verbose: print(f" Cleaning ALL {len(files_to_delete)} intermediate files...")
@ -482,30 +481,30 @@ def perform_graceful_save(output_path_to_save: str):
total_processed_ever = len(all_completed_module_prefixes_ever_global)
if not extracted_loha_state_dict_global and not total_processed_ever: print(f"No layers to save to {output_path_to_save}. Aborted."); return False
if not args_global: print("Error: Global args not for saving metadata."); return False
save_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16}.get(args_global.save_weights_dtype, torch.bfloat16)
final_sd = OrderedDict((k, v.to(save_dtype) if hasattr(v, 'is_floating_point') and v.is_floating_point() else v) for k, v in extracted_loha_state_dict_global.items())
print(f"\nSaving LoHA for {total_processed_ever} modules ({processed_layers_this_session_count_global} this session) to {output_path_to_save}")
net_alpha = f"{args_global.initial_alpha:.8f}" if args_global.initial_alpha is not None else str(args_global.rank)
conv_alpha_val = args_global.initial_conv_alpha if args_global.initial_conv_alpha is not None else (args_global.conv_rank or args_global.rank)
conv_alpha = f"{conv_alpha_val:.8f}" if isinstance(conv_alpha_val, float) else str(conv_alpha_val)
network_args = {"algo": "loha", "dim": str(args_global.rank), "alpha": net_alpha,
"conv_dim": str(args_global.conv_rank or args_global.rank), "conv_alpha": conv_alpha,
network_args = {"algo": "loha", "dim": str(args_global.rank), "alpha": net_alpha,
"conv_dim": str(args_global.conv_rank or args_global.rank), "conv_alpha": conv_alpha,
**{k: str(getattr(args_global, k)) for k in ["dropout", "rank_dropout", "module_dropout"]}}
sf_meta = {
"ss_network_module": "lycoris.kohya", "ss_network_rank": str(args_global.rank), "ss_network_alpha": net_alpha,
"ss_network_algo": "loha", "ss_network_args": json.dumps(network_args),
"ss_comment": f"Extracted LoHA (Int: {save_attempted_on_interrupt}). OptPrec: {args_global.precision}. SaveDtype: {args_global.save_weights_dtype}. Layers: {total_processed_ever}.",
"ss_base_model_name": os.path.splitext(os.path.basename(args_global.base_model_path))[0],
"ss_base_model_name": os.path.splitext(os.path.basename(args_global.base_model_path))[0],
"ss_ft_model_name": os.path.splitext(os.path.basename(args_global.ft_model_path))[0],
"ss_save_weights_dtype": args_global.save_weights_dtype, "ss_optimization_precision": args_global.precision,
"ss_completed_loha_modules": json.dumps(list(all_completed_module_prefixes_ever_global))
"ss_save_weights_dtype": args_global.save_weights_dtype, "ss_optimization_precision": args_global.precision,
"ss_completed_loha_modules": json.dumps(list(all_completed_module_prefixes_ever_global))
}
serializable_args = {k: str(v) if not isinstance(v, (str, int, float, bool, list, dict, type(None))) else v for k, v in vars(args_global).items()}
json_meta = {
"comfyui_lora_type": "LyCORIS_LoHa", "model_name": os.path.splitext(os.path.basename(output_path_to_save))[0],
@ -514,27 +513,56 @@ def perform_graceful_save(output_path_to_save: str):
"extraction_summary": {"total_cumulative": total_processed_ever, "this_session": processed_layers_this_session_count_global,
"skipped_identical": skipped_identical_count_global, "skipped_other": skipped_other_reason_count_global,
"skipped_good_initial": skipped_good_initial_loss_count_global, "scanned_keys": keys_scanned_this_run_global},
"layer_optimization_details_this_session": [{k: (float(v) if isinstance(v, (torch.Tensor, float)) and k == 'final_loss' else v) for k, v in stat.items()} for stat in layer_optimization_stats_global],
"layer_optimization_details_this_session": layer_optimization_stats_global,
"embedded_safetensors_metadata": sf_meta, "interrupted_save": save_attempted_on_interrupt
}
try:
# --- MODIFICATION START ---
temp_sf_path = None
temp_json_path = None
try:
if output_path_to_save.endswith(".safetensors"):
save_file(final_sd, output_path_to_save, metadata=sf_meta)
json_path = os.path.splitext(output_path_to_save)[0] + "_extraction_metadata.json"
with open(json_path, 'w') as f: json.dump(json_meta, f, indent=4)
print(f"Saved: {output_path_to_save} and {json_path}")
else:
# Define temporary paths
temp_sf_path = output_path_to_save + ".part"
final_json_path = os.path.splitext(output_path_to_save)[0] + "_extraction_metadata.json"
temp_json_path = final_json_path + ".part"
# Save safetensors to temporary file
save_file(final_sd, temp_sf_path, metadata=sf_meta)
# Save JSON metadata to temporary file
with open(temp_json_path, 'w') as f:
json.dump(json_meta, f, indent=4)
# If both temporary saves are successful, replace original files
os.replace(temp_sf_path, output_path_to_save) # Atomically replaces if target exists
os.replace(temp_json_path, final_json_path) # Atomically replaces if target exists
print(f"Saved: {output_path_to_save} and {final_json_path}")
else:
# For non-safetensors (e.g., .pt), torch.save often handles atomicity well,
# but we can apply a similar pattern if needed. Sticking to original for now for this path.
torch.save({'state_dict': final_sd, '__metadata__': sf_meta, '__extended_metadata__': json_meta}, output_path_to_save)
print(f"Saved (basic .pt): {output_path_to_save}")
return True
except Exception as e: print(f"Error saving to {output_path_to_save}: {e}"); traceback.print_exc(); return False
except Exception as e:
print(f"Error saving to {output_path_to_save}: {e}")
traceback.print_exc()
# Attempt to clean up .part files if they exist from a failed save
if temp_sf_path and os.path.exists(temp_sf_path):
try: os.remove(temp_sf_path)
except OSError: pass # Best effort
if temp_json_path and os.path.exists(temp_json_path):
try: os.remove(temp_json_path)
except OSError: pass # Best effort
return False
# --- MODIFICATION END ---
def handle_interrupt(signum, frame):
global save_attempted_on_interrupt, outer_pbar_global, args_global, all_completed_module_prefixes_ever_global
print("\n" + "="*30 + "\nCtrl+C Detected!\n" + "="*30)
if save_attempted_on_interrupt: print("Save already attempted. Exiting."); return
save_attempted_on_interrupt = True
if save_attempted_on_interrupt: print("Save already attempted. Exiting."); return
save_attempted_on_interrupt = True
if outer_pbar_global: outer_pbar_global.close()
if args_global and args_global.save_to:
save_path = generate_intermediate_filename(args_global.save_to, len(all_completed_module_prefixes_ever_global))
@ -543,7 +571,55 @@ def handle_interrupt(signum, frame):
cleanup_intermediate_files(args_global.save_to, True, args_global.keep_n_resume_files)
else: print("Cannot perform interrupt save: args not defined.")
print("Exiting.")
sys.exit(0)
sys.exit(0)
def setup_and_print_configuration(current_args: argparse.Namespace):
"""
Sets up derived configuration values and prints the run configuration.
Modifies current_args in place for 'progress_check_start_iter'.
"""
if current_args.progress_check_start_iter is None:
current_args.progress_check_start_iter = max(1, current_args.progress_check_interval) if current_args.progress_check_interval > 0 else current_args.max_iterations + 1
elif current_args.progress_check_interval <= 0 :
current_args.progress_check_start_iter = current_args.max_iterations + 1
opt_dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
target_opt_dtype = opt_dtype_map.get(current_args.precision, torch.float32)
final_save_dtype_torch = opt_dtype_map.get(current_args.save_weights_dtype, torch.bfloat16)
print(f"Device: {current_args.device}, Opt Dtype: {target_opt_dtype}, Save Dtype: {final_save_dtype_torch}")
if current_args.target_loss: print(f"Target Loss: {current_args.target_loss:.2e} (min iters: {current_args.min_iterations} for target check)")
else: print(f"No Target Loss. Min iters for any early stop: {current_args.min_iterations}.")
print(f"Max Iters/Layer: {current_args.max_iterations}, Max Rank Retries: {current_args.max_rank_retries}, Rank Incr Factor: {current_args.rank_increase_factor}")
if current_args.save_every_n_layers > 0: print(f"Save every {current_args.save_every_n_layers} processed layers enabled.")
if current_args.keep_n_resume_files > 0: print(f"Keeping the {current_args.keep_n_resume_files} most recent resume files.")
if current_args.progress_check_interval > 0:
first_eval_iter = current_args.progress_check_start_iter + current_args.progress_check_interval
print(f"Progress Check: Enabled. Interval: {current_args.progress_check_interval} iters, Min Rel. Loss Decrease: {current_args.min_progress_loss_ratio:.1e}.")
print(f" Progress window starts at iter: {current_args.progress_check_start_iter}, first evaluation at iter: {first_eval_iter}.")
if current_args.target_loss is not None:
print(f" Projection Check: Enabled (if target loss specified). Decay Caps: min={getattr(current_args, 'advanced_projection_decay_cap_min', 'N/A')}, max={getattr(current_args, 'advanced_projection_decay_cap_max', 'N/A')}")
else: print("Progress Check: Disabled (and Projection Check disabled).")
return current_args
def load_models(base_model_path: str, ft_model_path: str) -> tuple[OrderedDict, OrderedDict]:
"""Loads the base and fine-tuned models from the given paths."""
print(f"\nLoading base model: {base_model_path}")
try:
base_sd_raw = load_file(base_model_path, device='cpu') if base_model_path.endswith(".safetensors") else torch.load(base_model_path, map_location='cpu')
base_model_sd = base_sd_raw.get('state_dict', base_sd_raw) if not isinstance(base_sd_raw, OrderedDict) and hasattr(base_sd_raw, 'get') else base_sd_raw
except Exception as e:
print(f"Error loading base model: {e}"); traceback.print_exc(); sys.exit(1)
print(f"Loading fine-tuned model: {ft_model_path}")
try:
ft_sd_raw = load_file(ft_model_path, device='cpu') if ft_model_path.endswith(".safetensors") else torch.load(ft_model_path, map_location='cpu')
ft_model_sd = ft_sd_raw.get('state_dict', ft_sd_raw) if not isinstance(ft_sd_raw, OrderedDict) and hasattr(ft_sd_raw, 'get') else ft_sd_raw
except Exception as e:
print(f"Error loading fine-tuned model: {e}"); traceback.print_exc(); sys.exit(1)
return base_model_sd, ft_model_sd
def main(cli_args):
global args_global, extracted_loha_state_dict_global, layer_optimization_stats_global, \
@ -552,37 +628,19 @@ def main(cli_args):
previously_completed_module_prefixes_global, all_completed_module_prefixes_ever_global, \
main_loop_completed_scan_flag_global, params_to_seed_optimizer_global, skipped_good_initial_loss_count_global
args_global = cli_args # Set global args object
signal.signal(signal.SIGINT, handle_interrupt)
args_global = cli_args
signal.signal(signal.SIGINT, handle_interrupt)
for g_list_or_dict in [extracted_loha_state_dict_global, layer_optimization_stats_global, params_to_seed_optimizer_global]: g_list_or_dict.clear()
for g_set in [previously_completed_module_prefixes_global, all_completed_module_prefixes_ever_global]: g_set.clear()
processed_layers_this_session_count_global = skipped_identical_count_global = skipped_other_reason_count_global = skipped_good_initial_loss_count_global = keys_scanned_this_run_global = 0
main_loop_completed_scan_flag_global = False; save_attempted_on_interrupt = False # Reset interrupt flag
main_loop_completed_scan_flag_global = False; save_attempted_on_interrupt = False
if args_global.progress_check_start_iter is None:
args_global.progress_check_start_iter = max(1, args_global.progress_check_interval) if args_global.progress_check_interval > 0 else args_global.max_iterations + 1
elif args_global.progress_check_interval <= 0 :
args_global.progress_check_start_iter = args_global.max_iterations + 1
args_global = setup_and_print_configuration(args_global)
opt_dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
target_opt_dtype = opt_dtype_map.get(args_global.precision, torch.float32)
final_save_dtype_torch = opt_dtype_map.get(args_global.save_weights_dtype, torch.bfloat16)
print(f"Device: {args_global.device}, Opt Dtype: {target_opt_dtype}, Save Dtype: {final_save_dtype_torch}")
if args_global.target_loss: print(f"Target Loss: {args_global.target_loss:.2e} (min iters: {args_global.min_iterations} for target check)")
else: print(f"No Target Loss. Min iters for any early stop: {args_global.min_iterations}.") # Clarified
print(f"Max Iters/Layer: {args_global.max_iterations}, Max Rank Retries: {args_global.max_rank_retries}, Rank Incr Factor: {args_global.rank_increase_factor}")
if args_global.save_every_n_layers > 0: print(f"Save every {args_global.save_every_n_layers} processed layers enabled.")
if args_global.keep_n_resume_files > 0: print(f"Keeping the {args_global.keep_n_resume_files} most recent resume files.")
target_opt_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}.get(args_global.precision, torch.float32)
final_save_dtype_torch = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}.get(args_global.save_weights_dtype, torch.bfloat16)
if args_global.progress_check_interval > 0:
first_eval_iter = args_global.progress_check_start_iter + args_global.progress_check_interval
print(f"Progress Check: Enabled. Interval: {args_global.progress_check_interval} iters, Min Rel. Loss Decrease: {args_global.min_progress_loss_ratio:.1e}.")
print(f" Progress window starts at iter: {args_global.progress_check_start_iter}, first evaluation at iter: {first_eval_iter}.")
if args_global.target_loss is not None:
print(f" Projection Check: Enabled (if target loss specified). Decay Caps: min={getattr(args_global, 'advanced_projection_decay_cap_min', 'N/A')}, max={getattr(args_global, 'advanced_projection_decay_cap_max', 'N/A')}")
else: print("Progress Check: Disabled (and Projection Check disabled).")
if args_global.continue_training_from_loha:
print(f"\nMode: Continue/Refine from LoHA: {args_global.continue_training_from_loha}")
@ -602,7 +660,7 @@ def main(cli_args):
if os.path.exists(args_global.save_to) and not args_global.overwrite: print(f" Warning: Output {args_global.save_to} exists and may be overwritten.")
elif os.path.exists(args_global.save_to) and args_global.overwrite: print(f" Info: Output {args_global.save_to} will be overwritten due to --overwrite.")
except Exception as e: print(f" Error loading LoHA: {e}."); traceback.print_exc(); sys.exit(1)
elif not args_global.overwrite:
elif not args_global.overwrite:
print(f"\nMode: Standard extraction. Checking resume states for: {args_global.save_to}")
resume_file, num_modules_resume = find_best_resume_file(args_global.save_to)
if resume_file:
@ -614,42 +672,33 @@ def main(cli_args):
if meta and "ss_completed_loha_modules" in meta: completed_in_file = set(json.loads(meta["ss_completed_loha_modules"]))
loaded_sd_resume = load_file(resume_file, device='cpu')
if not completed_in_file and loaded_sd_resume: completed_in_file = {".".join(k.split('.')[:-1]) for k in loaded_sd_resume if k.endswith(".hada_w1_a")}
res_tensor_count = 0
if completed_in_file:
for k, v in loaded_sd_resume.items():
if ".".join(k.split('.')[:-1]) in completed_in_file or k.endswith(".bias"): extracted_loha_state_dict_global[k] = v; res_tensor_count += 1
previously_completed_module_prefixes_global.update(completed_in_file); all_completed_module_prefixes_ever_global.update(completed_in_file)
print(f" Loaded {len(previously_completed_module_prefixes_global)} module prefixes, {res_tensor_count} tensors for resume.")
elif loaded_sd_resume: extracted_loha_state_dict_global.update(loaded_sd_resume) # Load all if no specific list
elif loaded_sd_resume: extracted_loha_state_dict_global.update(loaded_sd_resume)
del loaded_sd_resume
except Exception as e: print(f" Error loading resume file '{resume_file}': {e}. Starting fresh."); extracted_loha_state_dict_global.clear(); previously_completed_module_prefixes_global.clear(); all_completed_module_prefixes_ever_global.clear()
else: print(" No suitable existing LoHA to resume from. Starting fresh.")
elif args_global.overwrite: print(f"\nMode: Standard extraction with --overwrite. Final output {args_global.save_to} will be overwritten.")
print(f"\nLoading base model: {args_global.base_model_path}")
try:
base_sd_raw = load_file(args_global.base_model_path, device='cpu') if args_global.base_model_path.endswith(".safetensors") else torch.load(args_global.base_model_path, map_location='cpu')
base_model_sd = base_sd_raw.get('state_dict', base_sd_raw) if not isinstance(base_sd_raw, OrderedDict) and hasattr(base_sd_raw, 'get') else base_sd_raw
except Exception as e: print(f"Error loading base model: {e}"); traceback.print_exc(); sys.exit(1)
print(f"Loading fine-tuned model: {args_global.ft_model_path}")
try:
ft_sd_raw = load_file(args_global.ft_model_path, device='cpu') if args_global.ft_model_path.endswith(".safetensors") else torch.load(args_global.ft_model_path, map_location='cpu')
ft_model_sd = ft_sd_raw.get('state_dict', ft_sd_raw) if not isinstance(ft_sd_raw, OrderedDict) and hasattr(ft_sd_raw, 'get') else ft_sd_raw
except Exception as e: print(f"Error loading fine-tuned model: {e}"); traceback.print_exc(); sys.exit(1)
base_model_sd, ft_model_sd = load_models(args_global.base_model_path, args_global.ft_model_path)
all_candidate_keys = sorted([k for k in base_model_sd if k.endswith('.weight') and k in ft_model_sd and base_model_sd[k].shape == ft_model_sd[k].shape and (len(base_model_sd[k].shape) in [2,4])])
total_candidates_to_scan = len(all_candidate_keys)
print(f"Found {total_candidates_to_scan} candidate '.weight' keys for LoHA extraction.")
outer_pbar_global = tqdm(total=total_candidates_to_scan, desc="Scanning Layers", dynamic_ncols=True, position=0)
skipped_vae_layers_count = 0
try:
skipped_vae_layers_count = 0
try:
for key_name in all_candidate_keys:
if save_attempted_on_interrupt: break
if save_attempted_on_interrupt: break
keys_scanned_this_run_global += 1; outer_pbar_global.update(1)
original_module_path = key_name[:-len(".weight")]
loha_key_prefix = "lora_" + original_module_path.replace(".", "_")
if "model.diffusion_model." in original_module_path: loha_key_prefix = "lora_unet_" + original_module_path.split("model.diffusion_model.")[-1].replace(".", "_")
@ -662,13 +711,13 @@ def main(cli_args):
is_reopt_target = args_global.continue_training_from_loha and loha_key_prefix in params_to_seed_optimizer_global
if loha_key_prefix in all_completed_module_prefixes_ever_global and not is_reopt_target:
if args_global.verbose_layer_debug: tqdm.write(f" Skipping {loha_key_prefix} (already processed/resumed, not re-opt).")
continue
continue
if args_global.max_layers is not None and args_global.max_layers > 0 and processed_layers_this_session_count_global >= args_global.max_layers:
if args_global.verbose and processed_layers_this_session_count_global == args_global.max_layers and not (loha_key_prefix in all_completed_module_prefixes_ever_global and not is_reopt_target) :
tqdm.write(f"\nMax_layers ({args_global.max_layers}) for new/re-optimized hit. Scan continues.")
outer_pbar_global.set_description_str(f"Scan {keys_scanned_this_run_global}/{total_candidates_to_scan} (Max Layers Reached)")
continue
continue
base_W = base_model_sd[key_name].to(dtype=torch.float32)
ft_W = ft_model_sd[key_name].to(dtype=torch.float32)
@ -683,7 +732,7 @@ def main(cli_args):
if is_reopt_target and args_global.target_loss is not None:
seed_data = params_to_seed_optimizer_global[loha_key_prefix]
loaded_params_cpu = seed_data['params']; loaded_rank_check = seed_data['rank']; loaded_alpha_check = seed_data['alpha']
if all(k in loaded_params_cpu for k in ['hada_w1_a', 'hada_w1_b', 'hada_w2_a', 'hada_w2_b']):
if all(k_ in loaded_params_cpu for k_ in ['hada_w1_a', 'hada_w1_b', 'hada_w2_a', 'hada_w2_b']):
try:
with torch.no_grad():
w1a,w1b,w2a,w2b = (loaded_params_cpu[p].to(args_global.device, target_opt_dtype) for p in ['hada_w1_a','hada_w1_b','hada_w2_a','hada_w2_b'])
@ -693,18 +742,28 @@ def main(cli_args):
init_loss_c = F.mse_loss(init_loha_d, delta_W_target_c).item()
if init_loss_c <= args_global.target_loss:
tqdm.write(f" Skip Re-Opt {loha_key_prefix}: Loaded (R:{loaded_rank_check}, A:{loaded_alpha_check:.2f}) meets target. Loss: {init_loss_c:.4e} <= {args_global.target_loss:.4e}")
layer_optimization_stats_global.append({"name": loha_key_prefix, "original_name": original_module_path, "initial_rank_attempted": loaded_rank_check, "final_rank_used": loaded_rank_check, "rank_was_increased": False, "final_loss": init_loss_c, "iterations_done": 0, "stopped_early_by_loss_target": True, "skipped_reopt_due_to_initial_good_loss": True })
stat_entry_skip = {
"name": str(loha_key_prefix), "original_name": str(original_module_path),
"initial_rank_attempted": int(loaded_rank_check), "final_rank_used": int(loaded_rank_check),
"rank_was_increased": False, "final_loss": float(init_loss_c),
"alpha_final": float(loaded_alpha_check), "iterations_done": 0,
"stopped_early_by_loss_target": True, "stopped_by_insufficient_progress": False,
"stopped_by_projection": False, "projection_type_used": "none",
"final_projected_loss_on_stop": None,
"skipped_reopt_due_to_initial_good_loss": True, "interrupted_mid_layer": False
}
layer_optimization_stats_global.append(stat_entry_skip)
all_completed_module_prefixes_ever_global.add(loha_key_prefix); processed_layers_this_session_count_global += 1; skipped_good_initial_loss_count_global += 1
should_skip_due_to_pre_existing_good_loss = True
elif args_global.verbose_layer_debug: tqdm.write(f" Initial loss for loaded {loha_key_prefix}: {init_loss_c:.4e}. Re-optimizing.")
except Exception as e_c: tqdm.write(f" Warn: Pre-opt loss check failed for {loha_key_prefix}: {e_c}. Optimizing.");
if should_skip_due_to_pre_existing_good_loss:
outer_pbar_global.set_description_str(f"Scan {keys_scanned_this_run_global}/{total_candidates_to_scan} (New/ReOpt: {processed_layers_this_session_count_global - skipped_good_initial_loss_count_global}, SkipGood:{skipped_good_initial_loss_count_global})")
continue
continue
current_op_mode_str = "ReOpt" if is_reopt_target else "NewOpt"
if args_global.verbose: tqdm.write(f"\n--- {current_op_mode_str} Layer {processed_layers_this_session_count_global + 1 - skipped_good_initial_loss_count_global}: {loha_key_prefix} (Orig: {original_module_path}) ---")
initial_rank_opt = args_global.conv_rank if is_conv and args_global.conv_rank is not None else args_global.rank
initial_alpha_opt = args_global.initial_conv_alpha if is_conv else args_global.initial_alpha
existing_params_init = None; max_retries_layer = args_global.max_rank_retries
@ -714,44 +773,57 @@ def main(cli_args):
initial_rank_opt, initial_alpha_opt = seed_data['rank'], seed_data['alpha']
existing_params_init = seed_data['params']
base_rank_est = args_global.conv_rank if is_conv and args_global.conv_rank is not None else args_global.rank
if initial_rank_opt > base_rank_est and args_global.max_rank_retries > 0 :
if initial_rank_opt > base_rank_est and args_global.max_rank_retries > 0 :
est_retries_used = 0; cur_sim_rank = float(base_rank_est)
for _ in range(args_global.max_rank_retries + 10):
if cur_sim_rank >= initial_rank_opt: break
cur_sim_rank = max(math.ceil(cur_sim_rank * args_global.rank_increase_factor), cur_sim_rank + 1); est_retries_used += 1
max_retries_layer = max(0, args_global.max_rank_retries - est_retries_used)
if args_global.verbose: tqdm.write(f" Using loaded R:{initial_rank_opt}, A:{initial_alpha_opt:.1f}. Max further retries for layer: {max_retries_layer}.")
outer_pbar_global.set_description_str(f"{current_op_mode_str} L{processed_layers_this_session_count_global + 1 - skipped_good_initial_loss_count_global} (Scan {keys_scanned_this_run_global}/{total_candidates_to_scan}, SkipGood:{skipped_good_initial_loss_count_global})")
opt_results = optimize_loha_for_layer(
loha_key_prefix, delta_W_fp32, out_dim, in_dim_effective, k_h, k_w,
loha_key_prefix, delta_W_fp32, out_dim, in_dim_effective, k_h, k_w,
initial_rank_opt, initial_alpha_opt,
args_global.lr, args_global.max_iterations, args_global.min_iterations,
args_global.target_loss, args_global.weight_decay, args_global.device, target_opt_dtype,
is_conv, args_global.verbose_layer_debug, max_retries_layer,
args_global.lr, args_global.max_iterations, args_global.min_iterations,
args_global.target_loss, args_global.weight_decay, args_global.device, target_opt_dtype,
is_conv, args_global.verbose_layer_debug, max_retries_layer,
args_global.rank_increase_factor, existing_params_init
)
if not opt_results.get('interrupted_mid_layer') and 'hada_w1_a' in opt_results :
for p_name, p_val in opt_results.items():
if p_name not in ['final_loss', 'stopped_early_by_loss', 'stopped_by_insufficient_progress', 'stopped_by_projection', 'projection_type_used', 'iterations_done', 'final_rank_used', 'interrupted_mid_layer', 'final_projected_loss_on_stop']:
if torch.is_tensor(p_val): extracted_loha_state_dict_global[f'{loha_key_prefix}.{p_name}'] = p_val.to(final_save_dtype_torch)
final_rank_used = opt_results['final_rank_used']
opt_results["name"] = loha_key_prefix; opt_results["original_name"] = original_module_path
opt_results["initial_rank_attempted"] = initial_rank_opt # Rank it started this optimize_loha_for_layer call with
opt_results["rank_was_increased"] = final_rank_used > initial_rank_opt
opt_results.setdefault("skipped_reopt_due_to_initial_good_loss", False)
layer_optimization_stats_global.append(opt_results)
all_completed_module_prefixes_ever_global.add(loha_key_prefix)
stat_entry = {
"name": str(loha_key_prefix),
"original_name": str(original_module_path),
"initial_rank_attempted": int(initial_rank_opt),
"final_rank_used": int(final_rank_used),
"rank_was_increased": bool(final_rank_used > initial_rank_opt),
"final_loss": float(opt_results['final_loss']),
"alpha_final": float(opt_results['alpha'].item()) if isinstance(opt_results.get('alpha'), torch.Tensor) else float(opt_results.get('alpha', 0.0)),
"iterations_done": int(opt_results['iterations_done']),
"stopped_early_by_loss_target": bool(opt_results['stopped_early_by_loss']),
"stopped_by_insufficient_progress": bool(opt_results.get('stopped_by_insufficient_progress', False)),
"stopped_by_projection": bool(opt_results.get('stopped_by_projection', False)),
"projection_type_used": str(opt_results.get('projection_type_used', 'none')),
"final_projected_loss_on_stop": float(l_val) if (l_val := opt_results.get('final_projected_loss_on_stop')) is not None else None,
"skipped_reopt_due_to_initial_good_loss": bool(opt_results.get('skipped_reopt_due_to_initial_good_loss', False)), # Should be False here
"interrupted_mid_layer": bool(opt_results.get('interrupted_mid_layer', False))
}
layer_optimization_stats_global.append(stat_entry)
all_completed_module_prefixes_ever_global.add(loha_key_prefix)
stop_reason_short = ""
if opt_results['stopped_early_by_loss']: stop_reason_short = ", Stop:LossTarget"
elif opt_results.get('stopped_by_projection', False): stop_reason_short = f", Stop:Proj({opt_results.get('projection_type_used','?')})"
elif opt_results['stopped_by_insufficient_progress']: stop_reason_short = ", Stop:RawProg"
tqdm.write(f" Layer {loha_key_prefix} Opt. Done. R_used: {final_rank_used}, FinalLoss: {opt_results['final_loss']:.4e}, Iters: {opt_results['iterations_done']}{stop_reason_short}")
if args_global.use_bias:
bias_key = f"{original_module_path}.bias"
if bias_key in ft_model_sd and (bias_key not in base_model_sd or not torch.allclose(base_model_sd[bias_key], ft_model_sd[bias_key], atol=args_global.atol_fp32_check)):
@ -764,80 +836,79 @@ def main(cli_args):
tqdm.write(f"\n--- Periodic Save: Processed {processed_layers_this_session_count_global} layers. Saving to {periodic_save_path} ---")
if perform_graceful_save(periodic_save_path) and args_global.keep_n_resume_files > 0:
cleanup_intermediate_files(args_global.save_to, True, args_global.keep_n_resume_files)
else:
else:
tqdm.write(f" Optimization for {loha_key_prefix} did not yield saveable results (Interrupt: {opt_results.get('interrupted_mid_layer', 'N/A')}, Loss: {opt_results.get('final_loss', 'N/A')})")
if not opt_results.get('interrupted_mid_layer', False) and 'hada_w1_a' not in opt_results :
skipped_other_reason_count_global += 1; all_completed_module_prefixes_ever_global.add(loha_key_prefix)
if not save_attempted_on_interrupt and keys_scanned_this_run_global == total_candidates_to_scan:
if not opt_results.get('interrupted_mid_layer', False) and 'hada_w1_a' not in opt_results :
skipped_other_reason_count_global += 1; all_completed_module_prefixes_ever_global.add(loha_key_prefix)
if not save_attempted_on_interrupt and keys_scanned_this_run_global == total_candidates_to_scan:
main_loop_completed_scan_flag_global = True
finally:
if outer_pbar_global: outer_pbar_global.close()
if not save_attempted_on_interrupt:
print("\n--- Final Optimization Summary (This Session) ---")
for stat in layer_optimization_stats_global:
for stat in layer_optimization_stats_global: # Already serializable dicts
rank_info = f"InitialR: {stat['initial_rank_attempted']}, FinalR: {stat['final_rank_used']}"
if stat['rank_was_increased']: rank_info += " (Increased)"
proj_loss_info = f" (Proj.FinalLoss ~{stat.get('final_projected_loss_on_stop'):.2e})" if stat.get('final_projected_loss_on_stop') is not None else ""
stop_info = ""
if stat.get('skipped_reopt_due_to_initial_good_loss'): stop_info = ", SkipReOpt:GoodInitialLoss"
elif stat['stopped_early_by_loss_target']: stop_info = ", Stop:LossTarget"
elif stat.get('stopped_early_by_loss_target'): stop_info = ", Stop:LossTarget" # Key matched from stat_entry
elif stat.get('stopped_by_projection', False): stop_info = f", Stop:Proj({stat.get('projection_type_used','?')})" + proj_loss_info
elif stat.get('stopped_by_insufficient_progress', False): stop_info = ", Stop:RawProg"
print(f"Layer: {stat['name']}, {rank_info}, Loss: {stat['final_loss']:.4e}, Iters: {stat['iterations_done']}{stop_info}")
print(f"Layer: {stat['name']}, {rank_info}, Loss: {stat['final_loss']:.4e}, Alpha: {stat['alpha_final']:.2f}, Iters: {stat['iterations_done']}{stop_info}")
print(f"\n--- Overall Summary ---")
print(f"Total unique LoHA modules in final state: {len(all_completed_module_prefixes_ever_global)}")
# ... (other summary prints) ...
print(f" Processed (new/re-opt/skipped-good) this session: {processed_layers_this_session_count_global}")
print(f" Skipped identical (this session): {skipped_identical_count_global}")
print(f" Skipped re-opt due to good initial loss (this session): {skipped_good_initial_loss_count_global}")
print(f" Skipped re-opt due to good initial loss (this session): {skipped_good_initial_loss_count_global}")
print(f" Skipped other reasons (this session, VAE, opt error): {skipped_other_reason_count_global} (incl. {skipped_vae_layers_count} VAE)")
print(f" Total candidate keys scanned (this session): {keys_scanned_this_run_global}/{total_candidates_to_scan}")
save_to_final_name = main_loop_completed_scan_flag_global and len(all_completed_module_prefixes_ever_global) >= total_candidates_to_scan
actual_save_path = args_global.save_to if save_to_final_name else generate_intermediate_filename(args_global.save_to, len(all_completed_module_prefixes_ever_global))
reason = "Saving to final path" if save_to_final_name else \
("Run incomplete or --max_layers hit." if not main_loop_completed_scan_flag_global else "Full scan done, but not all layers processed/accounted for.")
print(f"\n{reason}: {actual_save_path}")
if perform_graceful_save(output_path_to_save=actual_save_path):
if args_global.keep_n_resume_files > 0 and not save_to_final_name :
if args_global.keep_n_resume_files > 0 and not save_to_final_name :
cleanup_intermediate_files(args_global.save_to, True, args_global.keep_n_resume_files)
if save_to_final_name and actual_save_path == args_global.save_to :
print("\nCleaning up ALL intermediate resume files (from this script's previous runs)...")
cleanup_intermediate_files(args_global.save_to, False)
cleanup_intermediate_files(args_global.save_to, False)
else: print("\nProcess interrupted. Graceful save to intermediate file attempted.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract LoHA parameters. Saves intermediate files like 'name_resume_L{count}.safetensors'.")
parser.add_argument("base_model_path", type=str, help="Path to base model (.pt, .pth, .safetensors)")
parser.add_argument("ft_model_path", type=str, help="Path to fine-tuned model (.pt, .pth, .safetensors)")
parser.add_argument("save_to", type=str, help="Path for FINAL LoHA output (recommended .safetensors).")
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing FINAL LoHA. Does NOT clean intermediates until successful final save.")
parser.add_argument("--continue_training_from_loha", type=str, default=None, help="Path to existing LoHA to load and continue optimizing.")
parser.add_argument("--rank", type=int, default=4, help="Default rank for LoHA.")
parser.add_argument("--conv_rank", type=int, default=None, help="Specific rank for Conv LoHA. Defaults to --rank.")
parser.add_argument("--initial_alpha", type=float, default=None, help="Global initial alpha. Defaults to 'rank'.")
parser.add_argument("--initial_conv_alpha", type=float, default=None, help="Specific initial alpha for Conv LoHA. Defaults to '--initial_alpha' or conv_rank.")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate per layer.")
parser.add_argument("--max_iterations", type=int, default=1000, help="Max optimization iterations per layer/attempt.")
parser.add_argument("--min_iterations", type=int, default=100, help="Min iterations before target_loss check per attempt.")
parser.add_argument("--target_loss", type=float, default=None, help="Target MSE loss for early stopping. Also for pre-re-opt check.")
parser.add_argument("--weight_decay", type=float, default=1e-5, help="Weight decay for optimization.")
parser.add_argument("--max_rank_retries", type=int, default=0, help="Rank increase retries if target_loss not met (0 for no retries).")
parser.add_argument("--rank_increase_factor", type=float, default=1.25, help="Factor to increase rank on retry.")
parser.add_argument("--progress_check_interval", type=int, default=100, help="Check loss improvement every N iterations (0 to disable).")
parser.add_argument("--min_progress_loss_ratio", type=float, default=0.001, help="Min relative loss decrease over interval.")
parser.add_argument("--progress_check_start_iter", type=int, default=None, help="Iteration for start of first progress window. Default: 'progress_check_interval'.")
@ -848,7 +919,7 @@ if __name__ == "__main__":
parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], help="Optimization precision.")
parser.add_argument("--save_weights_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"], help="Dtype for saved LoHA weights.")
parser.add_argument("--atol_fp32_check", type=float, default=1e-6, help="Tolerance for identical weight check.")
parser.add_argument("--no_warm_start", action="store_true", help="Disable warm-starting higher rank attempts from previous best.")
parser.add_argument("--no_warm_start", action="store_true", help="Disable warm-starting higher rank attempts from previous best.")
parser.add_argument("--use_bias", action="store_true", help="Save differing bias terms into LoHA.")
parser.add_argument("--dropout", type=float, default=0.0, help="General dropout (metadata only).")
@ -857,24 +928,24 @@ if __name__ == "__main__":
parser.add_argument("--max_layers", type=int, default=None, help="Max NEW differing layers to process this session.")
parser.add_argument("--verbose", action="store_true", help="General verbose output.")
parser.add_argument("--verbose_layer_debug", action="store_true", help="Detailed per-iteration debug output (implies --verbose).")
parser.add_argument("--projection_sample_interval", type=int, default=20, help="Loss sample interval for EMA (iterations).")
parser.add_argument("--projection_ema_alpha", type=float, default=0.1, help="Smoothing factor for EMA.")
parser.add_argument("--projection_min_ema_history", type=int, default=5, help="Min EMA samples for EMA-based projection.")
parser.add_argument("--save_every_n_layers", type=int, default=0, help="Save intermediate LoHA every N processed layers (0 to disable).")
parser.add_argument("--keep_n_resume_files", type=int, default=0, help="Keep only N most recent intermediate resume files (0 to keep all).")
parsed_args = parser.parse_args() # This is line 690 based on previous context
parsed_args = parser.parse_args()
if parsed_args.verbose_layer_debug: parsed_args.verbose = True
if not os.path.exists(parsed_args.base_model_path): print(f"Error: Base model not found: {parsed_args.base_model_path}"); sys.exit(1)
if not os.path.exists(parsed_args.ft_model_path): print(f"Error: FT model not found: {parsed_args.ft_model_path}"); sys.exit(1)
save_dir = os.path.dirname(parsed_args.save_to)
if save_dir and not os.path.exists(save_dir):
try: os.makedirs(save_dir, exist_ok=True);
try: os.makedirs(save_dir, exist_ok=True);
except OSError as e: print(f"Error creating dir {save_dir}: {e}"); sys.exit(1)
if parsed_args.initial_alpha is None: parsed_args.initial_alpha = float(parsed_args.rank)
if parsed_args.initial_conv_alpha is None:
conv_r_alpha_def = parsed_args.conv_rank if parsed_args.conv_rank is not None else parsed_args.rank