mirror of https://github.com/bmaltais/kohya_ss
Fix non sdxl device selection
parent
25f8925aeb
commit
829d5a6af3
|
|
@ -37,12 +37,12 @@ D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-
|
|||
--save_precision fp16 `
|
||||
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
|
||||
--model_tuned E:/models/sdxl/xxxRay_v11.safetensors `
|
||||
--save_to E:/lora/sdxl/xxxRay_v11_sv_fro_0.85_1024.safetensors `
|
||||
--save_to E:/lora/sdxl/xxxRay_v11_sv_fro_0.9_1024.safetensors `
|
||||
--dim 1024 `
|
||||
--device cuda `
|
||||
--sdxl `
|
||||
--dynamic_method sv_fro `
|
||||
--dynamic_param 0.85 `
|
||||
--dynamic_param 0.9 `
|
||||
--verbose
|
||||
|
||||
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
|
||||
|
|
|
|||
|
|
@ -2,27 +2,31 @@ import sys
|
|||
import os
|
||||
|
||||
# 1. Add sd-scripts directory to sys.path
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(script_dir)
|
||||
sd_scripts_dir_path = os.path.join(project_root, "sd-scripts")
|
||||
# This block can now be potentially removed if no other sd-scripts imports are needed
|
||||
# OR kept if there's a chance of re-introducing some utilities for other purposes.
|
||||
# For full removal of the sd-scripts dependency for *this script's execution*,
|
||||
# ensure no other `from library...` or `from networks...` exist.
|
||||
# script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# project_root = os.path.dirname(script_dir)
|
||||
# sd_scripts_dir_path = os.path.join(project_root, "sd-scripts")
|
||||
|
||||
if sd_scripts_dir_path not in sys.path:
|
||||
sys.path.insert(0, sd_scripts_dir_path)
|
||||
# if sd_scripts_dir_path not in sys.path:
|
||||
# sys.path.insert(0, sd_scripts_dir_path)
|
||||
|
||||
# Now you can import from the library package and the networks package
|
||||
try:
|
||||
# model_util and sdxl_model_util REMOVED from here
|
||||
from library.utils import setup_logging
|
||||
from networks import lora
|
||||
except ImportError as e:
|
||||
print(f"Error importing from sd-scripts. Please check your sd-scripts folder structure.")
|
||||
print(f"Attempted to load from: {sd_scripts_dir_path}")
|
||||
print(f"Original error: {e}")
|
||||
print("Current sys.path relevant entries:")
|
||||
for p in sys.path:
|
||||
if "sd-scripts" in p or "kohya_ss" in p:
|
||||
print(p)
|
||||
raise
|
||||
# try:
|
||||
# # model_util and sdxl_model_util REMOVED from here
|
||||
# # from library.utils import setup_logging # REMOVED
|
||||
# # from networks import lora # REMOVED
|
||||
# except ImportError as e:
|
||||
# print(f"Error importing from sd-scripts. Please check your sd-scripts folder structure.")
|
||||
# # print(f"Attempted to load from: {sd_scripts_dir_path}") # If path addition is removed
|
||||
# print(f"Original error: {e}")
|
||||
# print("Current sys.path relevant entries:")
|
||||
# for p in sys.path:
|
||||
# if "sd-scripts" in p or "kohya_ss" in p: # Adjust if sd_scripts_dir_path is removed
|
||||
# print(p)
|
||||
# raise
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
|
@ -30,32 +34,87 @@ import time
|
|||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import logging # Import for logging
|
||||
|
||||
# NEW: Add diffusers import for model loading
|
||||
try:
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
from diffusers.utils import load_image # In case any part needs it, though not directly by your script
|
||||
except ImportError:
|
||||
print("Diffusers library not found. Please install it: pip install diffusers transformers accelerate")
|
||||
raise
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
# --- Localized Logging Setup ---
|
||||
def _local_setup_logging(log_level=logging.INFO):
|
||||
"""
|
||||
Sets up basic logging to console.
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)-8s %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
_local_setup_logging() # Initialize logging
|
||||
logger = logging.getLogger(__name__) # Get logger for this module
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
# --- Localized sd-scripts constants and utility functions ---
|
||||
_LOCAL_MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_v10" # Common identifier used in sd-scripts for SDXL base
|
||||
_LOCAL_MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_v10"
|
||||
|
||||
def _local_get_model_version_str_for_sd1_sd2(is_v2: bool, is_v_parameterization: bool) -> str:
|
||||
"""
|
||||
Replicates model_util.get_model_version_str_for_sd1_sd2 from sd-scripts.
|
||||
Determines a string representation for SD1.x or SD2.x model versions.
|
||||
"""
|
||||
if is_v2:
|
||||
return "v2-v" if is_v_parameterization else "v2"
|
||||
return "v1" # Corresponds to SD 1.x
|
||||
return "v1"
|
||||
|
||||
# --- Localized LoRA Placeholder and Network Creation ---
|
||||
class LocalLoRAModulePlaceholder:
|
||||
def __init__(self, lora_name: str, org_module: torch.nn.Module):
|
||||
self.lora_name = lora_name
|
||||
self.org_module = org_module
|
||||
# Add other attributes if _calculate_module_diffs_and_check needs them,
|
||||
# but it primarily uses .lora_name and .org_module.weight
|
||||
|
||||
def _local_create_network_placeholders(text_encoders: list, unet: torch.nn.Module, lora_conv_dim_init: int):
|
||||
"""
|
||||
Creates placeholders for LoRA-able modules in text encoders and UNet.
|
||||
Mimics the module identification and naming of sd-scripts' lora.create_network.
|
||||
`lora_conv_dim_init`: If > 0, Conv2d layers are considered for LoRA.
|
||||
"""
|
||||
unet_loras = []
|
||||
text_encoder_loras = []
|
||||
|
||||
# Target U-Net modules
|
||||
for name, module in unet.named_modules():
|
||||
lora_name = "lora_unet_" + name.replace(".", "_")
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
unet_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
elif isinstance(module, torch.nn.Conv2d):
|
||||
if lora_conv_dim_init > 0: # Only consider conv layers if conv_dim > 0
|
||||
# Kernel size check might be relevant if sd-scripts has specific logic,
|
||||
# but for diffing, any conv is a candidate if conv_dim > 0.
|
||||
# SVD will later handle rank based on actual layer type (1x1 vs 3x3).
|
||||
unet_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
|
||||
# Target Text Encoder modules
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if text_encoder is None: # SDXL can have None TEs if not loaded
|
||||
continue
|
||||
# Determine prefix based on number of text encoders (for SDXL compatibility)
|
||||
te_prefix = f"lora_te{i+1}_" if len(text_encoders) > 1 else "lora_te_"
|
||||
|
||||
for name, module in text_encoder.named_modules():
|
||||
lora_name = te_prefix + name.replace(".", "_")
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
text_encoder_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
# Conv2d in text encoders is rare but check just in case (sd-scripts might)
|
||||
elif isinstance(module, torch.nn.Conv2d):
|
||||
if lora_conv_dim_init > 0:
|
||||
text_encoder_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
|
||||
logger.info(f"Found {len(text_encoder_loras)} LoRA-able placeholder modules in Text Encoders.")
|
||||
logger.info(f"Found {len(unet_loras)} LoRA-able placeholder modules in U-Net.")
|
||||
return text_encoder_loras, unet_loras
|
||||
|
||||
|
||||
# --- Singular Value Indexing Functions (Unchanged) ---
|
||||
|
|
@ -158,89 +217,38 @@ def _build_local_sai_metadata(title, creation_time, is_v2_flag, is_v_param_flag,
|
|||
return metadata
|
||||
|
||||
# --- MODIFIED Helper Functions for Model Loading ---
|
||||
def _load_sd_model_components(model_path, is_v2_flag, load_dtype_torch): # Renamed is_v2 to is_v2_flag for clarity
|
||||
"""
|
||||
Loads Text Encoder and UNet from a Stable Diffusion checkpoint (.ckpt or .safetensors)
|
||||
using diffusers.StableDiffusionPipeline.from_single_file.
|
||||
The VAE is loaded but then deleted as it's not used by this script.
|
||||
Models are loaded to CPU first, then dtype is applied, then moved to CPU (as per original logic flow).
|
||||
"""
|
||||
def _load_sd_model_components(model_path, is_v2_flag, target_device_override, load_dtype_torch):
|
||||
logger.info(f"Loading SD model using Diffusers.StableDiffusionPipeline from: {model_path}")
|
||||
|
||||
# Diffusers from_single_file usually loads to CUDA if available by default with certain dtypes.
|
||||
# We want to replicate: load, then cast dtype, ensure on CPU for diff calculation if not handled by diff calc device.
|
||||
# The original script's model_util.load_models_from_stable_diffusion_checkpoint loads to CPU.
|
||||
|
||||
# Load with specified dtype, but this might place it on GPU.
|
||||
# Forcing CPU load is tricky with from_single_file if a GPU is available.
|
||||
# A common pattern is to load then move.
|
||||
pipeline = StableDiffusionPipeline.from_single_file(
|
||||
model_path,
|
||||
torch_dtype=load_dtype_torch # Apply dtype on load
|
||||
# load_safety_checker=False, # REMOVED
|
||||
torch_dtype=load_dtype_torch
|
||||
)
|
||||
# Ensure models are on CPU after loading and dtype casting, before returning.
|
||||
# The diff calculation expects them on CPU.
|
||||
pipeline.to("cpu")
|
||||
|
||||
text_encoder = pipeline.text_encoder
|
||||
# VAE is loaded by pipeline but not used further in this script.
|
||||
# vae = pipeline.vae
|
||||
# del vae
|
||||
unet = pipeline.unet
|
||||
|
||||
eff_device = target_device_override if target_device_override else "cpu"
|
||||
text_encoder = pipeline.text_encoder.to(eff_device)
|
||||
unet = pipeline.unet.to(eff_device)
|
||||
text_encoders = [text_encoder]
|
||||
|
||||
# Dtype should be set by torch_dtype in from_single_file.
|
||||
# If any component is not on CPU, move it. (pipeline.to("cpu") should handle this)
|
||||
# for te in text_encoders:
|
||||
# if te.device.type != "cpu": te.to("cpu")
|
||||
# if unet.device.type != "cpu": unet.to("cpu")
|
||||
# And ensure dtype again if from_single_file's torch_dtype was not fully effective on all parts
|
||||
# if load_dtype_torch:
|
||||
# for te in text_encoders: te.to(dtype=load_dtype_torch)
|
||||
# unet.to(dtype=load_dtype_torch)
|
||||
|
||||
# The is_v2_flag is not directly used by from_single_file for loading,
|
||||
# as it attempts to infer the model version from the checkpoint.
|
||||
# This could be a point of difference if sd-scripts used is_v2 for more subtle loading decisions.
|
||||
logger.info(f"Loaded SD model components. UNet device: {unet.device}, TextEncoder device: {text_encoder.device}")
|
||||
return text_encoders, unet
|
||||
|
||||
def _load_sdxl_model_components(model_path, target_device_override, load_dtype_torch):
|
||||
"""
|
||||
Loads Text Encoders and UNet from an SDXL checkpoint (.ckpt or .safetensors)
|
||||
using diffusers.StableDiffusionXLPipeline.from_single_file.
|
||||
The VAE is loaded but then deleted.
|
||||
Models are loaded to `actual_load_device` (CPU by default, or `target_device_override`).
|
||||
"""
|
||||
actual_load_device = target_device_override if target_device_override else "cpu"
|
||||
logger.info(f"Loading SDXL model using Diffusers.StableDiffusionXLPipeline from: {model_path} to device: {actual_load_device}")
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
model_path,
|
||||
torch_dtype=load_dtype_torch # Apply dtype on load
|
||||
# load_safety_checker=False, # REMOVED
|
||||
torch_dtype=load_dtype_torch
|
||||
)
|
||||
pipeline.to(actual_load_device) # Move to the target device after loading
|
||||
|
||||
pipeline.to(actual_load_device)
|
||||
text_encoder = pipeline.text_encoder
|
||||
text_encoder_2 = pipeline.text_encoder_2
|
||||
# vae = pipeline.vae
|
||||
# del vae
|
||||
unet = pipeline.unet
|
||||
|
||||
text_encoders = [text_encoder, text_encoder_2]
|
||||
|
||||
logger.info(f"Loaded SDXL model components. UNet device: {unet.device}, TextEncoder1 device: {text_encoder.device}, TextEncoder2 device: {text_encoder_2.device}")
|
||||
return text_encoders, unet
|
||||
|
||||
|
||||
def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_device, min_diff_thresh, module_type_str):
|
||||
diffs_map = {}
|
||||
is_different_flag = False
|
||||
first_diff_logged = False
|
||||
|
||||
for lora_o, lora_t in zip(module_loras_o, module_loras_t):
|
||||
lora_name = lora_o.lora_name
|
||||
if lora_o.org_module is None or lora_t.org_module is None or \
|
||||
|
|
@ -248,15 +256,10 @@ def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_
|
|||
not hasattr(lora_t.org_module, 'weight') or lora_t.org_module.weight is None:
|
||||
logger.warning(f"Skipping {lora_name} in {module_type_str} due to missing org_module or weight.")
|
||||
continue
|
||||
|
||||
weight_o = lora_o.org_module.weight
|
||||
weight_t = lora_t.org_module.weight
|
||||
|
||||
if str(weight_o.device) != str(diff_calc_device):
|
||||
weight_o = weight_o.to(diff_calc_device)
|
||||
if str(weight_t.device) != str(diff_calc_device):
|
||||
weight_t = weight_t.to(diff_calc_device)
|
||||
|
||||
if str(weight_o.device) != str(diff_calc_device): weight_o = weight_o.to(diff_calc_device)
|
||||
if str(weight_t.device) != str(diff_calc_device): weight_t = weight_t.to(diff_calc_device)
|
||||
diff = weight_t - weight_o
|
||||
diffs_map[lora_name] = diff
|
||||
current_max_diff = torch.max(torch.abs(diff))
|
||||
|
|
@ -344,7 +347,7 @@ def _prepare_lora_metadata(output_path, is_v2_flag, kohya_base_model_version_str
|
|||
final_metadata = {
|
||||
"ss_v2": str(is_v2_flag),
|
||||
"ss_base_model_version": kohya_base_model_version_str,
|
||||
"ss_network_module": "networks.lora",
|
||||
"ss_network_module": "networks.lora", # This remains for compatibility with tools expecting it
|
||||
"ss_network_dim": network_dim_meta,
|
||||
"ss_network_alpha": network_alpha_meta,
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
|
|
@ -363,15 +366,13 @@ def _prepare_lora_metadata(output_path, is_v2_flag, kohya_base_model_version_str
|
|||
|
||||
# --- Main SVD Function ---
|
||||
def svd(
|
||||
model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, # v2 here is the CLI arg --v2
|
||||
model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None,
|
||||
conv_dim=None, v_parameterization=None, device=None, save_precision=None,
|
||||
clamp_quantile=0.99, min_diff=0.01, no_metadata=False, load_precision=None,
|
||||
load_original_model_to=None, load_tuned_model_to=None,
|
||||
dynamic_method=None, dynamic_param=None, verbose=False,
|
||||
):
|
||||
# Determine v_parameterization based on v2 flag if not explicitly set (original logic)
|
||||
actual_v_parameterization = v2 if v_parameterization is None else v_parameterization
|
||||
|
||||
load_dtype_torch = _str_to_dtype(load_precision)
|
||||
save_dtype_torch = _str_to_dtype(save_precision) if save_precision else torch.float
|
||||
|
||||
|
|
@ -382,26 +383,36 @@ def svd(
|
|||
final_weights_device = torch.device("cpu")
|
||||
|
||||
if not sdxl:
|
||||
# Pass the v2 flag from CLI (named 'v2' in this function's scope)
|
||||
text_encoders_o, unet_o = _load_sd_model_components(model_org, v2, load_dtype_torch)
|
||||
text_encoders_t, unet_t = _load_sd_model_components(model_tuned, v2, load_dtype_torch)
|
||||
# Use the localized function for version string
|
||||
text_encoders_o, unet_o = _load_sd_model_components(model_org, v2, load_original_model_to, load_dtype_torch)
|
||||
text_encoders_t, unet_t = _load_sd_model_components(model_tuned, v2, load_tuned_model_to, load_dtype_torch)
|
||||
kohya_model_version = _local_get_model_version_str_for_sd1_sd2(v2, actual_v_parameterization)
|
||||
else:
|
||||
text_encoders_o, unet_o = _load_sdxl_model_components(model_org, load_original_model_to, load_dtype_torch)
|
||||
text_encoders_t, unet_t = _load_sdxl_model_components(model_tuned, load_tuned_model_to, load_dtype_torch)
|
||||
# Use the localized constant for SDXL version string
|
||||
kohya_model_version = _LOCAL_MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
init_dim_val = 1
|
||||
lora_conv_dim_init = conv_dim if conv_dim is not None else init_dim_val
|
||||
kwargs_lora = {"conv_dim": lora_conv_dim_init, "conv_alpha": lora_conv_dim_init}
|
||||
# Determine lora_conv_dim_init based on conv_dim argument for network creation
|
||||
# The original script used init_dim_val (1) if conv_dim was None.
|
||||
# Here, conv_dim is already defaulted to args.dim if None by the main block.
|
||||
# So, lora_conv_dim_init will be args.conv_dim (which defaults to args.dim).
|
||||
# If args.conv_dim was explicitly 0, this would be 0.
|
||||
lora_conv_dim_init_val = conv_dim # conv_dim is args.conv_dim (or args.dim)
|
||||
|
||||
lora_network_o = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_o, unet_o, **kwargs_lora)
|
||||
lora_network_t = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_t, unet_t, **kwargs_lora)
|
||||
# Create LoRA placeholders using the localized function
|
||||
text_encoder_loras_o, unet_loras_o = _local_create_network_placeholders(text_encoders_o, unet_o, lora_conv_dim_init_val)
|
||||
text_encoder_loras_t, unet_loras_t = _local_create_network_placeholders(text_encoders_t, unet_t, lora_conv_dim_init_val) # same conv_dim logic for tuned
|
||||
|
||||
# Group LoRA placeholders for easier processing (mimicking LoraNetwork structure somewhat)
|
||||
class LocalLoraNetworkPlaceholder:
|
||||
def __init__(self, te_loras, unet_loras_list):
|
||||
self.text_encoder_loras = te_loras
|
||||
self.unet_loras = unet_loras_list
|
||||
|
||||
lora_network_o = LocalLoraNetworkPlaceholder(text_encoder_loras_o, unet_loras_o)
|
||||
lora_network_t = LocalLoraNetworkPlaceholder(text_encoder_loras_t, unet_loras_t)
|
||||
|
||||
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), \
|
||||
f"Model versions differ: {len(lora_network_o.text_encoder_loras)} vs {len(lora_network_t.text_encoder_loras)} TEs"
|
||||
f"Model versions (based on identified LoRA-able TE modules) differ: {len(lora_network_o.text_encoder_loras)} vs {len(lora_network_t.text_encoder_loras)} TEs"
|
||||
|
||||
all_diffs = {}
|
||||
te_diffs, text_encoder_different = _calculate_module_diffs_and_check(
|
||||
|
|
@ -413,24 +424,30 @@ def svd(
|
|||
all_diffs.update(te_diffs)
|
||||
else:
|
||||
logger.warning("Text encoders are considered identical based on min_diff. Not extracting TE LoRA.")
|
||||
# To prevent processing empty list later, ensure it's empty if no diffs
|
||||
lora_network_o.text_encoder_loras = []
|
||||
del text_encoders_t
|
||||
del text_encoders_t # Free memory early
|
||||
|
||||
unet_diffs, _ = _calculate_module_diffs_and_check(
|
||||
lora_network_o.unet_loras, lora_network_t.unet_loras,
|
||||
diff_calculation_device, min_diff, "U-Net"
|
||||
)
|
||||
all_diffs.update(unet_diffs)
|
||||
del lora_network_t, unet_t
|
||||
del lora_network_t, unet_t # Free memory early
|
||||
|
||||
lora_names_to_process = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
|
||||
# Ensure lora_names_to_process only includes modules from lora_network_o
|
||||
# that are actually present (e.g., if TEs were skipped)
|
||||
lora_names_to_process = set()
|
||||
if text_encoder_different: # Only add TE loras if they were deemed different
|
||||
lora_names_to_process.update(p.lora_name for p in lora_network_o.text_encoder_loras)
|
||||
lora_names_to_process.update(p.lora_name for p in lora_network_o.unet_loras)
|
||||
|
||||
logger.info("Extracting and resizing LoRA via SVD")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name in tqdm(lora_names_to_process):
|
||||
if lora_name not in all_diffs:
|
||||
logger.warning(f"Skipping {lora_name} as no diff was calculated for it.")
|
||||
logger.warning(f"Skipping {lora_name} as no diff was calculated for it (e.g., Text Encoders were identical).")
|
||||
continue
|
||||
original_diff_tensor = all_diffs[lora_name]
|
||||
is_conv2d_layer = len(original_diff_tensor.size()) == 4
|
||||
|
|
@ -449,10 +466,16 @@ def svd(
|
|||
except Exception as e:
|
||||
logger.error(f"SVD failed for {lora_name} with shape {mat_for_svd.shape}. Error: {e}")
|
||||
continue
|
||||
|
||||
# Max rank for SVD is based on 'dim' for linear and 'conv_dim' for conv3x3
|
||||
# The original `current_max_rank` logic was:
|
||||
# current_max_rank = dim if not is_conv2d_3x3_layer or conv_dim is None else conv_dim
|
||||
# Here, `dim` is args.dim and `conv_dim` is args.conv_dim (defaulted to args.dim)
|
||||
module_specific_max_rank = conv_dim if is_conv2d_3x3_layer else dim
|
||||
|
||||
eff_out_dim, eff_in_dim = mat_for_svd.shape[0], mat_for_svd.shape[1]
|
||||
current_max_rank = dim if not is_conv2d_3x3_layer or conv_dim is None else conv_dim
|
||||
rank = _determine_rank(S_full, dynamic_method, dynamic_param,
|
||||
current_max_rank, eff_in_dim, eff_out_dim, MIN_SV)
|
||||
module_specific_max_rank, eff_in_dim, eff_out_dim, MIN_SV)
|
||||
U_clamped, Vh_clamped = _construct_lora_weights_from_svd_components(
|
||||
U_full, S_full, Vh_full, rank, clamp_quantile,
|
||||
is_conv2d_layer, is_conv2d_3x3_layer, kernel_s,
|
||||
|
|
@ -466,24 +489,24 @@ def svd(
|
|||
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
||||
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
||||
# Alpha is set to the rank (dim of down_weight's 0th axis, which is rank)
|
||||
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype_torch, device=final_weights_device)
|
||||
|
||||
del text_encoders_o, unet_o, lora_network_o, all_diffs
|
||||
del text_encoders_o, unet_o, lora_network_o, all_diffs # Clean up original models and placeholders
|
||||
if 'torch' in sys.modules and hasattr(torch, 'cuda') and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not os.path.exists(os.path.dirname(save_to)) and os.path.dirname(save_to) != "": # Check if dirname is not empty
|
||||
if not os.path.exists(os.path.dirname(save_to)) and os.path.dirname(save_to) != "":
|
||||
os.makedirs(os.path.dirname(save_to), exist_ok=True)
|
||||
|
||||
|
||||
metadata_to_save = _prepare_lora_metadata(
|
||||
output_path=save_to,
|
||||
is_v2_flag=v2, # CLI --v2 flag
|
||||
is_v2_flag=v2,
|
||||
kohya_base_model_version_str=kohya_model_version,
|
||||
network_conv_dim_val=conv_dim,
|
||||
network_conv_dim_val=conv_dim, # This is args.conv_dim (defaulted to args.dim)
|
||||
use_dynamic_method_flag=bool(dynamic_method),
|
||||
network_dim_config_val=dim,
|
||||
is_v_param_flag=actual_v_parameterization, # Use the derived v_param
|
||||
network_dim_config_val=dim, # This is args.dim
|
||||
is_v_param_flag=actual_v_parameterization,
|
||||
is_sdxl_flag=sdxl,
|
||||
skip_sai_meta=no_metadata
|
||||
)
|
||||
|
|
@ -523,24 +546,18 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
if args.conv_dim is None:
|
||||
args.conv_dim = args.dim
|
||||
args.conv_dim = args.dim # Default conv_dim to dim if not provided
|
||||
logger.info(f"--conv_dim not set, using value of --dim: {args.conv_dim}")
|
||||
|
||||
methods_requiring_param = ["sv_ratio", "sv_fro", "sv_cumulative", "sv_rel_decrease"]
|
||||
if args.dynamic_method in methods_requiring_param and args.dynamic_param is None:
|
||||
parser.error(f"Dynamic method '{args.dynamic_method}' requires --dynamic_param to be set.")
|
||||
|
||||
if not args.dynamic_method:
|
||||
if not args.dynamic_method: # Ranks must be positive if not using dynamic method
|
||||
if args.dim <= 0: parser.error(f"--dim (rank) must be > 0. Got {args.dim}")
|
||||
if args.conv_dim <=0: parser.error(f"--conv_dim (rank) must be > 0. Got {args.conv_dim}")
|
||||
if args.conv_dim <=0: parser.error(f"--conv_dim (rank) must be > 0. Got {args.conv_dim}") # Check after defaulting
|
||||
|
||||
if MIN_SV <= 0: logger.warning(f"Global MIN_SV ({MIN_SV}) should be positive.")
|
||||
|
||||
# The v_parameterization in args defaults to False.
|
||||
# The svd function has logic: actual_v_parameterization = v2 if v_parameterization is None else v_parameterization
|
||||
# This means if --v_parameterization is not given, it takes the value of --v2.
|
||||
# If --v_parameterization is given, it's used.
|
||||
# This logic is preserved inside svd().
|
||||
|
||||
svd_args = vars(args).copy()
|
||||
svd(**svd_args)
|
||||
Loading…
Reference in New Issue