remove sd-scripts dependancies 1

pull/3264/head
bmaltais 2025-05-19 10:52:14 -04:00
parent 7767a5a3ec
commit e580ad60e9
1 changed files with 93 additions and 76 deletions

View File

@ -11,7 +11,8 @@ if sd_scripts_dir_path not in sys.path:
# Now you can import from the library package and the networks package
try:
from library import sai_model_spec, model_util, sdxl_model_util
# sai_model_spec REMOVED from here
from library import model_util, sdxl_model_util
from library.utils import setup_logging
from networks import lora
except ImportError as e:
@ -118,12 +119,47 @@ def save_to_file(file_name, state_dict_to_save, dtype, metadata=None):
else:
torch.save(state_dict_final, file_name)
# --- NEW LOCAL METADATA FUNCTION ---
def _build_local_sai_metadata(title, creation_time, is_v2_flag, is_v_param_flag, is_sdxl_flag):
"""
Creates a dictionary of SAI-like metadata based on the provided arguments,
specifically for the context of this LoRA extraction script.
Keys are aligned with common Civitai/SAI metadata practices.
"""
metadata = {}
metadata["ss_sd_model_name"] = str(title)
metadata["ss_creation_time"] = str(int(creation_time)) # Original uses int timestamp
# Determine base model version for SAI metadata
if is_sdxl_flag:
metadata["ss_base_model_version"] = "sdxl_v10" # Standard SAI identifier for SDXL 1.0
metadata["ss_sdxl_model_version"] = "1.0" # Specific SDXL version
# In SAI context, SDXL is often implicitly v-parameterization,
# but explicitly stating it if requested is good.
if is_v_param_flag:
metadata["ss_v_parameterization"] = "true"
elif is_v2_flag:
metadata["ss_base_model_version"] = "sd_v2" # Standard SAI identifier for SD 2.x
if is_v_param_flag: # v-parameterization is key for some SD2.x models
metadata["ss_v_parameterization"] = "true"
else:
metadata["ss_base_model_version"] = "sd_v1" # Standard SAI identifier for SD 1.x
# v-parameterization is less common for SD1.x but can exist
if is_v_param_flag:
metadata["ss_v_parameterization"] = "true"
# Other flags from the original sai_model_spec call were fixed (training_info=None, is_v_pred_like=False etc.)
# or their effects are covered by the flags above.
# We are only adding keys that are generally present and have meaningful values from the inputs.
# Example: "ss_is_v_prediction_like": "false" could be added if is_v_pred_like was a param,
# but it's false in the original call, so we omit it for brevity.
return metadata
# --- Refactored Helper Functions ---
def _load_sd_model_components(model_path, is_v2, load_dtype_torch):
logger.info(f"Loading SD model from: {model_path} (to CPU initially)")
# model_util usually loads to CPU by default, then we cast dtype
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(is_v2, model_path)
del vae # Not used
del vae
text_encoders = [text_encoder]
if load_dtype_torch:
for te in text_encoders:
@ -132,17 +168,14 @@ def _load_sd_model_components(model_path, is_v2, load_dtype_torch):
return text_encoders, unet
def _load_sdxl_model_components(model_path, target_device_override, load_dtype_torch):
# Prioritize CPU loading unless target_device_override is explicitly GPU
# This 'target_device_override' comes from args.load_original_model_to / args.load_tuned_model_to
actual_load_device = target_device_override if target_device_override else "cpu"
logger.info(f"Loading SDXL model from: {model_path} to device: {actual_load_device}")
text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_path, actual_load_device
)
del vae # Not used
del vae
text_encoders = [text_encoder1, text_encoder2]
if load_dtype_torch: # Apply dtype cast after loading to the specified device
if load_dtype_torch:
for te in text_encoders:
te.to(load_dtype_torch)
unet.to(load_dtype_torch)
@ -161,9 +194,6 @@ def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_
logger.warning(f"Skipping {lora_name} in {module_type_str} due to missing org_module or weight.")
continue
# Weights are expected to be on CPU after loading, or on specified load device.
# Move them to diff_calc_device ONLY if they are not already there.
# diff_calc_device will be CPU in the corrected flow.
weight_o = lora_o.org_module.weight
weight_t = lora_t.org_module.weight
@ -172,13 +202,8 @@ def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_
if str(weight_t.device) != str(diff_calc_device):
weight_t = weight_t.to(diff_calc_device)
diff = weight_t - weight_o # Diff happens on diff_calc_device (CPU)
# No need to set lora_o.org_module.weight to None here, original weights might be reused
# by other parts of sd-scripts if this script is integrated.
# We will del the entire model objects (unet_t, text_encoders_t) later.
diffs_map[lora_name] = diff # diff is on diff_calc_device (CPU)
diff = weight_t - weight_o
diffs_map[lora_name] = diff
current_max_diff = torch.max(torch.abs(diff))
if not is_different_flag and current_max_diff > min_diff_thresh:
is_different_flag = True
@ -206,9 +231,7 @@ def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, r
clamp_quantile_val, is_conv2d, is_conv2d_3x3,
conv_kernel_size,
module_out_channels, module_in_channels,
# svd_comp_device, # U,S,Vh are on this device
target_device_for_final_weights, target_dtype_for_final_weights):
# U_full, S_all_values, Vh_full are assumed to be on the SVD computation device.
S_k = S_all_values[:rank]
U_k = U_full[:, :rank]
Vh_k = Vh_full[:rank, :]
@ -216,10 +239,9 @@ def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, r
S_k_non_negative = torch.clamp(S_k, min=0.0)
s_sqrt = torch.sqrt(S_k_non_negative)
U_final = U_k * s_sqrt.unsqueeze(0) # on svd_comp_device
Vh_final = Vh_k * s_sqrt.unsqueeze(1) # on svd_comp_device
U_final = U_k * s_sqrt.unsqueeze(0)
Vh_final = Vh_k * s_sqrt.unsqueeze(1)
# Clamping happens on svd_comp_device
dist = torch.cat([U_final.flatten(), Vh_final.flatten()])
hi_val = torch.quantile(dist, clamp_quantile_val)
if hi_val == 0 and torch.max(torch.abs(dist)) > 1e-9:
@ -228,14 +250,13 @@ def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, r
U_clamped = U_final.clamp(-hi_val, hi_val)
Vh_clamped = Vh_final.clamp(-hi_val, hi_val)
if is_conv2d: # Reshaping also on svd_comp_device
if is_conv2d:
U_clamped = U_clamped.reshape(module_out_channels, rank, 1, 1)
if is_conv2d_3x3:
Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, *conv_kernel_size)
else:
Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, 1, 1)
# Move to final target device and dtype at the very end
U_clamped = U_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous()
Vh_clamped = Vh_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous()
return U_clamped, Vh_clamped
@ -245,18 +266,14 @@ def _log_svd_stats(lora_module_name, S_all_values, rank_used, min_sv_for_calc=MI
logger.info(f"{lora_module_name:75} | rank: {rank_used}, SVD not performed (empty singular values).")
return
# S_all_values might be on GPU, move to CPU for float conversion and sum if not already
S_cpu = S_all_values.to('cpu')
s_sum_total = float(torch.sum(S_cpu))
s_sum_rank = float(torch.sum(S_cpu[:rank_used]))
fro_orig_total = float(torch.sqrt(torch.sum(S_cpu.pow(2))))
fro_reconstructed_rank = float(torch.sqrt(torch.sum(S_cpu[:rank_used].pow(2))))
ratio_sv = float('inf')
if rank_used > 0 and S_cpu[rank_used - 1].abs() > min_sv_for_calc:
ratio_sv = S_cpu[0] / S_cpu[rank_used - 1] # Ensure S_cpu[0] is also float for division
ratio_sv = S_cpu[0] / S_cpu[rank_used - 1]
sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > min_sv_for_calc else 1.0
fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > min_sv_for_calc else 1.0
@ -268,7 +285,7 @@ def _log_svd_stats(lora_module_name, S_all_values, rank_used, min_sv_for_calc=MI
f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}"
)
def _prepare_lora_metadata(output_path, is_v2_flag, base_model_ver, network_conv_dim_val,
def _prepare_lora_metadata(output_path, is_v2_flag, kohya_base_model_version_str, network_conv_dim_val,
use_dynamic_method_flag, network_dim_config_val,
is_v_param_flag, is_sdxl_flag, skip_sai_meta):
net_kwargs = {"conv_dim": str(network_conv_dim_val), "conv_alpha": str(float(network_conv_dim_val))} if network_conv_dim_val is not None else {}
@ -280,28 +297,34 @@ def _prepare_lora_metadata(output_path, is_v2_flag, base_model_ver, network_conv
network_dim_meta = str(network_dim_config_val)
network_alpha_meta = str(float(network_dim_config_val))
# Initial metadata using Kohya's conventions
final_metadata = {
"ss_v2": str(is_v2_flag),
"ss_base_model_version": base_model_ver,
"ss_v2": str(is_v2_flag), # Kohya's flag for v2 checkpoint type
"ss_base_model_version": kohya_base_model_version_str, # Kohya's specific base model string
"ss_network_module": "networks.lora",
"ss_network_dim": network_dim_meta,
"ss_network_alpha": network_alpha_meta,
"ss_network_args": json.dumps(net_kwargs),
"ss_lowram": "False",
"ss_num_train_images": "N/A",
"ss_num_train_images": "N/A", # This script doesn't involve training
}
if not skip_sai_meta:
title = os.path.splitext(os.path.basename(output_path))[0]
is_sd2_for_meta = True
current_time = time.time()
sai_metadata_content = sai_model_spec.build_metadata(
training_info=None, v2=is_v2_flag, v_parameterization=is_v_param_flag,
sdxl=is_sdxl_flag, is_sd2=is_sd2_for_meta, is_v_pred_like=False,
unet_use_linear_projection_in_v2=False, creation_time=time.time(), title=title,
# Build SAI-like metadata using the local function
sai_metadata_content = _build_local_sai_metadata(
title=title,
creation_time=current_time,
is_v2_flag=is_v2_flag, # Pass the script's v2 context
is_v_param_flag=is_v_param_flag,
is_sdxl_flag=is_sdxl_flag
)
sai_metadata_cleaned = {k: v for k, v in sai_metadata_content.items() if v is not None}
final_metadata.update(sai_metadata_cleaned)
# Update final_metadata. Keys from sai_metadata_content will overwrite
# existing keys if they are the same (e.g., potentially 'ss_base_model_version').
final_metadata.update(sai_metadata_content)
return final_metadata
# --- Main SVD Function ---
@ -316,36 +339,26 @@ def svd(
load_dtype_torch = _str_to_dtype(load_precision)
save_dtype_torch = _str_to_dtype(save_precision) if save_precision else torch.float
# Device for SVD computation itself. Defaults to CUDA if available, else CPU.
svd_computation_device = torch.device(device if device else "cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using SVD computation device: {svd_computation_device}")
# Device for calculating weight differences. This should ideally be CPU to avoid GPU->CPU transfers if models loaded to CPU.
diff_calculation_device = torch.device("cpu")
logger.info(f"Calculating weight differences on: {diff_calculation_device}")
# Device for final LoRA weights before saving (usually CPU).
final_weights_device = torch.device("cpu")
# Load models
if not sdxl:
# _load_sd_model_components loads to CPU, then applies dtype
text_encoders_o, unet_o = _load_sd_model_components(model_org, v2, load_dtype_torch)
text_encoders_t, unet_t = _load_sd_model_components(model_tuned, v2, load_dtype_torch)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
# This is Kohya's specific model version string
kohya_model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
# _load_sdxl_model_components uses load_original_model_to/load_tuned_model_to if provided, otherwise defaults to CPU.
text_encoders_o, unet_o = _load_sdxl_model_components(model_org, load_original_model_to, load_dtype_torch)
text_encoders_t, unet_t = _load_sdxl_model_components(model_tuned, load_tuned_model_to, load_dtype_torch)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# This is Kohya's specific model version string for SDXL
kohya_model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# Create LoRA networks (initially with small dim for structure)
init_dim_val = 1
# Conv_dim for network creation should be based on user's conv_dim, or init_dim_val if not set
# This is for the structure of the LoRA network object.
lora_conv_dim_init = conv_dim if conv_dim is not None else init_dim_val
kwargs_lora = {"conv_dim": lora_conv_dim_init, "conv_alpha": lora_conv_dim_init} # alpha matches dim for init
kwargs_lora = {"conv_dim": lora_conv_dim_init, "conv_alpha": lora_conv_dim_init}
lora_network_o = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_o, unet_o, **kwargs_lora)
lora_network_t = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_t, unet_t, **kwargs_lora)
@ -353,7 +366,6 @@ def svd(
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), \
f"Model versions differ: {len(lora_network_o.text_encoder_loras)} vs {len(lora_network_t.text_encoder_loras)} TEs"
# Compute differences on diff_calculation_device (CPU)
all_diffs = {}
te_diffs, text_encoder_different = _calculate_module_diffs_and_check(
lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras,
@ -366,14 +378,14 @@ def svd(
logger.warning("Text encoders are considered identical based on min_diff. Not extracting TE LoRA.")
lora_network_o.text_encoder_loras = []
del text_encoders_t # Free memory
del text_encoders_t
unet_diffs, _ = _calculate_module_diffs_and_check(
lora_network_o.unet_loras, lora_network_t.unet_loras,
diff_calculation_device, min_diff, "U-Net"
)
all_diffs.update(unet_diffs) # All diffs are now on diff_calculation_device (CPU)
del lora_network_t, unet_t # Free memory
all_diffs.update(unet_diffs)
del lora_network_t, unet_t
lora_names_to_process = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
@ -385,15 +397,11 @@ def svd(
logger.warning(f"Skipping {lora_name} as no diff was calculated for it.")
continue
original_diff_tensor = all_diffs[lora_name] # This is on diff_calculation_device (CPU)
original_diff_tensor = all_diffs[lora_name]
is_conv2d_layer = len(original_diff_tensor.size()) == 4
kernel_s = original_diff_tensor.size()[2:4] if is_conv2d_layer else None
is_conv2d_3x3_layer = is_conv2d_layer and kernel_s != (1, 1)
module_true_out_channels, module_true_in_channels = original_diff_tensor.size()[0:2]
# Move diff tensor to SVD computation device, ensure it's float32 for SVD
mat_for_svd = original_diff_tensor.to(svd_computation_device, dtype=torch.float)
if is_conv2d_layer:
@ -407,7 +415,7 @@ def svd(
continue
try:
U_full, S_full, Vh_full = torch.linalg.svd(mat_for_svd) # SVD on svd_computation_device
U_full, S_full, Vh_full = torch.linalg.svd(mat_for_svd)
except Exception as e:
logger.error(f"SVD failed for {lora_name} with shape {mat_for_svd.shape}. Error: {e}")
continue
@ -422,21 +430,19 @@ def svd(
U_full, S_full, Vh_full, rank, clamp_quantile,
is_conv2d_layer, is_conv2d_3x3_layer, kernel_s,
module_true_out_channels, module_true_in_channels,
final_weights_device, save_dtype_torch # Final weights to CPU with target dtype
final_weights_device, save_dtype_torch
)
lora_weights[lora_name] = (U_clamped, Vh_clamped) # U_clamped, Vh_clamped are on final_weights_device (CPU)
lora_weights[lora_name] = (U_clamped, Vh_clamped)
if verbose:
_log_svd_stats(lora_name, S_full, rank, MIN_SV) # S_full is on svd_computation_device
_log_svd_stats(lora_name, S_full, rank, MIN_SV)
# Create state dict for LoRA (all components are on final_weights_device (CPU))
lora_sd = {}
for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype_torch, device=final_weights_device)
# Clean up original models from memory if they are still around and large (especially if on GPU)
del text_encoders_o, unet_o, lora_network_o, all_diffs
if 'torch' in sys.modules and hasattr(torch, 'cuda') and torch.cuda.is_available():
torch.cuda.empty_cache()
@ -444,12 +450,18 @@ def svd(
os.makedirs(os.path.dirname(save_to), exist_ok=True)
metadata_to_save = _prepare_lora_metadata(
save_to, v2, model_version, conv_dim,
bool(dynamic_method), dim,
v_parameterization, sdxl, no_metadata
output_path=save_to,
is_v2_flag=v2, # The script's --v2 flag
kohya_base_model_version_str=kohya_model_version, # The specific version string from model_util / sdxl_model_util
network_conv_dim_val=conv_dim,
use_dynamic_method_flag=bool(dynamic_method),
network_dim_config_val=dim, # 'dim' is the general network dim if not dynamic
is_v_param_flag=v_parameterization, # The script's v_param flag
is_sdxl_flag=sdxl, # The script's --sdxl flag
skip_sai_meta=no_metadata
)
save_to_file(save_to, lora_sd, save_dtype_torch, metadata_to_save) # save_dtype_torch applied again if not None
save_to_file(save_to, lora_sd, save_dtype_torch, metadata_to_save)
logger.info(f"LoRA saved to: {save_to}")
@ -498,4 +510,9 @@ if __name__ == "__main__":
if MIN_SV <= 0: logger.warning(f"Global MIN_SV ({MIN_SV}) should be positive.")
svd(**vars(args))
# Pass the correct CLI arguments to the svd function
# Note: 'v2', 'sdxl', 'v_parameterization' are directly from args
# 'kohya_model_version' is determined inside svd() and then passed to _prepare_lora_metadata
svd_args = vars(args).copy()
# No change needed here as svd() internally determines kohya_model_version and passes it correctly
svd(**svd_args)