diff --git a/tools/extract_lora_from_models-nw.py b/tools/extract_lora_from_models-nw.py index 56a6c55..277623b 100644 --- a/tools/extract_lora_from_models-nw.py +++ b/tools/extract_lora_from_models-nw.py @@ -1,20 +1,46 @@ +import sys +import os + +# 1. Add sd-scripts directory to sys.path +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(script_dir) +sd_scripts_dir_path = os.path.join(project_root, "sd-scripts") + +if sd_scripts_dir_path not in sys.path: + sys.path.insert(0, sd_scripts_dir_path) + +# Now you can import from the library package and the networks package +try: + from library import sai_model_spec, model_util, sdxl_model_util + from library.utils import setup_logging + from networks import lora # <--- CORRECTED LORA IMPORT +except ImportError as e: + print(f"Error importing from sd-scripts. Please check your sd-scripts folder structure.") + print(f"Attempted to load from: {sd_scripts_dir_path}") + print(f"Original error: {e}") + print("Current sys.path relevant entries:") + for p in sys.path: + if "sd-scripts" in p or "kohya_ss" in p: # Print relevant paths for debugging + print(p) + # Ensure 'networks' directory exists in 'sd-scripts' and contains 'lora.py' + # Also ensure 'sd-scripts/networks/__init__.py' exists. + raise + +# --- The rest of your script --- import argparse import json -import os +# import os # Already imported import time import torch from safetensors.torch import load_file, save_file from tqdm import tqdm -from library import sai_model_spec, model_util, sdxl_model_util -import lora # Assuming this is your existing lora script/library -from library.utils import setup_logging + setup_logging() import logging logger = logging.getLogger(__name__) MIN_SV = 1e-6 -# ... (Keep all your existing helper functions: index_sv_cumulative, index_sv_fro, etc.) def index_sv_cumulative(S, target): original_sum = float(torch.sum(S)) cumulative_sums = torch.cumsum(S, dim=0) / original_sum @@ -38,72 +64,161 @@ def index_sv_ratio(S, target): return index def index_sv_knee_improved(S, MIN_SV_KNEE=1e-8): # MIN_SV_KNEE can be same as global MIN_SV or specific + """ + Determine rank using the knee point detection method with normalization. + Normalizes singular values and their indices to the [0,1] range + to make the knee detection scale-invariant. + """ n = len(S) - if n < 3: return 1 + if n < 3: # Need at least 3 points to detect a knee + return 1 + + # S is expected to be sorted in descending order. s_max, s_min = S[0], S[-1] - if s_max - s_min < MIN_SV_KNEE: return 1 + + # Handle flat or nearly flat singular value spectrum + if s_max - s_min < MIN_SV_KNEE: + # If all singular values are almost the same, a knee is not well-defined. + # Returning 1 is a conservative choice for low rank. + # Alternatively, n // 2 or n - 1 could be chosen depending on desired behavior. + return 1 + + # Normalize singular values to [0, 1] + # s_normalized[0] will be 1, s_normalized[n-1] will be 0 s_normalized = (S - s_min) / (s_max - s_min) + + # Normalize indices to [0, 1] + # x_normalized[0] will be 0, x_normalized[n-1] will be 1 x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype) + + # The line in normalized space connects (x_norm[0], s_norm[0]) to (x_norm[n-1], s_norm[n-1]) + # This is (0, 1) to (1, 0). + # The equation of the line passing through (0,1) and (1,0) is x + y - 1 = 0. + # So, A=1, B=1, C=-1 for the line equation Ax + By + C = 0. + + # Calculate the perpendicular distance from each point (x_normalized[i], s_normalized[i]) to this line. + # Distance = |A*x_i + B*y_i + C| / sqrt(A^2 + B^2) + # Distance = |1*x_normalized + 1*s_normalized - 1| / sqrt(1^2 + 1^2) + # Distance = |x_normalized + s_normalized - 1| / sqrt(2) + + # The sqrt(2) denominator is constant and doesn't affect argmax, so it can be omitted for finding the index. distances = (x_normalized + s_normalized - 1).abs() + + # Find the 0-based index of the point with the maximum distance. knee_index_0based = torch.argmax(distances).item() + + # Convert 0-based index to 1-based rank. rank = knee_index_0based + 1 + + # Clamp rank similar to original: must be > 0 and <= n-1 (typical for rank reduction) + # If knee_index_0based is n-1 (last point), rank becomes n. min(n, n-1) results in n-1. rank = max(1, min(rank, n - 1)) + return rank def index_sv_cumulative_knee(S, min_sv_threshold=1e-8): + """ + Determine rank using the knee point detection method on the normalized cumulative sum of singular values. + This method identifies a point where adding more singular values contributes diminishingly to the total sum. + """ n = len(S) - if n < 3: return 1 + if n < 3: # Need at least 3 points to detect a knee + return 1 + s_sum = torch.sum(S) - if s_sum < min_sv_threshold: return 1 + # If all singular values are zero or very small, return rank 1. + if s_sum < min_sv_threshold: + return 1 + + # Calculate cumulative sum of singular values, normalized by the total sum. + # y_values[0] = S[0]/s_sum, ..., y_values[n-1] = 1.0 y_values = torch.cumsum(S, dim=0) / s_sum - y_min, y_max = y_values[0], y_values[n-1] - if y_max - y_min < min_sv_threshold: return 1 + + # Normalize these y_values (cumulative sums) to the range [0,1] for knee detection. + y_min, y_max = y_values[0], y_values[n-1] # y_max is typically 1.0 + + # If the normalized cumulative sum curve is very flat (e.g., S[0] captures almost all energy), + # it implies the first few components are dominant. + if y_max - y_min < min_sv_threshold: # Using min_sv_threshold here as a sensitivity for flatness + # This condition means (S[0] + ... + S[n-1]) - S[0] is small relative to sum(S) if n>1 + # Effectively, S[1:] components are negligible. + return 1 + + # y_norm[0] = 0, y_norm[n-1] = 1 (represents the normalized cumulative sum from start to end) y_norm = (y_values - y_min) / (y_max - y_min) + + # x_values are indices, normalized to [0, 1] + # x_norm[0] = 0, ..., x_norm[n-1] = 1 x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype) - distances = (y_norm - x_norm).abs() + + # The "knee" is the point on the curve (x_norm[i], y_norm[i]) that is farthest + # from the line connecting the start and end of this normalized curve. + # In this normalized space, the line connects (0,0) to (1,1). + # The equation of this line is Y = X, or X - Y = 0. + # The distance from a point (x_i, y_i) to the line X - Y = 0 is |x_i - y_i| / sqrt(1^2 + (-1)^2). + # We can maximize |x_i - y_i| (or |y_i - x_i|) as sqrt(2) is a constant factor. + distances = (y_norm - x_norm).abs() # y_norm is expected to be >= x_norm for a concave cumulative curve. + + # Find the 0-based index of the point with the maximum distance. knee_index_0based = torch.argmax(distances).item() + + # Convert 0-based index to 1-based rank. rank = knee_index_0based + 1 + + # Clamp rank to be between 1 and n-1 (as n elements give n-1 possible ranks for truncation). + # A rank of n means no truncation. n-1 is the highest sensible rank for reduction. rank = max(1, min(rank, n - 1)) + return rank def index_sv_rel_decrease(S, tau=0.1): - if len(S) < 2: return 1 - ratios = S[1:] / S[:-1] - for k in range(len(ratios)): - if ratios[k] < tau: - return max(1, k + 1) - return min(len(S), len(S) - 1 if len(S) > 1 else 1) - - -def save_to_file(file_name, model_sd, dtype, metadata=None): # Changed model to model_sd for clarity - if dtype is not None: - for key in list(model_sd.keys()): - if isinstance(model_sd[key], torch.Tensor): - model_sd[key] = model_sd[key].to(dtype) + """Determine rank based on relative decrease threshold.""" + if len(S) < 2: + # For matrices with fewer than 2 singular values, a relative decrease + # isn't meaningful. Returning 1 is a sensible default. + return 1 - # Filter out non-tensor metadata if it accidentally gets into model_sd - final_sd = {k: v for k, v in model_sd.items() if isinstance(v, torch.Tensor)} + # Compute ratios of consecutive singular values + # S is sorted descending, so S[:-1] >= S[1:] + # ratios will be <= 1.0 + ratios = S[1:] / S[:-1] # Example: S=[10,1,0.5], ratios=[0.1, 0.5] + + # Find the smallest k such that S[k+1]/S[k] < tau. + # The rank would then be k+1, as we include S[k]. + for k in range(len(ratios)): # k ranges from 0 to len(S)-2 + if ratios[k] < tau: + # We found a significant drop after the k-th singular value. + # So, we keep k+1 singular values (indices 0 to k). + # The rank is k+1. Since k >= 0, k+1 >= 1. + return k + 1 + + # If no drop below tau was found, it means all relative decreases were >= tau. + # In this case, this method suggests using all available singular values. + # The actual rank will be capped later by args.dim/conv_dim and matrix dimensions. + return len(S) +def save_to_file(file_name, model_to_save, state_dict_content, dtype, metadata=None): # Renamed params for clarity + if dtype is not None: + for key in list(state_dict_content.keys()): + if isinstance(state_dict_content[key], torch.Tensor): + state_dict_content[key] = state_dict_content[key].to(dtype) + + # save_file from safetensors expects a state_dict as the first argument if metadata is also passed. + # torch.save would also expect the state_dict. + # The 'model' variable being passed to save_file should be the state_dict itself. if os.path.splitext(file_name)[1] == ".safetensors": - save_file(final_sd, file_name, metadata=metadata) # Pass metadata here + save_file(model_to_save, file_name, metadata=metadata) # Pass metadata correctly else: - # For .pt, metadata is typically not saved in this manner. - # If you need to save metadata with .pt, you might save a dict like {'state_dict': final_sd, 'metadata': metadata} - torch.save(final_sd, file_name) - if metadata: - logger.warning(".pt format does not standardly support metadata like safetensors. Metadata not saved in file.") + torch.save(model_to_save, file_name) -def svd_decomposition( - model_org_path=None, # Renamed for clarity - model_tuned_path=None, # Renamed for clarity +def svd( + model_org=None, + model_tuned=None, save_to=None, - algo="lora", # New: lora or loha - network_dim=4, # For LoRA: rank. For LoHA: "factor" or "hada_dim" - network_alpha=None, # For LoRA: alpha (often same as rank). For LoHA: "rank_initial" or "hada_alpha" - conv_dim=None, # For LoRA: conv_rank. For LoHA: "conv_factor" - conv_alpha=None, # For LoRA: conv_alpha. For LoHA: "conv_rank_initial" + dim=4, v2=None, sdxl=None, + conv_dim=None, v_parameterization=None, device=None, save_precision=None, @@ -118,9 +233,12 @@ def svd_decomposition( verbose=False, ): def str_to_dtype(p): - if p == "float": return torch.float - if p == "fp16": return torch.float16 - if p == "bf16": return torch.bfloat16 + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 return None assert not (v2 and sdxl), "v2 and sdxl cannot be specified at the same time" @@ -128,395 +246,300 @@ def svd_decomposition( load_dtype = str_to_dtype(load_precision) if load_precision else None save_dtype = str_to_dtype(save_precision) if save_precision else torch.float - work_device = "cpu" # Perform SVD and weight manipulation on CPU then move - compute_device = device if device else "cpu" - - # Handle default alpha values based on dim values if not provided - if network_alpha is None: network_alpha = network_dim - if conv_dim is None: conv_dim = network_dim # default conv_dim to network_dim - if conv_alpha is None: conv_alpha = conv_dim # default conv_alpha to conv_dim + work_device = "cpu" # Load models if not sdxl: - logger.info(f"Loading original SD model: {model_org_path}") - text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org_path, load_dtype) + logger.info(f"Loading original SD model: {model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] - - logger.info(f"Loading tuned SD model: {model_tuned_path}") - text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned_path, load_dtype) + if load_dtype: + text_encoder_o.to(load_dtype) + unet_o.to(load_dtype) + + logger.info(f"Loading tuned SD model: {model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] - + if load_dtype: + text_encoder_t.to(load_dtype) + unet_t.to(load_dtype) + model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) - else: # SDXL + else: device_org = load_original_model_to or "cpu" device_tuned = load_tuned_model_to or "cpu" - logger.info(f"Loading original SDXL model: {model_org_path}") + logger.info(f"Loading original SDXL model: {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org_path, device_org, load_dtype + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org ) text_encoders_o = [text_encoder_o1, text_encoder_o2] + if load_dtype: + text_encoder_o1.to(load_dtype) + text_encoder_o2.to(load_dtype) + unet_o.to(load_dtype) - logger.info(f"Loading tuned SDXL model: {model_tuned_path}") + logger.info(f"Loading tuned SDXL model: {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned_path, device_tuned, load_dtype + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned ) text_encoders_t = [text_encoder_t1, text_encoder_t2] + if load_dtype: + text_encoder_t1.to(load_dtype) + text_encoder_t2.to(load_dtype) + unet_t.to(load_dtype) + model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 - # Create temporary LoRA network to identify modules and get original weights - # Use a minimal fixed dimension for this stage as we only need module structure and original weights - temp_lora_kwargs = {"conv_dim": 1, "conv_alpha": 1.0} # Minimal conv settings for network creation - lora_network_o = lora.create_network(1.0, 1, 1.0, None, text_encoders_o, unet_o, **temp_lora_kwargs) - lora_network_t = lora.create_network(1.0, 1, 1.0, None, text_encoders_t, unet_t, **temp_lora_kwargs) - assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras) - - diffs = {} - text_encoder_differs = False - # Text Encoders - for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): - module_key_name = lora_o.lora_name # e.g. "lora_te1_text_model_encoder_layers_0_mlp_fc1" - org_weight = lora_o.org_module.weight.to(device=work_device, dtype=torch.float) - tuned_weight = lora_t.org_module.weight.to(device=work_device, dtype=torch.float) - diff = tuned_weight - org_weight - - if torch.max(torch.abs(diff)) > min_diff: - text_encoder_differs = True - logger.info(f"Text encoder {i+1} module {module_key_name} differs: max diff {torch.max(torch.abs(diff))}") - diffs[module_key_name] = diff - else: - logger.info(f"Text encoder {i+1} module {module_key_name} has no significant difference.") - - # Free memory - lora_o.org_module.weight = None - lora_t.org_module.weight = None - del org_weight, tuned_weight + # Create LoRA network + kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} if conv_dim else {} - # UNet - for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras): - module_key_name = lora_o.lora_name # e.g. "lora_unet_input_blocks_1_1_proj_in" - org_weight = lora_o.org_module.weight.to(device=work_device, dtype=torch.float) - tuned_weight = lora_t.org_module.weight.to(device=work_device, dtype=torch.float) - diff = tuned_weight - org_weight + # Define a small initial dimension for memory efficiency + init_dim = 4 # Small value to minimize memory usage - if torch.max(torch.abs(diff)) > min_diff: - logger.info(f"UNet module {module_key_name} differs: max diff {torch.max(torch.abs(diff))}") - diffs[module_key_name] = diff - else: - logger.info(f"UNet module {module_key_name} has no significant difference.") + # Create LoRA networks with minimal dimension + lora_network_o = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_o, unet_o, **kwargs) + lora_network_t = lora.create_network(1.0, init_dim, init_dim, None, text_encoders_t, unet_t, **kwargs) + + assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), "Model versions differ (SD1.x vs SD2.x)" + # Compute differences + diffs = {} + text_encoder_different = False + for lora_o, lora_t in zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras): + lora_name = lora_o.lora_name + diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device) lora_o.org_module.weight = None lora_t.org_module.weight = None - del org_weight, tuned_weight - if not text_encoder_differs: - logger.warning("Text encoder weights are identical or below min_diff. Text encoder LoRA modules will not be included.") - # Remove text encoder diffs if none were significant - diffs = {k: v for k, v in diffs.items() if "unet" in k} + if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: + text_encoder_different = True + logger.info(f"Text encoder differs: max diff {torch.max(torch.abs(diff))} > {min_diff}") + diffs[lora_name] = diff - del lora_network_o, lora_network_t, text_encoders_o, text_encoders_t, unet_o, unet_t - torch.cuda.empty_cache() + for text_encoder in text_encoders_t: + del text_encoder - lora_module_weights = {} # This will store the final decomposed weights for LoRA/LoHA + if not text_encoder_different: + logger.warning("Text encoders are identical. Extracting U-Net only.") + lora_network_o.text_encoder_loras = [] + diffs.clear() - logger.info(f"Extracting and resizing {algo.upper()} modules via SVD") + for lora_o, lora_t in zip(lora_network_o.unet_loras, lora_network_t.unet_loras): + lora_name = lora_o.lora_name + diff = lora_t.org_module.weight.to(work_device) - lora_o.org_module.weight.to(work_device) + lora_o.org_module.weight = None + lora_t.org_module.weight = None + diffs[lora_name] = diff + del lora_network_t, unet_t + + # Filter relevant modules + lora_names = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras) + + # Extract and resize LoRA using SVD + logger.info("Extracting and resizing LoRA via SVD") + lora_weights = {} with torch.no_grad(): - for module_key_name, mat_diff in tqdm(diffs.items()): - if compute_device != "cpu": # Move to GPU for SVD if specified - mat_diff = mat_diff.to(compute_device) - - # Determine if the layer is convolutional and its properties - is_conv = len(mat_diff.shape) == 4 - kernel_size = None - if is_conv: - kernel_size = mat_diff.shape[2:4] - - # For LoRA, conv_dim/alpha are specific to 3x3 convs. For others, it uses network_dim/alpha. - # For LoHA, we use conv_dim/alpha for any conv layer, and network_dim/alpha for linear. - # This logic can be refined based on how LyCORIS typically handles different conv kernel sizes for LoHA. - # Here, we'll use conv parameters if it's a conv layer, otherwise network parameters. - - current_dim_target = conv_dim if is_conv else network_dim - current_alpha_target = conv_alpha if is_conv else network_alpha + for lora_name in tqdm(lora_names): + mat = diffs[lora_name] + if device: + mat = mat.to(device) + mat = mat.to(torch.float) - # For LoHA, 'alpha_target' is the rank of the first SVD (rank_initial) - # and 'dim_target' is the rank of the second SVD (rank_factor). - # For LoRA, 'dim_target' is the rank of the SVD, and 'alpha_target' is the scaling factor. + conv2d = len(mat.size()) == 4 + kernel_size = mat.size()[2:4] if conv2d else None + conv2d_3x3 = conv2d and kernel_size != (1, 1) + out_dim, in_dim = mat.size()[0:2] - # Reshape convolutional weights for SVD - original_shape = mat_diff.shape - if is_conv: - if kernel_size == (1, 1): - mat_diff = mat_diff.squeeze() # Becomes (out_channels, in_channels) - else: # kernel_size (3,3) or others - mat_diff = mat_diff.flatten(start_dim=1) # Becomes (out_channels, in_channels*k_w*k_h) - - out_features, in_features = mat_diff.shape[0], mat_diff.shape[1] + if conv2d: + mat = mat.flatten(start_dim=1) if conv2d_3x3 else mat.squeeze() - # Perform first SVD - try: - U, S, Vh = torch.linalg.svd(mat_diff) - except Exception as e: - logger.error(f"SVD failed for {module_key_name} with shape {mat_diff.shape}: {e}") - continue - - if compute_device != "cpu": # Move results back to CPU if computation was on GPU - U, S, Vh = U.cpu(), S.cpu(), Vh.cpu() + U, S, Vh = torch.linalg.svd(mat) - # Determine rank for the first SVD (rank_initial for LoHA, rank for LoRA) - max_rank_initial = min(out_features, in_features) # Theoretical max rank - - # Default rank_initial to current_alpha_target (which is network_alpha or conv_alpha) - rank_initial = current_alpha_target - + # Determine rank + max_rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim if dynamic_method: if S[0] <= MIN_SV: - determined_rank = 1 - elif dynamic_method == "sv_ratio": determined_rank = index_sv_ratio(S, dynamic_param) - elif dynamic_method == "sv_cumulative": determined_rank = index_sv_cumulative(S, dynamic_param) - elif dynamic_method == "sv_fro": determined_rank = index_sv_fro(S, dynamic_param) - elif dynamic_method == "sv_knee": determined_rank = index_sv_knee_improved(S, MIN_SV) - elif dynamic_method == "sv_cumulative_knee": determined_rank = index_sv_cumulative_knee(S, MIN_SV) - elif dynamic_method == "sv_rel_decrease": determined_rank = index_sv_rel_decrease(S, dynamic_param) - else: determined_rank = rank_initial # Fallback if dynamic method unknown - rank_initial = min(determined_rank, current_alpha_target, max_rank_initial) + rank = 1 + elif dynamic_method == "sv_ratio": + rank = index_sv_ratio(S, dynamic_param) + elif dynamic_method == "sv_cumulative": + rank = index_sv_cumulative(S, dynamic_param) + elif dynamic_method == "sv_fro": + rank = index_sv_fro(S, dynamic_param) + elif dynamic_method == "sv_knee": + rank = index_sv_knee_improved(S, MIN_SV) # Pass MIN_SV or a specific threshold + elif dynamic_method == "sv_cumulative_knee": # New method + rank = index_sv_cumulative_knee(S, MIN_SV) # Pass MIN_SV or a specific threshold + elif dynamic_method == "sv_rel_decrease": + rank = index_sv_rel_decrease(S, dynamic_param) + rank = min(rank, max_rank, in_dim, out_dim) else: - rank_initial = min(current_alpha_target, max_rank_initial) - rank_initial = max(1, rank_initial) # Ensure rank is at least 1 - - # --- LoRA specific decomposition --- - if algo == 'lora': - lora_down = Vh[:rank_initial, :] - lora_up = U[:, :rank_initial] @ torch.diag(S[:rank_initial]) - - # Clamp values - dist = torch.cat([lora_up.flatten(), lora_down.flatten()]) - hi_val = torch.quantile(dist, clamp_quantile) if clamp_quantile < 1.0 else dist.abs().max() - lora_up = lora_up.clamp(-hi_val, hi_val) - lora_down = lora_down.clamp(-hi_val, hi_val) - - # Reshape for conv layers if necessary - if is_conv: - if kernel_size == (1,1): - # These are already (out_c, rank) and (rank, in_c) - # Some LoRA impls might expect them to be 4D (out_c, rank, 1, 1) and (rank, in_c, 1, 1) - lora_up = lora_up.reshape(out_features, rank_initial, 1, 1) - lora_down = lora_down.reshape(rank_initial, in_features, 1, 1) - else: # e.g. 3x3 conv - # lora_down was (rank, in_c*k_w*k_h), needs to be (rank, in_c, k_w, k_h) - # lora_up was (out_c, rank) - lora_up = lora_up.reshape(out_features, rank_initial, 1, 1) # often up is kept as 1x1 conv like - lora_down = lora_down.reshape(rank_initial, original_shape[1], *kernel_size) - - lora_module_weights[f"{module_key_name}.lora_down.weight"] = lora_down.to(work_device, dtype=save_dtype).contiguous() - lora_module_weights[f"{module_key_name}.lora_up.weight"] = lora_up.to(work_device, dtype=save_dtype).contiguous() - lora_module_weights[f"{module_key_name}.alpha"] = torch.tensor(float(current_alpha_target), dtype=save_dtype) # Use actual alpha target from params - - # --- LoHA specific decomposition --- - elif algo == 'loha': - lora_down_equivalent = Vh[:rank_initial, :] - lora_up_equivalent = U[:, :rank_initial] @ torch.diag(S[:rank_initial]) - - # current_dim_target is the "factor" for LoHA's second SVD - rank_factor = min(current_dim_target, rank_initial) # Factor cannot exceed rank_initial - if is_conv and kernel_size != (1,1): # For conv3x3, factor also limited by in/out features - rank_factor = min(rank_factor, original_shape[1], original_shape[0]) #original_shape[1] is in_channels - else: # Linear or Conv1x1 - rank_factor = min(rank_factor, in_features, out_features) - rank_factor = max(1, rank_factor) - - - # Decompose Lora_Down_Equivalent (shape: rank_initial x in_features_eff) - # Target: hada_w1_b (rank_initial x rank_factor) @ hada_w1_a (rank_factor x in_features_eff) - if compute_device != "cpu": lora_down_equivalent = lora_down_equivalent.to(compute_device) - Ud, Sd, Vhd = torch.linalg.svd(lora_down_equivalent) - if compute_device != "cpu": Ud, Sd, Vhd = Ud.cpu(), Sd.cpu(), Vhd.cpu() - - hada_w1_a = Vhd[:rank_factor, :] - hada_w1_b = Ud[:, :rank_factor] @ torch.diag(Sd[:rank_factor]) - - # Decompose Lora_Up_Equivalent (shape: out_features_eff x rank_initial) - # Target: hada_w2_b (out_features_eff x rank_factor) @ hada_w2_a (rank_factor x rank_initial) - if compute_device != "cpu": lora_up_equivalent = lora_up_equivalent.to(compute_device) - Uu, Su, Vhu = torch.linalg.svd(lora_up_equivalent) - if compute_device != "cpu": Uu, Su, Vhu = Uu.cpu(), Su.cpu(), Vhu.cpu() - - hada_w2_a = Vhu[:rank_factor, :] - hada_w2_b = Uu[:, :rank_factor] @ torch.diag(Su[:rank_factor]) - - # Clamp LoHA components - dist_w1a = hada_w1_a.flatten() - dist_w1b = hada_w1_b.flatten() - dist_w2a = hada_w2_a.flatten() - dist_w2b = hada_w2_b.flatten() - - if clamp_quantile < 1.0: - hi_val_w1a = torch.quantile(dist_w1a.abs(), clamp_quantile) - hi_val_w1b = torch.quantile(dist_w1b.abs(), clamp_quantile) - hi_val_w2a = torch.quantile(dist_w2a.abs(), clamp_quantile) - hi_val_w2b = torch.quantile(dist_w2b.abs(), clamp_quantile) - else: # Use max abs value if quantile is 1.0 or more - hi_val_w1a = dist_w1a.abs().max() - hi_val_w1b = dist_w1b.abs().max() - hi_val_w2a = dist_w2a.abs().max() - hi_val_w2b = dist_w2b.abs().max() - - hada_w1_a = hada_w1_a.clamp(-hi_val_w1a, hi_val_w1a) - hada_w1_b = hada_w1_b.clamp(-hi_val_w1b, hi_val_w1b) - hada_w2_a = hada_w2_a.clamp(-hi_val_w2a, hi_val_w2a) - hada_w2_b = hada_w2_b.clamp(-hi_val_w2b, hi_val_w2b) - - # LoHA weights are typically stored as 2D matrices. - # LyCORIS library handles reshaping or uses 1x1 convs internally. - lora_module_weights[f"{module_key_name}.hada_w1_a"] = hada_w1_a.to(work_device, dtype=save_dtype).contiguous() - lora_module_weights[f"{module_key_name}.hada_w1_b"] = hada_w1_b.to(work_device, dtype=save_dtype).contiguous() - lora_module_weights[f"{module_key_name}.hada_w2_a"] = hada_w2_a.to(work_device, dtype=save_dtype).contiguous() - lora_module_weights[f"{module_key_name}.hada_w2_b"] = hada_w2_b.to(work_device, dtype=save_dtype).contiguous() - lora_module_weights[f"{module_key_name}.alpha"] = torch.tensor(float(current_alpha_target), dtype=save_dtype) # This is rank_initial + rank = min(max_rank, in_dim, out_dim) + rank = max(1, rank) # Ensure rank is at least 1 + + # Truncate SVD components and distribute sqrt(S) + S_k = S[:rank] + U_k = U[:, :rank] + Vh_k = Vh[:rank, :] + + # Ensure S_k values are non-negative before sqrt to avoid NaN from tiny negative SVD artifacts + S_k_non_negative = torch.clamp(S_k, min=0.0) # Use 0.0 for float tensor + s_sqrt = torch.sqrt(S_k_non_negative) + + # Distribute s_sqrt: U_final = U_k * diag(s_sqrt), Vh_final = diag(s_sqrt) * Vh_k + # Using efficient broadcasting for multiplication: + U_final = U_k * s_sqrt.unsqueeze(0) # (out_dim, rank) * (1, rank) + Vh_final = Vh_k * s_sqrt.unsqueeze(1) # (rank, in_dim_effective) * (rank, 1) + + # Clamp values (applied to U_final, Vh_final) + # The distribution of values in U_final and Vh_final might be different + # than the original U and Vh, so the effect of clamping might change. + dist = torch.cat([U_final.flatten(), Vh_final.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + U_clamped = U_final.clamp(-hi_val, hi_val) + Vh_clamped = Vh_final.clamp(-hi_val, hi_val) + + if conv2d: + # U_clamped is (out_dim, rank) + U_clamped = U_clamped.reshape(out_dim, rank, 1, 1) + + # Vh_clamped is (rank, in_dim * possibly_kernel_dims) + # It needs to be reshaped back to (rank, in_dim, kernel_h, kernel_w) + if conv2d_3x3 : # Original mat was (out_dim, in_dim * k_h * k_w) + Vh_clamped = Vh_clamped.reshape(rank, in_dim, *kernel_size) + else: # Original mat was (out_dim, in_dim) for 1x1 conv, kernel_size is (1,1) + Vh_clamped = Vh_clamped.reshape(rank, in_dim, *kernel_size) # kernel_size is (1,1) here + + U_clamped = U_clamped.to(work_device, dtype=save_dtype).contiguous() + Vh_clamped = Vh_clamped.to(work_device, dtype=save_dtype).contiguous() + lora_weights[lora_name] = (U_clamped, Vh_clamped) + + # Verbose output (S values are pre-modification for accurate reporting of original SVD properties) if verbose: - s_sum = float(torch.sum(S)) - s_rank_initial = float(torch.sum(S[:rank_initial])) - fro_initial = float(torch.sqrt(torch.sum(S.pow(2)))) - fro_rank_initial = float(torch.sqrt(torch.sum(S[:rank_initial].pow(2)))) - # This verbose output is for the first SVD. Adding verbose for second SVD would be more complex. + s_sum_total = float(torch.sum(S)) + s_sum_rank = float(torch.sum(S[:rank])) # Sum of the singular values actually used for reconstruction + + fro_orig_total = float(torch.sqrt(torch.sum(S.pow(2)))) + fro_reconstructed_rank = float(torch.sqrt(torch.sum(S[:rank].pow(2)))) # Frobenius norm of the matrix part represented by chosen rank + + # Ratio of the largest retained singular value to the smallest retained singular value + # S is sorted, S[0] is max. S[rank-1] is the smallest singular value included if rank > 0. + ratio_sv = S[0] / S[rank - 1] if rank > 0 and S[rank - 1].abs() > MIN_SV else float('inf') # Avoid division by zero or tiny number + + # Ensure denominators are not zero for percentages + sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > MIN_SV else 1.0 + fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > MIN_SV else 1.0 + logger.info( - f"{module_key_name[:75]:75} | Algo: {algo.upper()}, Rank/Alpha (initial): {rank_initial}, Dim/Factor (final): {rank_factor if algo=='loha' else rank_initial} " - f"| Sum(S) Retained (1st SVD): {s_rank_initial/s_sum if s_sum > 0 else 0:.1%}, Fro Retained (1st SVD): {fro_rank_initial/fro_initial if fro_initial > 0 else 0:.1%}" + f"{lora_name:75} | rank: {rank}, " + f"sum(S) retained: {sum_s_retained_percentage:.2%}, " + f"Frobenius norm retained: {fro_retained_percentage:.2%}, " + f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}" ) - del U, S, Vh, mat_diff - if algo == 'loha': - del lora_down_equivalent, lora_up_equivalent, Ud, Sd, Vhd, Uu, Su, Vhu - del hada_w1_a, hada_w1_b, hada_w2_a, hada_w2_b - else: # lora - del lora_down, lora_up - if compute_device != "cpu": - torch.cuda.empty_cache() + # Create state dict + lora_sd = {} + for lora_name, (up_weight, down_weight) in lora_weights.items(): + lora_sd[lora_name + ".lora_up.weight"] = up_weight + lora_sd[lora_name + ".lora_down.weight"] = down_weight + lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype) # alpha is rank + # Load and save LoRA + lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd) + lora_network_save.apply_to(text_encoders_o, unet_o) # This applies weights, not strictly necessary if just saving sd + info = lora_network_save.load_state_dict(lora_sd) # This populates the network object with the weights from lora_sd + logger.info(f"Loaded extracted and resized LoRA weights into network object: {info}") - if not lora_module_weights: - logger.error("No LoRA/LoHA modules were generated. This might be due to models being too similar or min_diff being too high.") - return os.makedirs(os.path.dirname(save_to), exist_ok=True) # Metadata + net_kwargs = {"conv_dim": str(conv_dim), "conv_alpha": str(float(conv_dim))} if conv_dim else {} + # Determine network_dim and network_alpha for metadata based on dynamic method + if dynamic_method: + network_dim_meta = "Dynamic" + network_alpha_meta = "Dynamic" # Alpha is rank, which is dynamic + else: + network_dim_meta = str(dim) + network_alpha_meta = str(float(dim)) # Alpha is rank, which is dim + metadata = { - "ss_v2": str(v2) if v2 is not None else "false", # More explicit "false" + "ss_v2": str(v2), "ss_base_model_version": model_version, - "ss_sdxl_model_version": "1.0" if sdxl else "null", # LyCORIS convention - "ss_network_module": "networks.lora" if algo == "lora" else "lycoris.kohya", # For LoHA - # For LoRA, dim and alpha are usually the same as what's passed. - # For LoHA, network_dim is the 'factor' (our network_dim/conv_dim), - # and network_alpha is the 'rank_initial' (our network_alpha/conv_alpha). - "ss_network_dim": str(network_dim) if not dynamic_method or algo == "loha" else "Dynamic", # For LoHA, this is 'factor' - "ss_network_alpha": str(network_alpha) if not dynamic_method or algo == "loha" else "Dynamic", # For LoHA, this is 'rank_initial' + "ss_network_module": "networks.lora", + "ss_network_dim": network_dim_meta, + "ss_network_alpha": network_alpha_meta, # Alpha is typically the rank + "ss_network_args": json.dumps(net_kwargs), + "ss_lowram": "False", # Assuming not specifically lowram mode + "ss_num_train_images": "N/A", # Not applicable for extraction + # Add other relevant metadata as per sai_model_spec or conventions } - - net_kwargs = {} - if algo == "lora": - if conv_dim is not None: # Only add if conv_dim was actually specified for LoRA - net_kwargs["conv_dim"] = str(conv_dim) - net_kwargs["conv_alpha"] = str(float(conv_alpha if conv_alpha is not None else conv_dim)) - elif algo == "loha": - net_kwargs["algo"] = "loha" - # For LoHA, conv_dim and conv_alpha are distinct concepts (factor and rank_initial for conv layers) - net_kwargs["conv_dim"] = str(conv_dim) - net_kwargs["conv_alpha"] = str(float(conv_alpha)) - # LyCORIS sometimes uses dropout, but we are not implementing it here. - # net_kwargs["dropout"] = "0" # Example if we had dropout - - metadata["ss_network_args"] = json.dumps(net_kwargs) if net_kwargs else "null" # LyCORIS uses "null" for empty - if not no_metadata: title = os.path.splitext(os.path.basename(save_to))[0] - # sai_model_spec might need adjustment if it expects specific lora types not lycoris - try: - sai_metadata = sai_model_spec.build_metadata( - None, v2, v_parameterization, sdxl, True, - is_lycoris=(algo != "lora"), # Pass if it's LyCORIS - is_lora=(algo == "lora"), # Pass if it's LoRA - creation_time=time.time(), title=title - ) - metadata.update(sai_metadata) - except TypeError as e: - logger.warning(f"Could not generate full SAI metadata, possibly due to outdated sai_model_spec.py or new flags: {e}") - logger.warning("Falling back to basic metadata for SAI fields.") - metadata.update({ # Basic fallback - "sai_model_name": title, - "sai_base_model": model_version, - "sai_is_sdxl": str(sdxl).lower(), - }) + # Build sai_metadata, ensuring it includes necessary fields like 'ss_sd_model_hash' if possible + # For extraction, some training-specific metadata might not be relevant or available. + sai_metadata = sai_model_spec.build_metadata( + None, # training_info (usually from train_util or fine_tune) - can be None for extraction + v2, + v_parameterization, + sdxl, + True, # is_sd2 + False, # is_v_pred_like + time.time(), + title=title, + # model_hash=None, # Original model hash if available + # tuned_model_hash=None # Tuned model hash if available + ) + # Filter out None values from sai_metadata if any, or handle them in build_metadata + sai_metadata_cleaned = {k: v for k, v in sai_metadata.items() if v is not None} + metadata.update(sai_metadata_cleaned) - save_to_file(save_to, lora_module_weights, save_dtype, metadata) - logger.info(f"{algo.upper()} saved to: {save_to}") - + # Use the state_dict 'lora_sd' for saving, not the network object 'lora_network_save' + save_to_file(save_to, lora_sd, lora_sd, save_dtype, metadata) # Pass lora_sd as the model/state_dict to save + logger.info(f"LoRA saved to: {save_to}") def setup_parser(): parser = argparse.ArgumentParser() parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model") - parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2 if applicable)") + parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2)") parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model") - - parser.add_argument("--model_org_path", type=str, required=True, help="Path to the original Stable Diffusion model (ckpt/safetensors)") - parser.add_argument("--model_tuned_path", type=str, required=True, help="Path to the tuned Stable Diffusion model (ckpt/safetensors)") - parser.add_argument("--save_to", type=str, required=True, help="Output file name for the LoRA/LoHA (ckpt/safetensors)") - - parser.add_argument("--algo", type=str, default="lora", choices=["lora", "loha"], help="Algorithm to use: lora or loha") - - parser.add_argument("--network_dim", type=int, default=4, help="Network dimension. For LoRA: rank. For LoHA: 'factor' or 'hada_dim'.") - parser.add_argument("--network_alpha", type=int, default=None, help="Network alpha. For LoRA: alpha (often same as rank). For LoHA: 'rank_initial' or 'hada_alpha'. Defaults to network_dim if not set.") - parser.add_argument("--conv_dim", type=int, default=None, help="Conv dimension for conv layers. For LoRA: rank. For LoHA: 'factor'. Defaults to network_dim if not set.") - parser.add_argument("--conv_alpha", type=int, default=None, help="Conv alpha for conv layers. For LoRA: alpha. For LoHA: 'rank_initial'. Defaults to conv_dim if not set.") - - parser.add_argument("--load_precision", type=str, choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for loading models (None means default float32)") - parser.add_argument("--save_precision", type=str, choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA/LoHA (None means float32)") - - parser.add_argument("--device", type=str, default=None, help="Device for SVD computation (e.g., 'cuda', 'cpu'). If None, defaults to 'cuda' if available, else 'cpu'. SVD results are moved to CPU for storage.") - parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights (0.0 to 1.0). 1.0 means clamp to max abs value.") - parser.add_argument("--min_diff", type=float, default=1e-6, help="Minimum weight difference threshold for a module to be considered for extraction.") # Lowered default - parser.add_argument("--no_metadata", action="store_true", help="Do not save detailed metadata (minimal ss_ metadata will still be saved).") - - parser.add_argument("--load_original_model_to", type=str, default=None, help="Device to load original model to (SDXL only, e.g., 'cpu', 'cuda:0')") - parser.add_argument("--load_tuned_model_to", type=str, default=None, help="Device to load tuned model to (SDXL only, e.g., 'cpu', 'cuda:1')") - - parser.add_argument("--dynamic_method", type=str, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], default=None, help="Dynamic method to determine rank/alpha (for the first SVD in LoHA). Overrides fixed network_alpha/conv_alpha if set.") - parser.add_argument("--dynamic_param", type=float, default=0.9, help="Parameter for the chosen dynamic_method (e.g., target ratio/cumulative sum/Frobenius norm percentage).") - parser.add_argument("--verbose", action="store_true", help="Show detailed SVD info for each module.") - + parser.add_argument("--load_precision", choices=[None, "float", "fp16", "bf16"], help="Precision for loading models") + parser.add_argument("--save_precision", choices=[None, "float", "fp16", "bf16"], default=None, help="Precision for saving LoRA") + parser.add_argument("--model_org", required=True, help="Original Stable Diffusion model (ckpt/safetensors)") + parser.add_argument("--model_tuned", required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)") + parser.add_argument("--save_to", required=True, help="Output file name (ckpt/safetensors)") + parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers") + parser.add_argument("--conv_dim", type=int, help="Max dimension (rank) of LoRA for Conv2d-3x3") + parser.add_argument("--device", default="cuda", help="Device for computation (e.g., cuda)") + parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights") + parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract") + parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata") + parser.add_argument("--load_original_model_to", help="Device for original model (SDXL only)") + parser.add_argument("--load_tuned_model_to", help="Device for tuned model (SDXL only)") + parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction") + parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info") + parser.add_argument( + "--dynamic_method", + choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"], # Added "sv_cumulative_knee" + help="Dynamic rank reduction method" + ) return parser if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() - - if args.dynamic_method and (args.dynamic_param is None): - # Default dynamic_param for methods if not specified, or raise error - if args.dynamic_method in ["sv_cumulative", "sv_fro"]: - args.dynamic_param = 0.99 # Example: 99% variance/energy - logger.info(f"Dynamic method {args.dynamic_method} chosen, dynamic_param defaulted to {args.dynamic_param}") - elif args.dynamic_method in ["sv_ratio"]: - args.dynamic_param = 1000 # Example: ratio of 1000 - logger.info(f"Dynamic method {args.dynamic_method} chosen, dynamic_param defaulted to {args.dynamic_param}") - elif args.dynamic_method in ["sv_rel_decrease"]: - args.dynamic_param = 0.05 # Example: 5% relative decrease - logger.info(f"Dynamic method {args.dynamic_method} chosen, dynamic_param defaulted to {args.dynamic_param}") - # sv_knee and sv_cumulative_knee do not require dynamic_param in this implementation. - elif args.dynamic_method not in ["sv_knee", "sv_cumulative_knee"]: - parser.error("--dynamic_method requires --dynamic_param for most methods.") - - - # Default device selection - if args.device is None: - args.device = "cuda" if torch.cuda.is_available() else "cpu" - logger.info(f"Using device: {args.device} for SVD computation.") - - # Rename args for clarity before passing to function - func_args = vars(args).copy() - func_args["model_org_path"] = func_args.pop("model_org_path") - func_args["model_tuned_path"] = func_args.pop("model_tuned_path") + methods_requiring_param = ["sv_ratio", "sv_fro", "sv_cumulative", "sv_rel_decrease"] + if args.dynamic_method in methods_requiring_param and args.dynamic_param is None: + raise ValueError(f"Dynamic method '{args.dynamic_method}' requires --dynamic_param to be set.") - svd_decomposition(**func_args) \ No newline at end of file + # Add a check for rank > 0 if not dynamic, or ensure dynamic methods return rank >= 1 + if not args.dynamic_method and args.dim <= 0: + raise ValueError(f"--dim (rank) must be > 0. Got {args.dim}") + if args.conv_dim is not None and args.conv_dim <=0: + raise ValueError(f"--conv_dim (rank) must be > 0 if specified. Got {args.conv_dim}") + + svd(**vars(args)) \ No newline at end of file