d8ahazard 2025-09-17 16:59:03 -05:00
parent 25ddb10977
commit 4244f6b2cf
1 changed files with 4 additions and 11 deletions

View File

@ -1444,21 +1444,20 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
# Get the path to a temporary directory
del s_pipeline
logger.debug(f"Loading image pipeline from {weights_dir}...")
# Build preview pipeline on CPU to avoid any overlap with training models on GPU
if args.model_type == "SDXL":
s_pipeline = StableDiffusionXLPipeline.from_pretrained(
weights_dir,
vae=vae,
revision=args.revision,
torch_dtype=weight_dtype,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
device_map=None,
)
else:
s_pipeline = StableDiffusionPipeline.from_pretrained(
weights_dir,
vae=vae,
revision=args.revision,
torch_dtype=weight_dtype,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
device_map=None,
)
@ -1467,13 +1466,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
if args.use_lora:
s_pipeline.load_lora_weights(lora_save_file)
try:
s_pipeline.enable_vae_tiling()
s_pipeline.enable_vae_slicing()
s_pipeline.enable_sequential_cpu_offload()
s_pipeline.enable_xformers_memory_efficient_attention()
except:
pass
# Do not enable GPU-offload/xformers here; keep preview on CPU to avoid VRAM spikes
samples = []
sample_prompts = []