mirror of https://github.com/bmaltais/kohya_ss
remove sd-scripts dependancies 1
parent
7767a5a3ec
commit
e580ad60e9
|
|
@ -11,7 +11,8 @@ if sd_scripts_dir_path not in sys.path:
|
|||
|
||||
# Now you can import from the library package and the networks package
|
||||
try:
|
||||
from library import sai_model_spec, model_util, sdxl_model_util
|
||||
# sai_model_spec REMOVED from here
|
||||
from library import model_util, sdxl_model_util
|
||||
from library.utils import setup_logging
|
||||
from networks import lora
|
||||
except ImportError as e:
|
||||
|
|
@ -118,12 +119,47 @@ def save_to_file(file_name, state_dict_to_save, dtype, metadata=None):
|
|||
else:
|
||||
torch.save(state_dict_final, file_name)
|
||||
|
||||
# --- NEW LOCAL METADATA FUNCTION ---
|
||||
def _build_local_sai_metadata(title, creation_time, is_v2_flag, is_v_param_flag, is_sdxl_flag):
|
||||
"""
|
||||
Creates a dictionary of SAI-like metadata based on the provided arguments,
|
||||
specifically for the context of this LoRA extraction script.
|
||||
Keys are aligned with common Civitai/SAI metadata practices.
|
||||
"""
|
||||
metadata = {}
|
||||
metadata["ss_sd_model_name"] = str(title)
|
||||
metadata["ss_creation_time"] = str(int(creation_time)) # Original uses int timestamp
|
||||
|
||||
# Determine base model version for SAI metadata
|
||||
if is_sdxl_flag:
|
||||
metadata["ss_base_model_version"] = "sdxl_v10" # Standard SAI identifier for SDXL 1.0
|
||||
metadata["ss_sdxl_model_version"] = "1.0" # Specific SDXL version
|
||||
# In SAI context, SDXL is often implicitly v-parameterization,
|
||||
# but explicitly stating it if requested is good.
|
||||
if is_v_param_flag:
|
||||
metadata["ss_v_parameterization"] = "true"
|
||||
elif is_v2_flag:
|
||||
metadata["ss_base_model_version"] = "sd_v2" # Standard SAI identifier for SD 2.x
|
||||
if is_v_param_flag: # v-parameterization is key for some SD2.x models
|
||||
metadata["ss_v_parameterization"] = "true"
|
||||
else:
|
||||
metadata["ss_base_model_version"] = "sd_v1" # Standard SAI identifier for SD 1.x
|
||||
# v-parameterization is less common for SD1.x but can exist
|
||||
if is_v_param_flag:
|
||||
metadata["ss_v_parameterization"] = "true"
|
||||
|
||||
# Other flags from the original sai_model_spec call were fixed (training_info=None, is_v_pred_like=False etc.)
|
||||
# or their effects are covered by the flags above.
|
||||
# We are only adding keys that are generally present and have meaningful values from the inputs.
|
||||
# Example: "ss_is_v_prediction_like": "false" could be added if is_v_pred_like was a param,
|
||||
# but it's false in the original call, so we omit it for brevity.
|
||||
return metadata
|
||||
|
||||
# --- Refactored Helper Functions ---
|
||||
def _load_sd_model_components(model_path, is_v2, load_dtype_torch):
|
||||
logger.info(f"Loading SD model from: {model_path} (to CPU initially)")
|
||||
# model_util usually loads to CPU by default, then we cast dtype
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(is_v2, model_path)
|
||||
del vae # Not used
|
||||
del vae
|
||||
text_encoders = [text_encoder]
|
||||
if load_dtype_torch:
|
||||
for te in text_encoders:
|
||||
|
|
@ -132,17 +168,14 @@ def _load_sd_model_components(model_path, is_v2, load_dtype_torch):
|
|||
return text_encoders, unet
|
||||
|
||||
def _load_sdxl_model_components(model_path, target_device_override, load_dtype_torch):
|
||||
# Prioritize CPU loading unless target_device_override is explicitly GPU
|
||||
# This 'target_device_override' comes from args.load_original_model_to / args.load_tuned_model_to
|
||||
actual_load_device = target_device_override if target_device_override else "cpu"
|
||||
logger.info(f"Loading SDXL model from: {model_path} to device: {actual_load_device}")
|
||||
|
||||
text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_path, actual_load_device
|
||||
)
|
||||
del vae # Not used
|
||||
del vae
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
if load_dtype_torch: # Apply dtype cast after loading to the specified device
|
||||
if load_dtype_torch:
|
||||
for te in text_encoders:
|
||||
te.to(load_dtype_torch)
|
||||
unet.to(load_dtype_torch)
|
||||
|
|
@ -161,9 +194,6 @@ def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_
|
|||
logger.warning(f"Skipping {lora_name} in {module_type_str} due to missing org_module or weight.")
|
||||
continue
|
||||
|
||||
# Weights are expected to be on CPU after loading, or on specified load device.
|
||||
# Move them to diff_calc_device ONLY if they are not already there.
|
||||
# diff_calc_device will be CPU in the corrected flow.
|
||||
weight_o = lora_o.org_module.weight
|
||||
weight_t = lora_t.org_module.weight
|
||||
|
||||
|
|
@ -172,13 +202,8 @@ def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_
|
|||
if str(weight_t.device) != str(diff_calc_device):
|
||||
weight_t = weight_t.to(diff_calc_device)
|
||||
|
||||
diff = weight_t - weight_o # Diff happens on diff_calc_device (CPU)
|
||||
|
||||
# No need to set lora_o.org_module.weight to None here, original weights might be reused
|
||||
# by other parts of sd-scripts if this script is integrated.
|
||||
# We will del the entire model objects (unet_t, text_encoders_t) later.
|
||||
|
||||
diffs_map[lora_name] = diff # diff is on diff_calc_device (CPU)
|
||||
diff = weight_t - weight_o
|
||||
diffs_map[lora_name] = diff
|
||||
current_max_diff = torch.max(torch.abs(diff))
|
||||
if not is_different_flag and current_max_diff > min_diff_thresh:
|
||||
is_different_flag = True
|
||||
|
|
@ -206,9 +231,7 @@ def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, r
|
|||
clamp_quantile_val, is_conv2d, is_conv2d_3x3,
|
||||
conv_kernel_size,
|
||||
module_out_channels, module_in_channels,
|
||||
# svd_comp_device, # U,S,Vh are on this device
|
||||
target_device_for_final_weights, target_dtype_for_final_weights):
|
||||
# U_full, S_all_values, Vh_full are assumed to be on the SVD computation device.
|
||||
S_k = S_all_values[:rank]
|
||||
U_k = U_full[:, :rank]
|
||||
Vh_k = Vh_full[:rank, :]
|
||||
|
|
@ -216,10 +239,9 @@ def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, r
|
|||
S_k_non_negative = torch.clamp(S_k, min=0.0)
|
||||
s_sqrt = torch.sqrt(S_k_non_negative)
|
||||
|
||||
U_final = U_k * s_sqrt.unsqueeze(0) # on svd_comp_device
|
||||
Vh_final = Vh_k * s_sqrt.unsqueeze(1) # on svd_comp_device
|
||||
U_final = U_k * s_sqrt.unsqueeze(0)
|
||||
Vh_final = Vh_k * s_sqrt.unsqueeze(1)
|
||||
|
||||
# Clamping happens on svd_comp_device
|
||||
dist = torch.cat([U_final.flatten(), Vh_final.flatten()])
|
||||
hi_val = torch.quantile(dist, clamp_quantile_val)
|
||||
if hi_val == 0 and torch.max(torch.abs(dist)) > 1e-9:
|
||||
|
|
@ -228,14 +250,13 @@ def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, r
|
|||
U_clamped = U_final.clamp(-hi_val, hi_val)
|
||||
Vh_clamped = Vh_final.clamp(-hi_val, hi_val)
|
||||
|
||||
if is_conv2d: # Reshaping also on svd_comp_device
|
||||
if is_conv2d:
|
||||
U_clamped = U_clamped.reshape(module_out_channels, rank, 1, 1)
|
||||
if is_conv2d_3x3:
|
||||
Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, *conv_kernel_size)
|
||||
else:
|
||||
Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, 1, 1)
|
||||
|
||||
# Move to final target device and dtype at the very end
|
||||
U_clamped = U_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous()
|
||||
Vh_clamped = Vh_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous()
|
||||
return U_clamped, Vh_clamped
|
||||
|
|
@ -245,18 +266,14 @@ def _log_svd_stats(lora_module_name, S_all_values, rank_used, min_sv_for_calc=MI
|
|||
logger.info(f"{lora_module_name:75} | rank: {rank_used}, SVD not performed (empty singular values).")
|
||||
return
|
||||
|
||||
# S_all_values might be on GPU, move to CPU for float conversion and sum if not already
|
||||
S_cpu = S_all_values.to('cpu')
|
||||
|
||||
s_sum_total = float(torch.sum(S_cpu))
|
||||
s_sum_rank = float(torch.sum(S_cpu[:rank_used]))
|
||||
|
||||
fro_orig_total = float(torch.sqrt(torch.sum(S_cpu.pow(2))))
|
||||
fro_reconstructed_rank = float(torch.sqrt(torch.sum(S_cpu[:rank_used].pow(2))))
|
||||
|
||||
ratio_sv = float('inf')
|
||||
if rank_used > 0 and S_cpu[rank_used - 1].abs() > min_sv_for_calc:
|
||||
ratio_sv = S_cpu[0] / S_cpu[rank_used - 1] # Ensure S_cpu[0] is also float for division
|
||||
ratio_sv = S_cpu[0] / S_cpu[rank_used - 1]
|
||||
|
||||
sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > min_sv_for_calc else 1.0
|
||||
fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > min_sv_for_calc else 1.0
|
||||
|
|
@ -268,7 +285,7 @@ def _log_svd_stats(lora_module_name, S_all_values, rank_used, min_sv_for_calc=MI
|
|||
f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}"
|
||||
)
|
||||
|
||||
def _prepare_lora_metadata(output_path, is_v2_flag, base_model_ver, network_conv_dim_val,
|
||||
def _prepare_lora_metadata(output_path, is_v2_flag, kohya_base_model_version_str, network_conv_dim_val,
|
||||
use_dynamic_method_flag, network_dim_config_val,
|
||||
is_v_param_flag, is_sdxl_flag, skip_sai_meta):
|
||||
net_kwargs = {"conv_dim": str(network_conv_dim_val), "conv_alpha": str(float(network_conv_dim_val))} if network_conv_dim_val is not None else {}
|
||||
|
|
@ -280,28 +297,34 @@ def _prepare_lora_metadata(output_path, is_v2_flag, base_model_ver, network_conv
|
|||
network_dim_meta = str(network_dim_config_val)
|
||||
network_alpha_meta = str(float(network_dim_config_val))
|
||||
|
||||
# Initial metadata using Kohya's conventions
|
||||
final_metadata = {
|
||||
"ss_v2": str(is_v2_flag),
|
||||
"ss_base_model_version": base_model_ver,
|
||||
"ss_v2": str(is_v2_flag), # Kohya's flag for v2 checkpoint type
|
||||
"ss_base_model_version": kohya_base_model_version_str, # Kohya's specific base model string
|
||||
"ss_network_module": "networks.lora",
|
||||
"ss_network_dim": network_dim_meta,
|
||||
"ss_network_alpha": network_alpha_meta,
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
"ss_lowram": "False",
|
||||
"ss_num_train_images": "N/A",
|
||||
"ss_num_train_images": "N/A", # This script doesn't involve training
|
||||
}
|
||||
|
||||
if not skip_sai_meta:
|
||||
title = os.path.splitext(os.path.basename(output_path))[0]
|
||||
is_sd2_for_meta = True
|
||||
current_time = time.time()
|
||||
|
||||
sai_metadata_content = sai_model_spec.build_metadata(
|
||||
training_info=None, v2=is_v2_flag, v_parameterization=is_v_param_flag,
|
||||
sdxl=is_sdxl_flag, is_sd2=is_sd2_for_meta, is_v_pred_like=False,
|
||||
unet_use_linear_projection_in_v2=False, creation_time=time.time(), title=title,
|
||||
# Build SAI-like metadata using the local function
|
||||
sai_metadata_content = _build_local_sai_metadata(
|
||||
title=title,
|
||||
creation_time=current_time,
|
||||
is_v2_flag=is_v2_flag, # Pass the script's v2 context
|
||||
is_v_param_flag=is_v_param_flag,
|
||||
is_sdxl_flag=is_sdxl_flag
|
||||
)
|
||||
sai_metadata_cleaned = {k: v for k, v in sai_metadata_content.items() if v is not None}
|
||||
final_metadata.update(sai_metadata_cleaned)
|
||||
# Update final_metadata. Keys from sai_metadata_content will overwrite
|
||||
# existing keys if they are the same (e.g., potentially 'ss_base_model_version').
|
||||
final_metadata.update(sai_metadata_content)
|
||||
|
||||
return final_metadata
|
||||
|
||||
# --- Main SVD Function ---
|
||||
|
|
@ -316,36 +339,26 @@ def svd(
|
|||
load_dtype_torch = _str_to_dtype(load_precision)
|
||||
save_dtype_torch = _str_to_dtype(save_precision) if save_precision else torch.float
|
||||
|
||||
# Device for SVD computation itself. Defaults to CUDA if available, else CPU.
|
||||
svd_computation_device = torch.device(device if device else "cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Using SVD computation device: {svd_computation_device}")
|
||||
|
||||
# Device for calculating weight differences. This should ideally be CPU to avoid GPU->CPU transfers if models loaded to CPU.
|
||||
diff_calculation_device = torch.device("cpu")
|
||||
logger.info(f"Calculating weight differences on: {diff_calculation_device}")
|
||||
|
||||
# Device for final LoRA weights before saving (usually CPU).
|
||||
final_weights_device = torch.device("cpu")
|
||||
|
||||
|
||||
# Load models
|
||||
if not sdxl:
|
||||
# _load_sd_model_components loads to CPU, then applies dtype
|
||||
text_encoders_o, unet_o = _load_sd_model_components(model_org, v2, load_dtype_torch)
|
||||
text_encoders_t, unet_t = _load_sd_model_components(model_tuned, v2, load_dtype_torch)
|
||||
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||
# This is Kohya's specific model version string
|
||||
kohya_model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||
else:
|
||||
# _load_sdxl_model_components uses load_original_model_to/load_tuned_model_to if provided, otherwise defaults to CPU.
|
||||
text_encoders_o, unet_o = _load_sdxl_model_components(model_org, load_original_model_to, load_dtype_torch)
|
||||
text_encoders_t, unet_t = _load_sdxl_model_components(model_tuned, load_tuned_model_to, load_dtype_torch)
|
||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||
# This is Kohya's specific model version string for SDXL
|
||||
kohya_model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# Create LoRA networks (initially with small dim for structure)
|
||||
init_dim_val = 1
|
||||
# Conv_dim for network creation should be based on user's conv_dim, or init_dim_val if not set
|
||||
# This is for the structure of the LoRA network object.
|
||||
lora_conv_dim_init = conv_dim if conv_dim is not None else init_dim_val
|
||||
kwargs_lora = {"conv_dim": lora_conv_dim_init, "conv_alpha": lora_conv_dim_init} # alpha matches dim for init
|
||||
kwargs_lora = {"conv_dim": lora_conv_dim_init, "conv_alpha": lora_conv_dim_init}
|
||||
|
||||
lora_network_o = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_o, unet_o, **kwargs_lora)
|
||||
lora_network_t = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_t, unet_t, **kwargs_lora)
|
||||
|
|
@ -353,7 +366,6 @@ def svd(
|
|||
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), \
|
||||
f"Model versions differ: {len(lora_network_o.text_encoder_loras)} vs {len(lora_network_t.text_encoder_loras)} TEs"
|
||||
|
||||
# Compute differences on diff_calculation_device (CPU)
|
||||
all_diffs = {}
|
||||
te_diffs, text_encoder_different = _calculate_module_diffs_and_check(
|
||||
lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras,
|
||||
|
|
@ -366,14 +378,14 @@ def svd(
|
|||
logger.warning("Text encoders are considered identical based on min_diff. Not extracting TE LoRA.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
|
||||
del text_encoders_t # Free memory
|
||||
del text_encoders_t
|
||||
|
||||
unet_diffs, _ = _calculate_module_diffs_and_check(
|
||||
lora_network_o.unet_loras, lora_network_t.unet_loras,
|
||||
diff_calculation_device, min_diff, "U-Net"
|
||||
)
|
||||
all_diffs.update(unet_diffs) # All diffs are now on diff_calculation_device (CPU)
|
||||
del lora_network_t, unet_t # Free memory
|
||||
all_diffs.update(unet_diffs)
|
||||
del lora_network_t, unet_t
|
||||
|
||||
lora_names_to_process = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
|
||||
|
||||
|
|
@ -385,15 +397,11 @@ def svd(
|
|||
logger.warning(f"Skipping {lora_name} as no diff was calculated for it.")
|
||||
continue
|
||||
|
||||
original_diff_tensor = all_diffs[lora_name] # This is on diff_calculation_device (CPU)
|
||||
|
||||
original_diff_tensor = all_diffs[lora_name]
|
||||
is_conv2d_layer = len(original_diff_tensor.size()) == 4
|
||||
kernel_s = original_diff_tensor.size()[2:4] if is_conv2d_layer else None
|
||||
is_conv2d_3x3_layer = is_conv2d_layer and kernel_s != (1, 1)
|
||||
|
||||
module_true_out_channels, module_true_in_channels = original_diff_tensor.size()[0:2]
|
||||
|
||||
# Move diff tensor to SVD computation device, ensure it's float32 for SVD
|
||||
mat_for_svd = original_diff_tensor.to(svd_computation_device, dtype=torch.float)
|
||||
|
||||
if is_conv2d_layer:
|
||||
|
|
@ -407,7 +415,7 @@ def svd(
|
|||
continue
|
||||
|
||||
try:
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(mat_for_svd) # SVD on svd_computation_device
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(mat_for_svd)
|
||||
except Exception as e:
|
||||
logger.error(f"SVD failed for {lora_name} with shape {mat_for_svd.shape}. Error: {e}")
|
||||
continue
|
||||
|
|
@ -422,21 +430,19 @@ def svd(
|
|||
U_full, S_full, Vh_full, rank, clamp_quantile,
|
||||
is_conv2d_layer, is_conv2d_3x3_layer, kernel_s,
|
||||
module_true_out_channels, module_true_in_channels,
|
||||
final_weights_device, save_dtype_torch # Final weights to CPU with target dtype
|
||||
final_weights_device, save_dtype_torch
|
||||
)
|
||||
lora_weights[lora_name] = (U_clamped, Vh_clamped) # U_clamped, Vh_clamped are on final_weights_device (CPU)
|
||||
lora_weights[lora_name] = (U_clamped, Vh_clamped)
|
||||
|
||||
if verbose:
|
||||
_log_svd_stats(lora_name, S_full, rank, MIN_SV) # S_full is on svd_computation_device
|
||||
_log_svd_stats(lora_name, S_full, rank, MIN_SV)
|
||||
|
||||
# Create state dict for LoRA (all components are on final_weights_device (CPU))
|
||||
lora_sd = {}
|
||||
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
||||
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
||||
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype_torch, device=final_weights_device)
|
||||
|
||||
# Clean up original models from memory if they are still around and large (especially if on GPU)
|
||||
del text_encoders_o, unet_o, lora_network_o, all_diffs
|
||||
if 'torch' in sys.modules and hasattr(torch, 'cuda') and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
@ -444,12 +450,18 @@ def svd(
|
|||
os.makedirs(os.path.dirname(save_to), exist_ok=True)
|
||||
|
||||
metadata_to_save = _prepare_lora_metadata(
|
||||
save_to, v2, model_version, conv_dim,
|
||||
bool(dynamic_method), dim,
|
||||
v_parameterization, sdxl, no_metadata
|
||||
output_path=save_to,
|
||||
is_v2_flag=v2, # The script's --v2 flag
|
||||
kohya_base_model_version_str=kohya_model_version, # The specific version string from model_util / sdxl_model_util
|
||||
network_conv_dim_val=conv_dim,
|
||||
use_dynamic_method_flag=bool(dynamic_method),
|
||||
network_dim_config_val=dim, # 'dim' is the general network dim if not dynamic
|
||||
is_v_param_flag=v_parameterization, # The script's v_param flag
|
||||
is_sdxl_flag=sdxl, # The script's --sdxl flag
|
||||
skip_sai_meta=no_metadata
|
||||
)
|
||||
|
||||
save_to_file(save_to, lora_sd, save_dtype_torch, metadata_to_save) # save_dtype_torch applied again if not None
|
||||
save_to_file(save_to, lora_sd, save_dtype_torch, metadata_to_save)
|
||||
logger.info(f"LoRA saved to: {save_to}")
|
||||
|
||||
|
||||
|
|
@ -498,4 +510,9 @@ if __name__ == "__main__":
|
|||
|
||||
if MIN_SV <= 0: logger.warning(f"Global MIN_SV ({MIN_SV}) should be positive.")
|
||||
|
||||
svd(**vars(args))
|
||||
# Pass the correct CLI arguments to the svd function
|
||||
# Note: 'v2', 'sdxl', 'v_parameterization' are directly from args
|
||||
# 'kohya_model_version' is determined inside svd() and then passed to _prepare_lora_metadata
|
||||
svd_args = vars(args).copy()
|
||||
# No change needed here as svd() internally determines kohya_model_version and passes it correctly
|
||||
svd(**svd_args)
|
||||
Loading…
Reference in New Issue