mirror of https://github.com/bmaltais/kohya_ss
Turn OFF TE for now. ( not supported )
parent
8009f31d78
commit
7906a4229e
|
|
@ -162,8 +162,8 @@ class animaTraining:
|
|||
with gr.Row():
|
||||
self.anima_cache_text_encoder_outputs = gr.Checkbox(
|
||||
label="Cache Text Encoder Outputs",
|
||||
value=self.config.get("anima.anima_cache_text_encoder_outputs", False),
|
||||
info="Cache Qwen3 outputs to reduce VRAM. Recommended when not training text encoder LoRA.",
|
||||
value=self.config.get("anima.anima_cache_text_encoder_outputs", True),
|
||||
info="Cache Qwen3 outputs to reduce VRAM. Enabled by default: TE LoRA is not supported at inference for Anima.",
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_cache_text_encoder_outputs_to_disk = gr.Checkbox(
|
||||
|
|
@ -215,4 +215,4 @@ class animaTraining:
|
|||
lambda anima_checkbox: gr.Accordion(visible=anima_checkbox),
|
||||
inputs=[self.anima_checkbox],
|
||||
outputs=[anima_accordion],
|
||||
)
|
||||
)
|
||||
|
|
@ -1590,7 +1590,9 @@ def train_model(
|
|||
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
|
||||
network_train_text_encoder_only = text_encoder_lr_float != 0 and unet_lr_float == 0
|
||||
# Flag to train unet only if its learning rate is non-zero and text encoder's is zero.
|
||||
network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0
|
||||
# For Anima: always unet-only — TE LoRA is not applied at inference, so training it wastes
|
||||
# memory and produces unreloadable lora_te_* keys that confuse ComfyUI.
|
||||
network_train_unet_only = (text_encoder_lr_float == 0 and unet_lr_float != 0) or anima_checkbox
|
||||
|
||||
clip_l_value = None
|
||||
if sd3_checkbox:
|
||||
|
|
@ -3253,4 +3255,4 @@ def lora_tab(
|
|||
folders.reg_data_dir,
|
||||
folders.output_dir,
|
||||
folders.logging_dir,
|
||||
)
|
||||
)
|
||||
|
|
@ -216,14 +216,20 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
|||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
else:
|
||||
# move text encoder to device for encoding during training/validation
|
||||
text_encoders[0].to(accelerator.device)
|
||||
# Keep TE on CPU — it will be moved to GPU per-step in process_batch (step-level offloading).
|
||||
# Qwen3-0.6B (~1.2GB bf16) staying in VRAM throughout training is wasteful when frozen.
|
||||
logger.info("Text encoder kept on CPU; will be moved to GPU per step for encoding.")
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
|
||||
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
qwen3_te = te[0] if te is not None else None
|
||||
|
||||
# Step-level TE offloading: move to GPU for sampling, back to CPU after
|
||||
te_was_on_cpu = qwen3_te is not None and qwen3_te.device.type == "cpu"
|
||||
if te_was_on_cpu:
|
||||
qwen3_te.to(accelerator.device)
|
||||
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
anima_train_utils.sample_images(
|
||||
|
|
@ -239,6 +245,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
|||
self.sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
if te_was_on_cpu:
|
||||
qwen3_te.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
return noise_scheduler
|
||||
|
|
@ -340,7 +350,19 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
|||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> torch.Tensor:
|
||||
"""Override base process_batch for caption dropout with cached text encoder outputs."""
|
||||
"""Override base process_batch for caption dropout and step-level TE CPU offloading."""
|
||||
|
||||
# Step-level TE offloading: move Qwen3 to GPU only for this step, then back to CPU.
|
||||
# Keeps ~1.2GB (Qwen3-0.6B bf16) free from VRAM when TE is frozen and not caching.
|
||||
te_was_on_cpu = (
|
||||
not args.cache_text_encoder_outputs
|
||||
and text_encoders is not None
|
||||
and len(text_encoders) > 0
|
||||
and text_encoders[0] is not None
|
||||
and text_encoders[0].device.type == "cpu"
|
||||
)
|
||||
if te_was_on_cpu:
|
||||
text_encoders[0].to(accelerator.device)
|
||||
|
||||
# Text encoder conditions
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
|
|
@ -355,7 +377,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
|||
)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||
|
||||
return super().process_batch(
|
||||
result = super().process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
|
|
@ -373,6 +395,12 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
|||
train_unet,
|
||||
)
|
||||
|
||||
if te_was_on_cpu:
|
||||
text_encoders[0].to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
return result
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
|
|
@ -445,4 +473,4 @@ if __name__ == "__main__":
|
|||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
trainer = AnimaNetworkTrainer()
|
||||
trainer.train(args)
|
||||
trainer.train(args)
|
||||
Loading…
Reference in New Issue