kohya_ss/kohya_gui/class_anima.py

217 lines
10 KiB
Python

import gradio as gr
from .common_gui import (
get_any_file_path,
document_symbol,
)
class animaTraining:
def __init__(
self,
headless: bool = False,
finetuning: bool = False,
config: dict = {},
anima_checkbox: gr.Checkbox = False,
) -> None:
self.headless = headless
self.finetuning = finetuning
self.config = config
self.anima_checkbox = anima_checkbox
with gr.Accordion(
"Anima", open=True, visible=False, elem_classes=["anima_background"]
) as anima_accordion:
with gr.Group():
gr.Markdown("### Anima Model Paths")
with gr.Row():
self.qwen3 = gr.Textbox(
label="Qwen3-0.6B Text Encoder Path",
placeholder="Path to Qwen3-0.6B model directory or .safetensors",
value=self.config.get("anima.qwen3", ""),
interactive=True,
)
self.qwen3_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.qwen3_button.click(
get_any_file_path,
outputs=self.qwen3,
show_progress=False,
)
self.anima_vae = gr.Textbox(
label="Qwen-Image VAE Path",
placeholder="Path to Qwen-Image VAE .safetensors or .pth",
value=self.config.get("anima.anima_vae", ""),
interactive=True,
)
self.anima_vae_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.anima_vae_button.click(
get_any_file_path,
outputs=self.anima_vae,
show_progress=False,
)
with gr.Row():
self.llm_adapter_path = gr.Textbox(
label="LLM Adapter Path (Optional)",
placeholder="Path to LLM adapter .safetensors. If empty, loaded from DiT if present.",
value=self.config.get("anima.llm_adapter_path", ""),
interactive=True,
)
self.llm_adapter_path_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.llm_adapter_path_button.click(
get_any_file_path,
outputs=self.llm_adapter_path,
show_progress=False,
)
self.t5_tokenizer_path = gr.Textbox(
label="T5 Tokenizer Path (Optional)",
placeholder="Path to T5 tokenizer directory. If empty, uses bundled configs/t5_old/.",
value=self.config.get("anima.t5_tokenizer_path", ""),
interactive=True,
)
self.t5_tokenizer_path_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.t5_tokenizer_path_button.click(
get_any_file_path,
outputs=self.t5_tokenizer_path,
show_progress=False,
)
gr.Markdown("### Anima Training Parameters")
with gr.Row():
self.anima_timestep_sampling = gr.Dropdown(
label="Timestep Sampling",
choices=["sigmoid", "sigma", "uniform", "shift", "flux_shift"],
value=self.config.get("anima.anima_timestep_sampling", "sigmoid"),
info="Timestep sampling method. Same options as FLUX training. Default: sigmoid.",
interactive=True,
)
self.anima_discrete_flow_shift = gr.Number(
label="Discrete Flow Shift",
value=self.config.get("anima.anima_discrete_flow_shift", 1.0),
info="Shift for timestep distribution in Rectified Flow. Default 1.0. Used when timestep_sampling=shift.",
minimum=0.0,
maximum=100.0,
step=0.1,
interactive=True,
)
self.anima_sigmoid_scale = gr.Number(
label="Sigmoid Scale",
value=self.config.get("anima.anima_sigmoid_scale", 1.0),
info="Scale factor for sigmoid/shift/flux_shift timestep sampling. Default 1.0.",
minimum=0.001,
maximum=100.0,
step=0.1,
interactive=True,
)
with gr.Row():
self.qwen3_max_token_length = gr.Number(
label="Qwen3 Max Token Length",
value=self.config.get("anima.qwen3_max_token_length", 512),
info="Maximum token length for Qwen3 tokenizer. Default 512.",
minimum=1,
maximum=4096,
step=1,
interactive=True,
)
self.t5_max_token_length = gr.Number(
label="T5 Max Token Length",
value=self.config.get("anima.t5_max_token_length", 512),
info="Maximum token length for T5 tokenizer. Default 512.",
minimum=1,
maximum=4096,
step=1,
interactive=True,
)
self.anima_split_attn = gr.Checkbox(
label="Split Attention",
value=self.config.get("anima.anima_split_attn", False),
info="Split attention per-sequence to save memory. Optional with xformers (uses BlockDiagonalMask otherwise). Useful when xformers lacks mask support or for max VRAM savings.",
interactive=True,
)
gr.Markdown("### Memory & Speed")
with gr.Row():
self.anima_cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
value=self.config.get("anima.anima_cache_text_encoder_outputs", True),
info="Cache Qwen3 outputs to reduce VRAM. Enabled by default: TE LoRA is not supported at inference for Anima.",
interactive=True,
)
self.anima_cache_text_encoder_outputs_to_disk = gr.Checkbox(
label="Cache Text Encoder Outputs to Disk",
value=self.config.get("anima.anima_cache_text_encoder_outputs_to_disk", False),
info="Cache text encoder outputs to disk.",
interactive=True,
)
self.anima_blocks_to_swap = gr.Slider(
label="Blocks to Swap",
value=self.config.get("anima.anima_blocks_to_swap", 0),
info="Number of Transformer blocks to swap CPU<->GPU. 28-block model: max 26. Reduces VRAM at cost of speed.",
minimum=0,
maximum=34,
step=1,
interactive=True,
)
self.anima_unsloth_offload_checkpointing = gr.Checkbox(
label="Unsloth Offload Checkpointing",
value=self.config.get("anima.anima_unsloth_offload_checkpointing", False),
info="Offload activations to CPU RAM using async non-blocking transfers. Faster than cpu_offload_checkpointing. Cannot combine with blocks_to_swap.",
interactive=True,
)
self.anima_disable_mmap_load_safetensors = gr.Checkbox(
label="Disable mmap Load",
value=self.config.get("anima.anima_disable_mmap_load_safetensors", False),
info="Disable mmap for safetensors loading. Speeds up model loading on WSL2 or network drives.",
interactive=True,
)
with gr.Row():
self.vae_chunk_size = gr.Number(
label="VAE Chunk Size",
value=self.config.get("anima.vae_chunk_size", 0),
info="Chunk size for Qwen-Image VAE processing to reduce VRAM. 0 = no chunking.",
minimum=0,
maximum=1024,
step=8,
interactive=True,
)
self.vae_disable_cache = gr.Checkbox(
label="VAE Disable Cache",
value=self.config.get("anima.vae_disable_cache", False),
info="Disable internal caching in Qwen-Image VAE to reduce VRAM.",
interactive=True,
)
self.anima_train_llm_adapter = gr.Checkbox(
label="Train LLM Adapter LoRA",
value=self.config.get("anima.anima_train_llm_adapter", False),
info="Apply LoRA to LLM Adapter blocks (6-layer transformer bridge from Qwen3 to T5-compatible space). Only supported with LoRA type 'Anima', ignored for LyCORIS variants.",
interactive=True,
)
self.anima_checkbox.change(
lambda anima_checkbox: gr.Accordion(visible=anima_checkbox),
inputs=[self.anima_checkbox],
outputs=[anima_accordion],
)