Turn OFF TE for now. ( not supported )

pull/3485/head
GameSpy 2026-02-11 19:37:43 +01:00
parent 8009f31d78
commit 7906a4229e
3 changed files with 40 additions and 10 deletions

View File

@ -162,8 +162,8 @@ class animaTraining:
with gr.Row():
self.anima_cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
value=self.config.get("anima.anima_cache_text_encoder_outputs", False),
info="Cache Qwen3 outputs to reduce VRAM. Recommended when not training text encoder LoRA.",
value=self.config.get("anima.anima_cache_text_encoder_outputs", True),
info="Cache Qwen3 outputs to reduce VRAM. Enabled by default: TE LoRA is not supported at inference for Anima.",
interactive=True,
)
self.anima_cache_text_encoder_outputs_to_disk = gr.Checkbox(
@ -215,4 +215,4 @@ class animaTraining:
lambda anima_checkbox: gr.Accordion(visible=anima_checkbox),
inputs=[self.anima_checkbox],
outputs=[anima_accordion],
)
)

View File

@ -1590,7 +1590,9 @@ def train_model(
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
network_train_text_encoder_only = text_encoder_lr_float != 0 and unet_lr_float == 0
# Flag to train unet only if its learning rate is non-zero and text encoder's is zero.
network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0
# For Anima: always unet-only — TE LoRA is not applied at inference, so training it wastes
# memory and produces unreloadable lora_te_* keys that confuse ComfyUI.
network_train_unet_only = (text_encoder_lr_float == 0 and unet_lr_float != 0) or anima_checkbox
clip_l_value = None
if sd3_checkbox:
@ -3253,4 +3255,4 @@ def lora_tab(
folders.reg_data_dir,
folders.output_dir,
folders.logging_dir,
)
)

View File

@ -216,14 +216,20 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
clean_memory_on_device(accelerator.device)
else:
# move text encoder to device for encoding during training/validation
text_encoders[0].to(accelerator.device)
# Keep TE on CPU — it will be moved to GPU per-step in process_batch (step-level offloading).
# Qwen3-0.6B (~1.2GB bf16) staying in VRAM throughout training is wasteful when frozen.
logger.info("Text encoder kept on CPU; will be moved to GPU per step for encoding.")
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
qwen3_te = te[0] if te is not None else None
# Step-level TE offloading: move to GPU for sampling, back to CPU after
te_was_on_cpu = qwen3_te is not None and qwen3_te.device.type == "cpu"
if te_was_on_cpu:
qwen3_te.to(accelerator.device)
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
anima_train_utils.sample_images(
@ -239,6 +245,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
self.sample_prompts_te_outputs,
)
if te_was_on_cpu:
qwen3_te.to("cpu")
clean_memory_on_device(accelerator.device)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
return noise_scheduler
@ -340,7 +350,19 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
"""Override base process_batch for caption dropout with cached text encoder outputs."""
"""Override base process_batch for caption dropout and step-level TE CPU offloading."""
# Step-level TE offloading: move Qwen3 to GPU only for this step, then back to CPU.
# Keeps ~1.2GB (Qwen3-0.6B bf16) free from VRAM when TE is frozen and not caching.
te_was_on_cpu = (
not args.cache_text_encoder_outputs
and text_encoders is not None
and len(text_encoders) > 0
and text_encoders[0] is not None
and text_encoders[0].device.type == "cpu"
)
if te_was_on_cpu:
text_encoders[0].to(accelerator.device)
# Text encoder conditions
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
@ -355,7 +377,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
return super().process_batch(
result = super().process_batch(
batch,
text_encoders,
unet,
@ -373,6 +395,12 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_unet,
)
if te_was_on_cpu:
text_encoders[0].to("cpu")
clean_memory_on_device(accelerator.device)
return result
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
@ -445,4 +473,4 @@ if __name__ == "__main__":
args.attn_mode = "torch" # backward compatibility
trainer = AnimaNetworkTrainer()
trainer.train(args)
trainer.train(args)