diff --git a/tools/extract_lora_from_models-nw.py b/tools/extract_lora_from_models-nw.py index 45ff5df..eb1d834 100644 --- a/tools/extract_lora_from_models-nw.py +++ b/tools/extract_lora_from_models-nw.py @@ -11,7 +11,8 @@ if sd_scripts_dir_path not in sys.path: # Now you can import from the library package and the networks package try: - from library import sai_model_spec, model_util, sdxl_model_util + # sai_model_spec REMOVED from here + from library import model_util, sdxl_model_util from library.utils import setup_logging from networks import lora except ImportError as e: @@ -118,12 +119,47 @@ def save_to_file(file_name, state_dict_to_save, dtype, metadata=None): else: torch.save(state_dict_final, file_name) +# --- NEW LOCAL METADATA FUNCTION --- +def _build_local_sai_metadata(title, creation_time, is_v2_flag, is_v_param_flag, is_sdxl_flag): + """ + Creates a dictionary of SAI-like metadata based on the provided arguments, + specifically for the context of this LoRA extraction script. + Keys are aligned with common Civitai/SAI metadata practices. + """ + metadata = {} + metadata["ss_sd_model_name"] = str(title) + metadata["ss_creation_time"] = str(int(creation_time)) # Original uses int timestamp + + # Determine base model version for SAI metadata + if is_sdxl_flag: + metadata["ss_base_model_version"] = "sdxl_v10" # Standard SAI identifier for SDXL 1.0 + metadata["ss_sdxl_model_version"] = "1.0" # Specific SDXL version + # In SAI context, SDXL is often implicitly v-parameterization, + # but explicitly stating it if requested is good. + if is_v_param_flag: + metadata["ss_v_parameterization"] = "true" + elif is_v2_flag: + metadata["ss_base_model_version"] = "sd_v2" # Standard SAI identifier for SD 2.x + if is_v_param_flag: # v-parameterization is key for some SD2.x models + metadata["ss_v_parameterization"] = "true" + else: + metadata["ss_base_model_version"] = "sd_v1" # Standard SAI identifier for SD 1.x + # v-parameterization is less common for SD1.x but can exist + if is_v_param_flag: + metadata["ss_v_parameterization"] = "true" + + # Other flags from the original sai_model_spec call were fixed (training_info=None, is_v_pred_like=False etc.) + # or their effects are covered by the flags above. + # We are only adding keys that are generally present and have meaningful values from the inputs. + # Example: "ss_is_v_prediction_like": "false" could be added if is_v_pred_like was a param, + # but it's false in the original call, so we omit it for brevity. + return metadata + # --- Refactored Helper Functions --- def _load_sd_model_components(model_path, is_v2, load_dtype_torch): logger.info(f"Loading SD model from: {model_path} (to CPU initially)") - # model_util usually loads to CPU by default, then we cast dtype text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(is_v2, model_path) - del vae # Not used + del vae text_encoders = [text_encoder] if load_dtype_torch: for te in text_encoders: @@ -132,17 +168,14 @@ def _load_sd_model_components(model_path, is_v2, load_dtype_torch): return text_encoders, unet def _load_sdxl_model_components(model_path, target_device_override, load_dtype_torch): - # Prioritize CPU loading unless target_device_override is explicitly GPU - # This 'target_device_override' comes from args.load_original_model_to / args.load_tuned_model_to actual_load_device = target_device_override if target_device_override else "cpu" logger.info(f"Loading SDXL model from: {model_path} to device: {actual_load_device}") - text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_path, actual_load_device ) - del vae # Not used + del vae text_encoders = [text_encoder1, text_encoder2] - if load_dtype_torch: # Apply dtype cast after loading to the specified device + if load_dtype_torch: for te in text_encoders: te.to(load_dtype_torch) unet.to(load_dtype_torch) @@ -161,9 +194,6 @@ def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_ logger.warning(f"Skipping {lora_name} in {module_type_str} due to missing org_module or weight.") continue - # Weights are expected to be on CPU after loading, or on specified load device. - # Move them to diff_calc_device ONLY if they are not already there. - # diff_calc_device will be CPU in the corrected flow. weight_o = lora_o.org_module.weight weight_t = lora_t.org_module.weight @@ -172,13 +202,8 @@ def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_ if str(weight_t.device) != str(diff_calc_device): weight_t = weight_t.to(diff_calc_device) - diff = weight_t - weight_o # Diff happens on diff_calc_device (CPU) - - # No need to set lora_o.org_module.weight to None here, original weights might be reused - # by other parts of sd-scripts if this script is integrated. - # We will del the entire model objects (unet_t, text_encoders_t) later. - - diffs_map[lora_name] = diff # diff is on diff_calc_device (CPU) + diff = weight_t - weight_o + diffs_map[lora_name] = diff current_max_diff = torch.max(torch.abs(diff)) if not is_different_flag and current_max_diff > min_diff_thresh: is_different_flag = True @@ -206,9 +231,7 @@ def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, r clamp_quantile_val, is_conv2d, is_conv2d_3x3, conv_kernel_size, module_out_channels, module_in_channels, - # svd_comp_device, # U,S,Vh are on this device target_device_for_final_weights, target_dtype_for_final_weights): - # U_full, S_all_values, Vh_full are assumed to be on the SVD computation device. S_k = S_all_values[:rank] U_k = U_full[:, :rank] Vh_k = Vh_full[:rank, :] @@ -216,10 +239,9 @@ def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, r S_k_non_negative = torch.clamp(S_k, min=0.0) s_sqrt = torch.sqrt(S_k_non_negative) - U_final = U_k * s_sqrt.unsqueeze(0) # on svd_comp_device - Vh_final = Vh_k * s_sqrt.unsqueeze(1) # on svd_comp_device + U_final = U_k * s_sqrt.unsqueeze(0) + Vh_final = Vh_k * s_sqrt.unsqueeze(1) - # Clamping happens on svd_comp_device dist = torch.cat([U_final.flatten(), Vh_final.flatten()]) hi_val = torch.quantile(dist, clamp_quantile_val) if hi_val == 0 and torch.max(torch.abs(dist)) > 1e-9: @@ -228,14 +250,13 @@ def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, r U_clamped = U_final.clamp(-hi_val, hi_val) Vh_clamped = Vh_final.clamp(-hi_val, hi_val) - if is_conv2d: # Reshaping also on svd_comp_device + if is_conv2d: U_clamped = U_clamped.reshape(module_out_channels, rank, 1, 1) if is_conv2d_3x3: Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, *conv_kernel_size) else: Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, 1, 1) - # Move to final target device and dtype at the very end U_clamped = U_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous() Vh_clamped = Vh_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous() return U_clamped, Vh_clamped @@ -245,18 +266,14 @@ def _log_svd_stats(lora_module_name, S_all_values, rank_used, min_sv_for_calc=MI logger.info(f"{lora_module_name:75} | rank: {rank_used}, SVD not performed (empty singular values).") return - # S_all_values might be on GPU, move to CPU for float conversion and sum if not already S_cpu = S_all_values.to('cpu') - s_sum_total = float(torch.sum(S_cpu)) s_sum_rank = float(torch.sum(S_cpu[:rank_used])) - fro_orig_total = float(torch.sqrt(torch.sum(S_cpu.pow(2)))) fro_reconstructed_rank = float(torch.sqrt(torch.sum(S_cpu[:rank_used].pow(2)))) - ratio_sv = float('inf') if rank_used > 0 and S_cpu[rank_used - 1].abs() > min_sv_for_calc: - ratio_sv = S_cpu[0] / S_cpu[rank_used - 1] # Ensure S_cpu[0] is also float for division + ratio_sv = S_cpu[0] / S_cpu[rank_used - 1] sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > min_sv_for_calc else 1.0 fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > min_sv_for_calc else 1.0 @@ -268,7 +285,7 @@ def _log_svd_stats(lora_module_name, S_all_values, rank_used, min_sv_for_calc=MI f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}" ) -def _prepare_lora_metadata(output_path, is_v2_flag, base_model_ver, network_conv_dim_val, +def _prepare_lora_metadata(output_path, is_v2_flag, kohya_base_model_version_str, network_conv_dim_val, use_dynamic_method_flag, network_dim_config_val, is_v_param_flag, is_sdxl_flag, skip_sai_meta): net_kwargs = {"conv_dim": str(network_conv_dim_val), "conv_alpha": str(float(network_conv_dim_val))} if network_conv_dim_val is not None else {} @@ -280,28 +297,34 @@ def _prepare_lora_metadata(output_path, is_v2_flag, base_model_ver, network_conv network_dim_meta = str(network_dim_config_val) network_alpha_meta = str(float(network_dim_config_val)) + # Initial metadata using Kohya's conventions final_metadata = { - "ss_v2": str(is_v2_flag), - "ss_base_model_version": base_model_ver, + "ss_v2": str(is_v2_flag), # Kohya's flag for v2 checkpoint type + "ss_base_model_version": kohya_base_model_version_str, # Kohya's specific base model string "ss_network_module": "networks.lora", "ss_network_dim": network_dim_meta, "ss_network_alpha": network_alpha_meta, "ss_network_args": json.dumps(net_kwargs), "ss_lowram": "False", - "ss_num_train_images": "N/A", + "ss_num_train_images": "N/A", # This script doesn't involve training } if not skip_sai_meta: title = os.path.splitext(os.path.basename(output_path))[0] - is_sd2_for_meta = True + current_time = time.time() - sai_metadata_content = sai_model_spec.build_metadata( - training_info=None, v2=is_v2_flag, v_parameterization=is_v_param_flag, - sdxl=is_sdxl_flag, is_sd2=is_sd2_for_meta, is_v_pred_like=False, - unet_use_linear_projection_in_v2=False, creation_time=time.time(), title=title, + # Build SAI-like metadata using the local function + sai_metadata_content = _build_local_sai_metadata( + title=title, + creation_time=current_time, + is_v2_flag=is_v2_flag, # Pass the script's v2 context + is_v_param_flag=is_v_param_flag, + is_sdxl_flag=is_sdxl_flag ) - sai_metadata_cleaned = {k: v for k, v in sai_metadata_content.items() if v is not None} - final_metadata.update(sai_metadata_cleaned) + # Update final_metadata. Keys from sai_metadata_content will overwrite + # existing keys if they are the same (e.g., potentially 'ss_base_model_version'). + final_metadata.update(sai_metadata_content) + return final_metadata # --- Main SVD Function --- @@ -316,36 +339,26 @@ def svd( load_dtype_torch = _str_to_dtype(load_precision) save_dtype_torch = _str_to_dtype(save_precision) if save_precision else torch.float - # Device for SVD computation itself. Defaults to CUDA if available, else CPU. svd_computation_device = torch.device(device if device else "cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using SVD computation device: {svd_computation_device}") - - # Device for calculating weight differences. This should ideally be CPU to avoid GPU->CPU transfers if models loaded to CPU. diff_calculation_device = torch.device("cpu") logger.info(f"Calculating weight differences on: {diff_calculation_device}") - - # Device for final LoRA weights before saving (usually CPU). final_weights_device = torch.device("cpu") - - # Load models if not sdxl: - # _load_sd_model_components loads to CPU, then applies dtype text_encoders_o, unet_o = _load_sd_model_components(model_org, v2, load_dtype_torch) text_encoders_t, unet_t = _load_sd_model_components(model_tuned, v2, load_dtype_torch) - model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) + # This is Kohya's specific model version string + kohya_model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) else: - # _load_sdxl_model_components uses load_original_model_to/load_tuned_model_to if provided, otherwise defaults to CPU. text_encoders_o, unet_o = _load_sdxl_model_components(model_org, load_original_model_to, load_dtype_torch) text_encoders_t, unet_t = _load_sdxl_model_components(model_tuned, load_tuned_model_to, load_dtype_torch) - model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 + # This is Kohya's specific model version string for SDXL + kohya_model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 - # Create LoRA networks (initially with small dim for structure) init_dim_val = 1 - # Conv_dim for network creation should be based on user's conv_dim, or init_dim_val if not set - # This is for the structure of the LoRA network object. lora_conv_dim_init = conv_dim if conv_dim is not None else init_dim_val - kwargs_lora = {"conv_dim": lora_conv_dim_init, "conv_alpha": lora_conv_dim_init} # alpha matches dim for init + kwargs_lora = {"conv_dim": lora_conv_dim_init, "conv_alpha": lora_conv_dim_init} lora_network_o = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_o, unet_o, **kwargs_lora) lora_network_t = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_t, unet_t, **kwargs_lora) @@ -353,7 +366,6 @@ def svd( assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), \ f"Model versions differ: {len(lora_network_o.text_encoder_loras)} vs {len(lora_network_t.text_encoder_loras)} TEs" - # Compute differences on diff_calculation_device (CPU) all_diffs = {} te_diffs, text_encoder_different = _calculate_module_diffs_and_check( lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras, @@ -366,14 +378,14 @@ def svd( logger.warning("Text encoders are considered identical based on min_diff. Not extracting TE LoRA.") lora_network_o.text_encoder_loras = [] - del text_encoders_t # Free memory + del text_encoders_t unet_diffs, _ = _calculate_module_diffs_and_check( lora_network_o.unet_loras, lora_network_t.unet_loras, diff_calculation_device, min_diff, "U-Net" ) - all_diffs.update(unet_diffs) # All diffs are now on diff_calculation_device (CPU) - del lora_network_t, unet_t # Free memory + all_diffs.update(unet_diffs) + del lora_network_t, unet_t lora_names_to_process = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras) @@ -385,15 +397,11 @@ def svd( logger.warning(f"Skipping {lora_name} as no diff was calculated for it.") continue - original_diff_tensor = all_diffs[lora_name] # This is on diff_calculation_device (CPU) - + original_diff_tensor = all_diffs[lora_name] is_conv2d_layer = len(original_diff_tensor.size()) == 4 kernel_s = original_diff_tensor.size()[2:4] if is_conv2d_layer else None is_conv2d_3x3_layer = is_conv2d_layer and kernel_s != (1, 1) - module_true_out_channels, module_true_in_channels = original_diff_tensor.size()[0:2] - - # Move diff tensor to SVD computation device, ensure it's float32 for SVD mat_for_svd = original_diff_tensor.to(svd_computation_device, dtype=torch.float) if is_conv2d_layer: @@ -407,7 +415,7 @@ def svd( continue try: - U_full, S_full, Vh_full = torch.linalg.svd(mat_for_svd) # SVD on svd_computation_device + U_full, S_full, Vh_full = torch.linalg.svd(mat_for_svd) except Exception as e: logger.error(f"SVD failed for {lora_name} with shape {mat_for_svd.shape}. Error: {e}") continue @@ -422,21 +430,19 @@ def svd( U_full, S_full, Vh_full, rank, clamp_quantile, is_conv2d_layer, is_conv2d_3x3_layer, kernel_s, module_true_out_channels, module_true_in_channels, - final_weights_device, save_dtype_torch # Final weights to CPU with target dtype + final_weights_device, save_dtype_torch ) - lora_weights[lora_name] = (U_clamped, Vh_clamped) # U_clamped, Vh_clamped are on final_weights_device (CPU) + lora_weights[lora_name] = (U_clamped, Vh_clamped) if verbose: - _log_svd_stats(lora_name, S_full, rank, MIN_SV) # S_full is on svd_computation_device + _log_svd_stats(lora_name, S_full, rank, MIN_SV) - # Create state dict for LoRA (all components are on final_weights_device (CPU)) lora_sd = {} for lora_name, (up_weight, down_weight) in lora_weights.items(): lora_sd[lora_name + ".lora_up.weight"] = up_weight lora_sd[lora_name + ".lora_down.weight"] = down_weight lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype_torch, device=final_weights_device) - # Clean up original models from memory if they are still around and large (especially if on GPU) del text_encoders_o, unet_o, lora_network_o, all_diffs if 'torch' in sys.modules and hasattr(torch, 'cuda') and torch.cuda.is_available(): torch.cuda.empty_cache() @@ -444,12 +450,18 @@ def svd( os.makedirs(os.path.dirname(save_to), exist_ok=True) metadata_to_save = _prepare_lora_metadata( - save_to, v2, model_version, conv_dim, - bool(dynamic_method), dim, - v_parameterization, sdxl, no_metadata + output_path=save_to, + is_v2_flag=v2, # The script's --v2 flag + kohya_base_model_version_str=kohya_model_version, # The specific version string from model_util / sdxl_model_util + network_conv_dim_val=conv_dim, + use_dynamic_method_flag=bool(dynamic_method), + network_dim_config_val=dim, # 'dim' is the general network dim if not dynamic + is_v_param_flag=v_parameterization, # The script's v_param flag + is_sdxl_flag=sdxl, # The script's --sdxl flag + skip_sai_meta=no_metadata ) - save_to_file(save_to, lora_sd, save_dtype_torch, metadata_to_save) # save_dtype_torch applied again if not None + save_to_file(save_to, lora_sd, save_dtype_torch, metadata_to_save) logger.info(f"LoRA saved to: {save_to}") @@ -498,4 +510,9 @@ if __name__ == "__main__": if MIN_SV <= 0: logger.warning(f"Global MIN_SV ({MIN_SV}) should be positive.") - svd(**vars(args)) + # Pass the correct CLI arguments to the svd function + # Note: 'v2', 'sdxl', 'v_parameterization' are directly from args + # 'kohya_model_version' is determined inside svd() and then passed to _prepare_lora_metadata + svd_args = vars(args).copy() + # No change needed here as svd() internally determines kohya_model_version and passes it correctly + svd(**svd_args) \ No newline at end of file