Fix non sdxl device selection

pull/3264/head
bmaltais 2025-05-19 11:18:00 -04:00
parent 25f8925aeb
commit 829d5a6af3
2 changed files with 151 additions and 134 deletions

View File

@ -37,12 +37,12 @@ D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-
--save_precision fp16 `
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
--model_tuned E:/models/sdxl/xxxRay_v11.safetensors `
--save_to E:/lora/sdxl/xxxRay_v11_sv_fro_0.85_1024.safetensors `
--save_to E:/lora/sdxl/xxxRay_v11_sv_fro_0.9_1024.safetensors `
--dim 1024 `
--device cuda `
--sdxl `
--dynamic_method sv_fro `
--dynamic_param 0.85 `
--dynamic_param 0.9 `
--verbose
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `

View File

@ -2,27 +2,31 @@ import sys
import os
# 1. Add sd-scripts directory to sys.path
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
sd_scripts_dir_path = os.path.join(project_root, "sd-scripts")
# This block can now be potentially removed if no other sd-scripts imports are needed
# OR kept if there's a chance of re-introducing some utilities for other purposes.
# For full removal of the sd-scripts dependency for *this script's execution*,
# ensure no other `from library...` or `from networks...` exist.
# script_dir = os.path.dirname(os.path.abspath(__file__))
# project_root = os.path.dirname(script_dir)
# sd_scripts_dir_path = os.path.join(project_root, "sd-scripts")
if sd_scripts_dir_path not in sys.path:
sys.path.insert(0, sd_scripts_dir_path)
# if sd_scripts_dir_path not in sys.path:
# sys.path.insert(0, sd_scripts_dir_path)
# Now you can import from the library package and the networks package
try:
# model_util and sdxl_model_util REMOVED from here
from library.utils import setup_logging
from networks import lora
except ImportError as e:
print(f"Error importing from sd-scripts. Please check your sd-scripts folder structure.")
print(f"Attempted to load from: {sd_scripts_dir_path}")
print(f"Original error: {e}")
print("Current sys.path relevant entries:")
for p in sys.path:
if "sd-scripts" in p or "kohya_ss" in p:
print(p)
raise
# try:
# # model_util and sdxl_model_util REMOVED from here
# # from library.utils import setup_logging # REMOVED
# # from networks import lora # REMOVED
# except ImportError as e:
# print(f"Error importing from sd-scripts. Please check your sd-scripts folder structure.")
# # print(f"Attempted to load from: {sd_scripts_dir_path}") # If path addition is removed
# print(f"Original error: {e}")
# print("Current sys.path relevant entries:")
# for p in sys.path:
# if "sd-scripts" in p or "kohya_ss" in p: # Adjust if sd_scripts_dir_path is removed
# print(p)
# raise
import argparse
import json
@ -30,32 +34,87 @@ import time
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import logging # Import for logging
# NEW: Add diffusers import for model loading
try:
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from diffusers.utils import load_image # In case any part needs it, though not directly by your script
except ImportError:
print("Diffusers library not found. Please install it: pip install diffusers transformers accelerate")
raise
setup_logging()
import logging
logger = logging.getLogger(__name__)
# --- Localized Logging Setup ---
def _local_setup_logging(log_level=logging.INFO):
"""
Sets up basic logging to console.
"""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)-8s %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_local_setup_logging() # Initialize logging
logger = logging.getLogger(__name__) # Get logger for this module
MIN_SV = 1e-6
# --- Localized sd-scripts constants and utility functions ---
_LOCAL_MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_v10" # Common identifier used in sd-scripts for SDXL base
_LOCAL_MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_v10"
def _local_get_model_version_str_for_sd1_sd2(is_v2: bool, is_v_parameterization: bool) -> str:
"""
Replicates model_util.get_model_version_str_for_sd1_sd2 from sd-scripts.
Determines a string representation for SD1.x or SD2.x model versions.
"""
if is_v2:
return "v2-v" if is_v_parameterization else "v2"
return "v1" # Corresponds to SD 1.x
return "v1"
# --- Localized LoRA Placeholder and Network Creation ---
class LocalLoRAModulePlaceholder:
def __init__(self, lora_name: str, org_module: torch.nn.Module):
self.lora_name = lora_name
self.org_module = org_module
# Add other attributes if _calculate_module_diffs_and_check needs them,
# but it primarily uses .lora_name and .org_module.weight
def _local_create_network_placeholders(text_encoders: list, unet: torch.nn.Module, lora_conv_dim_init: int):
"""
Creates placeholders for LoRA-able modules in text encoders and UNet.
Mimics the module identification and naming of sd-scripts' lora.create_network.
`lora_conv_dim_init`: If > 0, Conv2d layers are considered for LoRA.
"""
unet_loras = []
text_encoder_loras = []
# Target U-Net modules
for name, module in unet.named_modules():
lora_name = "lora_unet_" + name.replace(".", "_")
if isinstance(module, torch.nn.Linear):
unet_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
elif isinstance(module, torch.nn.Conv2d):
if lora_conv_dim_init > 0: # Only consider conv layers if conv_dim > 0
# Kernel size check might be relevant if sd-scripts has specific logic,
# but for diffing, any conv is a candidate if conv_dim > 0.
# SVD will later handle rank based on actual layer type (1x1 vs 3x3).
unet_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
# Target Text Encoder modules
for i, text_encoder in enumerate(text_encoders):
if text_encoder is None: # SDXL can have None TEs if not loaded
continue
# Determine prefix based on number of text encoders (for SDXL compatibility)
te_prefix = f"lora_te{i+1}_" if len(text_encoders) > 1 else "lora_te_"
for name, module in text_encoder.named_modules():
lora_name = te_prefix + name.replace(".", "_")
if isinstance(module, torch.nn.Linear):
text_encoder_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
# Conv2d in text encoders is rare but check just in case (sd-scripts might)
elif isinstance(module, torch.nn.Conv2d):
if lora_conv_dim_init > 0:
text_encoder_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
logger.info(f"Found {len(text_encoder_loras)} LoRA-able placeholder modules in Text Encoders.")
logger.info(f"Found {len(unet_loras)} LoRA-able placeholder modules in U-Net.")
return text_encoder_loras, unet_loras
# --- Singular Value Indexing Functions (Unchanged) ---
@ -158,89 +217,38 @@ def _build_local_sai_metadata(title, creation_time, is_v2_flag, is_v_param_flag,
return metadata
# --- MODIFIED Helper Functions for Model Loading ---
def _load_sd_model_components(model_path, is_v2_flag, load_dtype_torch): # Renamed is_v2 to is_v2_flag for clarity
"""
Loads Text Encoder and UNet from a Stable Diffusion checkpoint (.ckpt or .safetensors)
using diffusers.StableDiffusionPipeline.from_single_file.
The VAE is loaded but then deleted as it's not used by this script.
Models are loaded to CPU first, then dtype is applied, then moved to CPU (as per original logic flow).
"""
def _load_sd_model_components(model_path, is_v2_flag, target_device_override, load_dtype_torch):
logger.info(f"Loading SD model using Diffusers.StableDiffusionPipeline from: {model_path}")
# Diffusers from_single_file usually loads to CUDA if available by default with certain dtypes.
# We want to replicate: load, then cast dtype, ensure on CPU for diff calculation if not handled by diff calc device.
# The original script's model_util.load_models_from_stable_diffusion_checkpoint loads to CPU.
# Load with specified dtype, but this might place it on GPU.
# Forcing CPU load is tricky with from_single_file if a GPU is available.
# A common pattern is to load then move.
pipeline = StableDiffusionPipeline.from_single_file(
model_path,
torch_dtype=load_dtype_torch # Apply dtype on load
# load_safety_checker=False, # REMOVED
torch_dtype=load_dtype_torch
)
# Ensure models are on CPU after loading and dtype casting, before returning.
# The diff calculation expects them on CPU.
pipeline.to("cpu")
text_encoder = pipeline.text_encoder
# VAE is loaded by pipeline but not used further in this script.
# vae = pipeline.vae
# del vae
unet = pipeline.unet
eff_device = target_device_override if target_device_override else "cpu"
text_encoder = pipeline.text_encoder.to(eff_device)
unet = pipeline.unet.to(eff_device)
text_encoders = [text_encoder]
# Dtype should be set by torch_dtype in from_single_file.
# If any component is not on CPU, move it. (pipeline.to("cpu") should handle this)
# for te in text_encoders:
# if te.device.type != "cpu": te.to("cpu")
# if unet.device.type != "cpu": unet.to("cpu")
# And ensure dtype again if from_single_file's torch_dtype was not fully effective on all parts
# if load_dtype_torch:
# for te in text_encoders: te.to(dtype=load_dtype_torch)
# unet.to(dtype=load_dtype_torch)
# The is_v2_flag is not directly used by from_single_file for loading,
# as it attempts to infer the model version from the checkpoint.
# This could be a point of difference if sd-scripts used is_v2 for more subtle loading decisions.
logger.info(f"Loaded SD model components. UNet device: {unet.device}, TextEncoder device: {text_encoder.device}")
return text_encoders, unet
def _load_sdxl_model_components(model_path, target_device_override, load_dtype_torch):
"""
Loads Text Encoders and UNet from an SDXL checkpoint (.ckpt or .safetensors)
using diffusers.StableDiffusionXLPipeline.from_single_file.
The VAE is loaded but then deleted.
Models are loaded to `actual_load_device` (CPU by default, or `target_device_override`).
"""
actual_load_device = target_device_override if target_device_override else "cpu"
logger.info(f"Loading SDXL model using Diffusers.StableDiffusionXLPipeline from: {model_path} to device: {actual_load_device}")
pipeline = StableDiffusionXLPipeline.from_single_file(
model_path,
torch_dtype=load_dtype_torch # Apply dtype on load
# load_safety_checker=False, # REMOVED
torch_dtype=load_dtype_torch
)
pipeline.to(actual_load_device) # Move to the target device after loading
pipeline.to(actual_load_device)
text_encoder = pipeline.text_encoder
text_encoder_2 = pipeline.text_encoder_2
# vae = pipeline.vae
# del vae
unet = pipeline.unet
text_encoders = [text_encoder, text_encoder_2]
logger.info(f"Loaded SDXL model components. UNet device: {unet.device}, TextEncoder1 device: {text_encoder.device}, TextEncoder2 device: {text_encoder_2.device}")
return text_encoders, unet
def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_device, min_diff_thresh, module_type_str):
diffs_map = {}
is_different_flag = False
first_diff_logged = False
for lora_o, lora_t in zip(module_loras_o, module_loras_t):
lora_name = lora_o.lora_name
if lora_o.org_module is None or lora_t.org_module is None or \
@ -248,15 +256,10 @@ def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_
not hasattr(lora_t.org_module, 'weight') or lora_t.org_module.weight is None:
logger.warning(f"Skipping {lora_name} in {module_type_str} due to missing org_module or weight.")
continue
weight_o = lora_o.org_module.weight
weight_t = lora_t.org_module.weight
if str(weight_o.device) != str(diff_calc_device):
weight_o = weight_o.to(diff_calc_device)
if str(weight_t.device) != str(diff_calc_device):
weight_t = weight_t.to(diff_calc_device)
if str(weight_o.device) != str(diff_calc_device): weight_o = weight_o.to(diff_calc_device)
if str(weight_t.device) != str(diff_calc_device): weight_t = weight_t.to(diff_calc_device)
diff = weight_t - weight_o
diffs_map[lora_name] = diff
current_max_diff = torch.max(torch.abs(diff))
@ -344,7 +347,7 @@ def _prepare_lora_metadata(output_path, is_v2_flag, kohya_base_model_version_str
final_metadata = {
"ss_v2": str(is_v2_flag),
"ss_base_model_version": kohya_base_model_version_str,
"ss_network_module": "networks.lora",
"ss_network_module": "networks.lora", # This remains for compatibility with tools expecting it
"ss_network_dim": network_dim_meta,
"ss_network_alpha": network_alpha_meta,
"ss_network_args": json.dumps(net_kwargs),
@ -363,15 +366,13 @@ def _prepare_lora_metadata(output_path, is_v2_flag, kohya_base_model_version_str
# --- Main SVD Function ---
def svd(
model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, # v2 here is the CLI arg --v2
model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None,
conv_dim=None, v_parameterization=None, device=None, save_precision=None,
clamp_quantile=0.99, min_diff=0.01, no_metadata=False, load_precision=None,
load_original_model_to=None, load_tuned_model_to=None,
dynamic_method=None, dynamic_param=None, verbose=False,
):
# Determine v_parameterization based on v2 flag if not explicitly set (original logic)
actual_v_parameterization = v2 if v_parameterization is None else v_parameterization
load_dtype_torch = _str_to_dtype(load_precision)
save_dtype_torch = _str_to_dtype(save_precision) if save_precision else torch.float
@ -382,26 +383,36 @@ def svd(
final_weights_device = torch.device("cpu")
if not sdxl:
# Pass the v2 flag from CLI (named 'v2' in this function's scope)
text_encoders_o, unet_o = _load_sd_model_components(model_org, v2, load_dtype_torch)
text_encoders_t, unet_t = _load_sd_model_components(model_tuned, v2, load_dtype_torch)
# Use the localized function for version string
text_encoders_o, unet_o = _load_sd_model_components(model_org, v2, load_original_model_to, load_dtype_torch)
text_encoders_t, unet_t = _load_sd_model_components(model_tuned, v2, load_tuned_model_to, load_dtype_torch)
kohya_model_version = _local_get_model_version_str_for_sd1_sd2(v2, actual_v_parameterization)
else:
text_encoders_o, unet_o = _load_sdxl_model_components(model_org, load_original_model_to, load_dtype_torch)
text_encoders_t, unet_t = _load_sdxl_model_components(model_tuned, load_tuned_model_to, load_dtype_torch)
# Use the localized constant for SDXL version string
kohya_model_version = _LOCAL_MODEL_VERSION_SDXL_BASE_V1_0
# Determine lora_conv_dim_init based on conv_dim argument for network creation
# The original script used init_dim_val (1) if conv_dim was None.
# Here, conv_dim is already defaulted to args.dim if None by the main block.
# So, lora_conv_dim_init will be args.conv_dim (which defaults to args.dim).
# If args.conv_dim was explicitly 0, this would be 0.
lora_conv_dim_init_val = conv_dim # conv_dim is args.conv_dim (or args.dim)
init_dim_val = 1
lora_conv_dim_init = conv_dim if conv_dim is not None else init_dim_val
kwargs_lora = {"conv_dim": lora_conv_dim_init, "conv_alpha": lora_conv_dim_init}
# Create LoRA placeholders using the localized function
text_encoder_loras_o, unet_loras_o = _local_create_network_placeholders(text_encoders_o, unet_o, lora_conv_dim_init_val)
text_encoder_loras_t, unet_loras_t = _local_create_network_placeholders(text_encoders_t, unet_t, lora_conv_dim_init_val) # same conv_dim logic for tuned
lora_network_o = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_o, unet_o, **kwargs_lora)
lora_network_t = lora.create_network(1.0, init_dim_val, init_dim_val, None, text_encoders_t, unet_t, **kwargs_lora)
# Group LoRA placeholders for easier processing (mimicking LoraNetwork structure somewhat)
class LocalLoraNetworkPlaceholder:
def __init__(self, te_loras, unet_loras_list):
self.text_encoder_loras = te_loras
self.unet_loras = unet_loras_list
lora_network_o = LocalLoraNetworkPlaceholder(text_encoder_loras_o, unet_loras_o)
lora_network_t = LocalLoraNetworkPlaceholder(text_encoder_loras_t, unet_loras_t)
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), \
f"Model versions differ: {len(lora_network_o.text_encoder_loras)} vs {len(lora_network_t.text_encoder_loras)} TEs"
f"Model versions (based on identified LoRA-able TE modules) differ: {len(lora_network_o.text_encoder_loras)} vs {len(lora_network_t.text_encoder_loras)} TEs"
all_diffs = {}
te_diffs, text_encoder_different = _calculate_module_diffs_and_check(
@ -413,24 +424,30 @@ def svd(
all_diffs.update(te_diffs)
else:
logger.warning("Text encoders are considered identical based on min_diff. Not extracting TE LoRA.")
lora_network_o.text_encoder_loras = []
del text_encoders_t
# To prevent processing empty list later, ensure it's empty if no diffs
lora_network_o.text_encoder_loras = []
del text_encoders_t # Free memory early
unet_diffs, _ = _calculate_module_diffs_and_check(
lora_network_o.unet_loras, lora_network_t.unet_loras,
diff_calculation_device, min_diff, "U-Net"
)
all_diffs.update(unet_diffs)
del lora_network_t, unet_t
del lora_network_t, unet_t # Free memory early
lora_names_to_process = set(lora.lora_name for lora in lora_network_o.text_encoder_loras + lora_network_o.unet_loras)
# Ensure lora_names_to_process only includes modules from lora_network_o
# that are actually present (e.g., if TEs were skipped)
lora_names_to_process = set()
if text_encoder_different: # Only add TE loras if they were deemed different
lora_names_to_process.update(p.lora_name for p in lora_network_o.text_encoder_loras)
lora_names_to_process.update(p.lora_name for p in lora_network_o.unet_loras)
logger.info("Extracting and resizing LoRA via SVD")
lora_weights = {}
with torch.no_grad():
for lora_name in tqdm(lora_names_to_process):
if lora_name not in all_diffs:
logger.warning(f"Skipping {lora_name} as no diff was calculated for it.")
logger.warning(f"Skipping {lora_name} as no diff was calculated for it (e.g., Text Encoders were identical).")
continue
original_diff_tensor = all_diffs[lora_name]
is_conv2d_layer = len(original_diff_tensor.size()) == 4
@ -449,10 +466,16 @@ def svd(
except Exception as e:
logger.error(f"SVD failed for {lora_name} with shape {mat_for_svd.shape}. Error: {e}")
continue
# Max rank for SVD is based on 'dim' for linear and 'conv_dim' for conv3x3
# The original `current_max_rank` logic was:
# current_max_rank = dim if not is_conv2d_3x3_layer or conv_dim is None else conv_dim
# Here, `dim` is args.dim and `conv_dim` is args.conv_dim (defaulted to args.dim)
module_specific_max_rank = conv_dim if is_conv2d_3x3_layer else dim
eff_out_dim, eff_in_dim = mat_for_svd.shape[0], mat_for_svd.shape[1]
current_max_rank = dim if not is_conv2d_3x3_layer or conv_dim is None else conv_dim
rank = _determine_rank(S_full, dynamic_method, dynamic_param,
current_max_rank, eff_in_dim, eff_out_dim, MIN_SV)
module_specific_max_rank, eff_in_dim, eff_out_dim, MIN_SV)
U_clamped, Vh_clamped = _construct_lora_weights_from_svd_components(
U_full, S_full, Vh_full, rank, clamp_quantile,
is_conv2d_layer, is_conv2d_3x3_layer, kernel_s,
@ -466,24 +489,24 @@ def svd(
for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
# Alpha is set to the rank (dim of down_weight's 0th axis, which is rank)
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype_torch, device=final_weights_device)
del text_encoders_o, unet_o, lora_network_o, all_diffs
del text_encoders_o, unet_o, lora_network_o, all_diffs # Clean up original models and placeholders
if 'torch' in sys.modules and hasattr(torch, 'cuda') and torch.cuda.is_available():
torch.cuda.empty_cache()
if not os.path.exists(os.path.dirname(save_to)) and os.path.dirname(save_to) != "": # Check if dirname is not empty
if not os.path.exists(os.path.dirname(save_to)) and os.path.dirname(save_to) != "":
os.makedirs(os.path.dirname(save_to), exist_ok=True)
metadata_to_save = _prepare_lora_metadata(
output_path=save_to,
is_v2_flag=v2, # CLI --v2 flag
is_v2_flag=v2,
kohya_base_model_version_str=kohya_model_version,
network_conv_dim_val=conv_dim,
network_conv_dim_val=conv_dim, # This is args.conv_dim (defaulted to args.dim)
use_dynamic_method_flag=bool(dynamic_method),
network_dim_config_val=dim,
is_v_param_flag=actual_v_parameterization, # Use the derived v_param
network_dim_config_val=dim, # This is args.dim
is_v_param_flag=actual_v_parameterization,
is_sdxl_flag=sdxl,
skip_sai_meta=no_metadata
)
@ -523,24 +546,18 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.conv_dim is None:
args.conv_dim = args.dim
args.conv_dim = args.dim # Default conv_dim to dim if not provided
logger.info(f"--conv_dim not set, using value of --dim: {args.conv_dim}")
methods_requiring_param = ["sv_ratio", "sv_fro", "sv_cumulative", "sv_rel_decrease"]
if args.dynamic_method in methods_requiring_param and args.dynamic_param is None:
parser.error(f"Dynamic method '{args.dynamic_method}' requires --dynamic_param to be set.")
if not args.dynamic_method:
if not args.dynamic_method: # Ranks must be positive if not using dynamic method
if args.dim <= 0: parser.error(f"--dim (rank) must be > 0. Got {args.dim}")
if args.conv_dim <=0: parser.error(f"--conv_dim (rank) must be > 0. Got {args.conv_dim}")
if args.conv_dim <=0: parser.error(f"--conv_dim (rank) must be > 0. Got {args.conv_dim}") # Check after defaulting
if MIN_SV <= 0: logger.warning(f"Global MIN_SV ({MIN_SV}) should be positive.")
# The v_parameterization in args defaults to False.
# The svd function has logic: actual_v_parameterization = v2 if v_parameterization is None else v_parameterization
# This means if --v_parameterization is not given, it takes the value of --v2.
# If --v_parameterization is given, it's used.
# This logic is preserved inside svd().
svd_args = vars(args).copy()
svd(**svd_args)