diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index 10347b7..4fe31fc 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -1444,21 +1444,20 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR # Get the path to a temporary directory del s_pipeline logger.debug(f"Loading image pipeline from {weights_dir}...") + # Build preview pipeline on CPU to avoid any overlap with training models on GPU if args.model_type == "SDXL": s_pipeline = StableDiffusionXLPipeline.from_pretrained( weights_dir, - vae=vae, revision=args.revision, - torch_dtype=weight_dtype, + torch_dtype=torch.float32, low_cpu_mem_usage=False, device_map=None, ) else: s_pipeline = StableDiffusionPipeline.from_pretrained( weights_dir, - vae=vae, revision=args.revision, - torch_dtype=weight_dtype, + torch_dtype=torch.float32, low_cpu_mem_usage=False, device_map=None, ) @@ -1467,13 +1466,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR if args.use_lora: s_pipeline.load_lora_weights(lora_save_file) - try: - s_pipeline.enable_vae_tiling() - s_pipeline.enable_vae_slicing() - s_pipeline.enable_sequential_cpu_offload() - s_pipeline.enable_xformers_memory_efficient_attention() - except: - pass + # Do not enable GPU-offload/xformers here; keep preview on CPU to avoid VRAM spikes samples = [] sample_prompts = []