mirror of https://github.com/bmaltais/kohya_ss
init anima UI
parent
4f45d7cf17
commit
8009f31d78
|
|
@ -0,0 +1,218 @@
|
|||
import gradio as gr
|
||||
from .common_gui import (
|
||||
get_any_file_path,
|
||||
document_symbol,
|
||||
)
|
||||
|
||||
|
||||
class animaTraining:
|
||||
def __init__(
|
||||
self,
|
||||
headless: bool = False,
|
||||
finetuning: bool = False,
|
||||
config: dict = {},
|
||||
anima_checkbox: gr.Checkbox = False,
|
||||
) -> None:
|
||||
self.headless = headless
|
||||
self.finetuning = finetuning
|
||||
self.config = config
|
||||
self.anima_checkbox = anima_checkbox
|
||||
|
||||
with gr.Accordion(
|
||||
"Anima", open=True, visible=False, elem_classes=["anima_background"]
|
||||
) as anima_accordion:
|
||||
with gr.Group():
|
||||
gr.Markdown("### Anima Model Paths")
|
||||
with gr.Row():
|
||||
self.qwen3 = gr.Textbox(
|
||||
label="Qwen3-0.6B Text Encoder Path",
|
||||
placeholder="Path to Qwen3-0.6B model directory or .safetensors",
|
||||
value=self.config.get("anima.qwen3", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.qwen3_button = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
interactive=True,
|
||||
)
|
||||
self.qwen3_button.click(
|
||||
get_any_file_path,
|
||||
outputs=self.qwen3,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
self.anima_vae = gr.Textbox(
|
||||
label="Qwen-Image VAE Path",
|
||||
placeholder="Path to Qwen-Image VAE .safetensors or .pth",
|
||||
value=self.config.get("anima.anima_vae", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_vae_button = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_vae_button.click(
|
||||
get_any_file_path,
|
||||
outputs=self.anima_vae,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.llm_adapter_path = gr.Textbox(
|
||||
label="LLM Adapter Path (Optional)",
|
||||
placeholder="Path to LLM adapter .safetensors. If empty, loaded from DiT if present.",
|
||||
value=self.config.get("anima.llm_adapter_path", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.llm_adapter_path_button = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
interactive=True,
|
||||
)
|
||||
self.llm_adapter_path_button.click(
|
||||
get_any_file_path,
|
||||
outputs=self.llm_adapter_path,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
self.t5_tokenizer_path = gr.Textbox(
|
||||
label="T5 Tokenizer Path (Optional)",
|
||||
placeholder="Path to T5 tokenizer directory. If empty, uses bundled configs/t5_old/.",
|
||||
value=self.config.get("anima.t5_tokenizer_path", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.t5_tokenizer_path_button = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
interactive=True,
|
||||
)
|
||||
self.t5_tokenizer_path_button.click(
|
||||
get_any_file_path,
|
||||
outputs=self.t5_tokenizer_path,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
gr.Markdown("### Anima Training Parameters")
|
||||
with gr.Row():
|
||||
self.anima_timestep_sampling = gr.Dropdown(
|
||||
label="Timestep Sampling",
|
||||
choices=["sigmoid", "sigma", "uniform", "shift", "flux_shift"],
|
||||
value=self.config.get("anima.anima_timestep_sampling", "sigmoid"),
|
||||
info="Timestep sampling method. Same options as FLUX training. Default: sigmoid.",
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_discrete_flow_shift = gr.Number(
|
||||
label="Discrete Flow Shift",
|
||||
value=self.config.get("anima.anima_discrete_flow_shift", 1.0),
|
||||
info="Shift for timestep distribution in Rectified Flow. Default 1.0. Used when timestep_sampling=shift.",
|
||||
minimum=0.0,
|
||||
maximum=100.0,
|
||||
step=0.1,
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_sigmoid_scale = gr.Number(
|
||||
label="Sigmoid Scale",
|
||||
value=self.config.get("anima.anima_sigmoid_scale", 1.0),
|
||||
info="Scale factor for sigmoid/shift/flux_shift timestep sampling. Default 1.0.",
|
||||
minimum=0.001,
|
||||
maximum=100.0,
|
||||
step=0.1,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.qwen3_max_token_length = gr.Number(
|
||||
label="Qwen3 Max Token Length",
|
||||
value=self.config.get("anima.qwen3_max_token_length", 512),
|
||||
info="Maximum token length for Qwen3 tokenizer. Default 512.",
|
||||
minimum=1,
|
||||
maximum=4096,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
self.t5_max_token_length = gr.Number(
|
||||
label="T5 Max Token Length",
|
||||
value=self.config.get("anima.t5_max_token_length", 512),
|
||||
info="Maximum token length for T5 tokenizer. Default 512.",
|
||||
minimum=1,
|
||||
maximum=4096,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_attn_mode = gr.Dropdown(
|
||||
label="Attention Mode",
|
||||
choices=["torch", "xformers", "flash", "sageattn"],
|
||||
value=self.config.get("anima.anima_attn_mode", "torch"),
|
||||
info="Attention implementation. xformers requires split_attn. sageattn is inference-only.",
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_split_attn = gr.Checkbox(
|
||||
label="Split Attention",
|
||||
value=self.config.get("anima.anima_split_attn", False),
|
||||
info="Split attention computation to reduce memory. Required when using xformers attn_mode.",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
gr.Markdown("### Memory & Speed")
|
||||
with gr.Row():
|
||||
self.anima_cache_text_encoder_outputs = gr.Checkbox(
|
||||
label="Cache Text Encoder Outputs",
|
||||
value=self.config.get("anima.anima_cache_text_encoder_outputs", False),
|
||||
info="Cache Qwen3 outputs to reduce VRAM. Recommended when not training text encoder LoRA.",
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_cache_text_encoder_outputs_to_disk = gr.Checkbox(
|
||||
label="Cache Text Encoder Outputs to Disk",
|
||||
value=self.config.get("anima.anima_cache_text_encoder_outputs_to_disk", False),
|
||||
info="Cache text encoder outputs to disk.",
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_blocks_to_swap = gr.Slider(
|
||||
label="Blocks to Swap",
|
||||
value=self.config.get("anima.anima_blocks_to_swap", 0),
|
||||
info="Number of Transformer blocks to swap CPU<->GPU. 28-block model: max 26. Reduces VRAM at cost of speed.",
|
||||
minimum=0,
|
||||
maximum=34,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_unsloth_offload_checkpointing = gr.Checkbox(
|
||||
label="Unsloth Offload Checkpointing",
|
||||
value=self.config.get("anima.anima_unsloth_offload_checkpointing", False),
|
||||
info="Offload activations to CPU RAM using async non-blocking transfers. Faster than cpu_offload_checkpointing. Cannot combine with blocks_to_swap.",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.vae_chunk_size = gr.Number(
|
||||
label="VAE Chunk Size",
|
||||
value=self.config.get("anima.vae_chunk_size", 0),
|
||||
info="Chunk size for Qwen-Image VAE processing to reduce VRAM. 0 = no chunking.",
|
||||
minimum=0,
|
||||
maximum=1024,
|
||||
step=8,
|
||||
interactive=True,
|
||||
)
|
||||
self.vae_disable_cache = gr.Checkbox(
|
||||
label="VAE Disable Cache",
|
||||
value=self.config.get("anima.vae_disable_cache", False),
|
||||
info="Disable internal caching in Qwen-Image VAE to reduce VRAM.",
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_train_llm_adapter = gr.Checkbox(
|
||||
label="Train LLM Adapter LoRA",
|
||||
value=self.config.get("anima.anima_train_llm_adapter", False),
|
||||
info="Apply LoRA to LLM Adapter blocks (6-layer transformer bridge from Qwen3 to T5-compatible space). Only supported with LoRA type 'Anima', ignored for LyCORIS variants.",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
self.anima_checkbox.change(
|
||||
lambda anima_checkbox: gr.Accordion(visible=anima_checkbox),
|
||||
inputs=[self.anima_checkbox],
|
||||
outputs=[anima_accordion],
|
||||
)
|
||||
|
|
@ -275,58 +275,45 @@ class SourceModel:
|
|||
min_width=60,
|
||||
interactive=True,
|
||||
)
|
||||
self.anima_checkbox = gr.Checkbox(
|
||||
label="Anima",
|
||||
value=False,
|
||||
visible=False,
|
||||
min_width=70,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
def toggle_checkboxes(v2, v_parameterization, sdxl_checkbox, sd3_checkbox, flux1_checkbox):
|
||||
def toggle_checkboxes(v2, v_parameterization, sdxl_checkbox, sd3_checkbox, flux1_checkbox, anima_checkbox):
|
||||
# Check if all checkboxes are unchecked
|
||||
if not v2 and not sdxl_checkbox and not sd3_checkbox and not flux1_checkbox:
|
||||
if not v2 and not sdxl_checkbox and not sd3_checkbox and not flux1_checkbox and not anima_checkbox:
|
||||
# If all unchecked, return new interactive checkboxes
|
||||
return (
|
||||
gr.Checkbox(interactive=True), # v2 checkbox
|
||||
gr.Checkbox(interactive=False, value=False), # v_parameterization checkbox
|
||||
gr.Checkbox(interactive=True), # sdxl_checkbox
|
||||
gr.Checkbox(interactive=True), # sd3_checkbox
|
||||
gr.Checkbox(interactive=True), # sd3_checkbox
|
||||
gr.Checkbox(interactive=True), # v2
|
||||
gr.Checkbox(interactive=False, value=False), # v_parameterization
|
||||
gr.Checkbox(interactive=True), # sdxl_checkbox
|
||||
gr.Checkbox(interactive=True), # sd3_checkbox
|
||||
gr.Checkbox(interactive=True), # flux1_checkbox
|
||||
gr.Checkbox(interactive=True), # anima_checkbox
|
||||
)
|
||||
else:
|
||||
# If any checkbox is checked, return checkboxes with current interactive state
|
||||
# If any checkbox is checked, only allow that one to be toggled
|
||||
return (
|
||||
gr.Checkbox(interactive=v2), # v2 checkbox
|
||||
gr.Checkbox(interactive=sdxl_checkbox), # v_parameterization checkbox
|
||||
gr.Checkbox(interactive=v2), # v2
|
||||
gr.Checkbox(interactive=sdxl_checkbox), # v_parameterization
|
||||
gr.Checkbox(interactive=sdxl_checkbox), # sdxl_checkbox
|
||||
gr.Checkbox(interactive=sd3_checkbox), # sd3_checkbox
|
||||
gr.Checkbox(interactive=flux1_checkbox), # flux1_checkbox
|
||||
gr.Checkbox(interactive=sd3_checkbox), # sd3_checkbox
|
||||
gr.Checkbox(interactive=flux1_checkbox), # flux1_checkbox
|
||||
gr.Checkbox(interactive=anima_checkbox), # anima_checkbox
|
||||
)
|
||||
|
||||
self.v2.change(
|
||||
fn=toggle_checkboxes,
|
||||
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
show_progress=False,
|
||||
)
|
||||
self.v_parameterization.change(
|
||||
fn=toggle_checkboxes,
|
||||
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
show_progress=False,
|
||||
)
|
||||
self.sdxl_checkbox.change(
|
||||
fn=toggle_checkboxes,
|
||||
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
show_progress=False,
|
||||
)
|
||||
self.sd3_checkbox.change(
|
||||
fn=toggle_checkboxes,
|
||||
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
show_progress=False,
|
||||
)
|
||||
self.flux1_checkbox.change(
|
||||
fn=toggle_checkboxes,
|
||||
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
show_progress=False,
|
||||
)
|
||||
_all_checkboxes = [self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox, self.anima_checkbox]
|
||||
for _cb in [self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox, self.anima_checkbox]:
|
||||
_cb.change(
|
||||
fn=toggle_checkboxes,
|
||||
inputs=_all_checkboxes,
|
||||
outputs=_all_checkboxes,
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Column():
|
||||
gr.Group(visible=False)
|
||||
|
||||
|
|
@ -364,6 +351,7 @@ class SourceModel:
|
|||
self.sdxl_checkbox,
|
||||
self.sd3_checkbox,
|
||||
self.flux1_checkbox,
|
||||
self.anima_checkbox,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -966,6 +966,7 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl = gr.Checkbox(value=True, visible=False)
|
||||
sd3 = gr.Checkbox(value=False, visible=False)
|
||||
flux1 = gr.Checkbox(value=False, visible=False)
|
||||
anima = gr.Checkbox(value=False, visible=False)
|
||||
return (
|
||||
gr.Dropdown(),
|
||||
v2,
|
||||
|
|
@ -973,6 +974,7 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl,
|
||||
sd3,
|
||||
flux1,
|
||||
anima,
|
||||
)
|
||||
|
||||
# Check if the given pretrained_model_name_or_path is in the list of V2 base models
|
||||
|
|
@ -983,6 +985,7 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl = gr.Checkbox(value=False, visible=False)
|
||||
sd3 = gr.Checkbox(value=False, visible=False)
|
||||
flux1 = gr.Checkbox(value=False, visible=False)
|
||||
anima = gr.Checkbox(value=False, visible=False)
|
||||
return (
|
||||
gr.Dropdown(),
|
||||
v2,
|
||||
|
|
@ -990,6 +993,7 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl,
|
||||
sd3,
|
||||
flux1,
|
||||
anima,
|
||||
)
|
||||
|
||||
# Check if the given pretrained_model_name_or_path is in the list of V parameterization models
|
||||
|
|
@ -1002,6 +1006,7 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl = gr.Checkbox(value=False, visible=False)
|
||||
sd3 = gr.Checkbox(value=False, visible=False)
|
||||
flux1 = gr.Checkbox(value=False, visible=False)
|
||||
anima = gr.Checkbox(value=False, visible=False)
|
||||
return (
|
||||
gr.Dropdown(),
|
||||
v2,
|
||||
|
|
@ -1009,6 +1014,7 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl,
|
||||
sd3,
|
||||
flux1,
|
||||
anima,
|
||||
)
|
||||
|
||||
# Check if the given pretrained_model_name_or_path is in the list of V1 models
|
||||
|
|
@ -1019,6 +1025,7 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl = gr.Checkbox(value=False, visible=False)
|
||||
sd3 = gr.Checkbox(value=False, visible=False)
|
||||
flux1 = gr.Checkbox(value=False, visible=False)
|
||||
anima = gr.Checkbox(value=False, visible=False)
|
||||
return (
|
||||
gr.Dropdown(),
|
||||
v2,
|
||||
|
|
@ -1026,6 +1033,7 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl,
|
||||
sd3,
|
||||
flux1,
|
||||
anima,
|
||||
)
|
||||
|
||||
# Check if the model_list is set to 'custom'
|
||||
|
|
@ -1034,6 +1042,7 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl = gr.Checkbox(visible=True)
|
||||
sd3 = gr.Checkbox(visible=True)
|
||||
flux1 = gr.Checkbox(visible=True)
|
||||
anima = gr.Checkbox(visible=True)
|
||||
|
||||
# Auto-detect model type if safetensors file path is given
|
||||
if pretrained_model_name_or_path.lower().endswith(".safetensors"):
|
||||
|
|
@ -1058,6 +1067,7 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl,
|
||||
sd3,
|
||||
flux1,
|
||||
anima,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from .class_huggingface import HuggingFace
|
|||
from .class_metadata import MetaData
|
||||
from .class_gui_config import KohyaSSGUIConfig
|
||||
from .class_flux1 import flux1Training
|
||||
from .class_anima import animaTraining
|
||||
|
||||
from .dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
|
|
@ -319,6 +320,26 @@ def save_configuration(
|
|||
sd3_text_encoder_batch_size,
|
||||
weighting_scheme,
|
||||
sd3_checkbox,
|
||||
# Anima parameters
|
||||
anima_checkbox,
|
||||
anima_qwen3,
|
||||
anima_vae,
|
||||
anima_llm_adapter_path,
|
||||
anima_t5_tokenizer_path,
|
||||
anima_timestep_sampling,
|
||||
anima_discrete_flow_shift,
|
||||
anima_sigmoid_scale,
|
||||
anima_qwen3_max_token_length,
|
||||
anima_t5_max_token_length,
|
||||
anima_attn_mode,
|
||||
anima_split_attn,
|
||||
anima_cache_text_encoder_outputs,
|
||||
anima_cache_text_encoder_outputs_to_disk,
|
||||
anima_blocks_to_swap,
|
||||
anima_unsloth_offload_checkpointing,
|
||||
anima_vae_chunk_size,
|
||||
anima_vae_disable_cache,
|
||||
anima_train_llm_adapter,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -605,6 +626,26 @@ def open_configuration(
|
|||
sd3_text_encoder_batch_size,
|
||||
weighting_scheme,
|
||||
sd3_checkbox,
|
||||
# Anima parameters
|
||||
anima_checkbox,
|
||||
anima_qwen3,
|
||||
anima_vae,
|
||||
anima_llm_adapter_path,
|
||||
anima_t5_tokenizer_path,
|
||||
anima_timestep_sampling,
|
||||
anima_discrete_flow_shift,
|
||||
anima_sigmoid_scale,
|
||||
anima_qwen3_max_token_length,
|
||||
anima_t5_max_token_length,
|
||||
anima_attn_mode,
|
||||
anima_split_attn,
|
||||
anima_cache_text_encoder_outputs,
|
||||
anima_cache_text_encoder_outputs_to_disk,
|
||||
anima_blocks_to_swap,
|
||||
anima_unsloth_offload_checkpointing,
|
||||
anima_vae_chunk_size,
|
||||
anima_vae_disable_cache,
|
||||
anima_train_llm_adapter,
|
||||
##
|
||||
training_preset,
|
||||
):
|
||||
|
|
@ -982,6 +1023,26 @@ def train_model(
|
|||
sd3_text_encoder_batch_size,
|
||||
weighting_scheme,
|
||||
sd3_checkbox,
|
||||
# Anima parameters
|
||||
anima_checkbox,
|
||||
anima_qwen3,
|
||||
anima_vae,
|
||||
anima_llm_adapter_path,
|
||||
anima_t5_tokenizer_path,
|
||||
anima_timestep_sampling,
|
||||
anima_discrete_flow_shift,
|
||||
anima_sigmoid_scale,
|
||||
anima_qwen3_max_token_length,
|
||||
anima_t5_max_token_length,
|
||||
anima_attn_mode,
|
||||
anima_split_attn,
|
||||
anima_cache_text_encoder_outputs,
|
||||
anima_cache_text_encoder_outputs_to_disk,
|
||||
anima_blocks_to_swap,
|
||||
anima_unsloth_offload_checkpointing,
|
||||
anima_vae_chunk_size,
|
||||
anima_vae_disable_cache,
|
||||
anima_train_llm_adapter,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -1019,6 +1080,24 @@ def train_model(
|
|||
)
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
if anima_checkbox:
|
||||
log.info(f"Validating lora type is Anima if Anima checkbox is checked...")
|
||||
if LoRA_type != "Anima" and "LyCORIS" not in LoRA_type:
|
||||
log.error(
|
||||
"LoRA type must be set to 'Anima' or 'LyCORIS' if Anima checkbox is checked."
|
||||
)
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
if not anima_qwen3:
|
||||
log.error(
|
||||
"Qwen3 Text Encoder path is required for Anima training. Please set it in the Anima section."
|
||||
)
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
if not anima_vae:
|
||||
log.error(
|
||||
"Anima VAE path is required for Anima training. Please set it in the Anima section."
|
||||
)
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
#
|
||||
# Validate paths
|
||||
#
|
||||
|
|
@ -1265,6 +1344,8 @@ def train_model(
|
|||
run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train_network.py")
|
||||
elif sd3_checkbox:
|
||||
run_cmd.append(rf"{scriptdir}/sd-scripts/sd3_train_network.py")
|
||||
elif anima_checkbox:
|
||||
run_cmd.append(rf"{scriptdir}/sd-scripts/anima_train_network.py")
|
||||
else:
|
||||
run_cmd.append(rf"{scriptdir}/sd-scripts/train_network.py")
|
||||
|
||||
|
|
@ -1375,6 +1456,11 @@ def train_model(
|
|||
if value:
|
||||
network_args += f" {key}={value}"
|
||||
|
||||
if LoRA_type == "Anima":
|
||||
network_module = "networks.lora_anima"
|
||||
if anima_train_llm_adapter:
|
||||
network_args += " train_llm_adapter=True"
|
||||
|
||||
if LoRA_type in ["Kohya LoCon", "Standard"]:
|
||||
kohya_lora_var_list = [
|
||||
"down_lr_weight",
|
||||
|
|
@ -1540,14 +1626,14 @@ def train_model(
|
|||
if (sdxl and sdxl_cache_text_encoder_outputs)
|
||||
or (flux1_checkbox and flux1_cache_text_encoder_outputs)
|
||||
or (sd3_checkbox and sd3_cache_text_encoder_outputs)
|
||||
or (anima_checkbox and anima_cache_text_encoder_outputs)
|
||||
else None
|
||||
),
|
||||
"cache_text_encoder_outputs_to_disk": (
|
||||
True
|
||||
if flux1_checkbox
|
||||
and flux1_cache_text_encoder_outputs_to_disk
|
||||
or sd3_checkbox
|
||||
and sd3_cache_text_encoder_outputs_to_disk
|
||||
if (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk)
|
||||
or (sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk)
|
||||
or (anima_checkbox and anima_cache_text_encoder_outputs_to_disk)
|
||||
else None
|
||||
),
|
||||
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
|
||||
|
|
@ -1609,7 +1695,7 @@ def train_model(
|
|||
"max_bucket_reso": max_bucket_reso,
|
||||
"max_grad_norm": max_grad_norm,
|
||||
"max_timestep": max_timestep if max_timestep != 0 else None,
|
||||
"max_token_length": int(max_token_length) if not flux1_checkbox else None,
|
||||
"max_token_length": int(max_token_length) if not flux1_checkbox and not anima_checkbox else None,
|
||||
"max_train_epochs": (
|
||||
int(max_train_epochs) if int(max_train_epochs) != 0 else None
|
||||
),
|
||||
|
|
@ -1693,7 +1779,7 @@ def train_model(
|
|||
"save_state_to_huggingface": save_state_to_huggingface,
|
||||
"scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred,
|
||||
"scale_weight_norms": scale_weight_norms,
|
||||
"sdpa": True if xformers == "sdpa" else None,
|
||||
"sdpa": True if xformers == "sdpa" and not anima_checkbox else None,
|
||||
"seed": int(seed) if int(seed) != 0 else None,
|
||||
"shuffle_caption": shuffle_caption,
|
||||
"skip_cache_check": skip_cache_check,
|
||||
|
|
@ -1709,12 +1795,12 @@ def train_model(
|
|||
"v2": v2,
|
||||
"v_parameterization": v_parameterization,
|
||||
"v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None,
|
||||
"vae": vae,
|
||||
"vae": anima_vae if anima_checkbox else vae,
|
||||
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
|
||||
"wandb_api_key": wandb_api_key,
|
||||
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
|
||||
"weighted_captions": weighted_captions,
|
||||
"xformers": True if xformers == "xformers" else None,
|
||||
"xformers": True if xformers == "xformers" and not anima_checkbox else None,
|
||||
# SD3 only Parameters
|
||||
# "cache_text_encoder_outputs": see previous assignment above for code
|
||||
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
|
||||
|
|
@ -1745,9 +1831,7 @@ def train_model(
|
|||
"ae": ae if flux1_checkbox else None,
|
||||
# "clip_l": see previous assignment above for code
|
||||
"t5xxl": t5xxl_value,
|
||||
"discrete_flow_shift": float(discrete_flow_shift) if flux1_checkbox else None,
|
||||
"model_prediction_type": model_prediction_type if flux1_checkbox else None,
|
||||
"timestep_sampling": timestep_sampling if flux1_checkbox else None,
|
||||
"split_mode": split_mode if flux1_checkbox else None,
|
||||
"t5xxl_max_token_length": (
|
||||
int(t5xxl_max_token_length) if flux1_checkbox else None
|
||||
|
|
@ -1758,9 +1842,33 @@ def train_model(
|
|||
"cpu_offload_checkpointing": (
|
||||
cpu_offload_checkpointing if flux1_checkbox else None
|
||||
),
|
||||
"blocks_to_swap": blocks_to_swap if flux1_checkbox or sd3_checkbox else None,
|
||||
"blocks_to_swap": blocks_to_swap if flux1_checkbox or sd3_checkbox else (
|
||||
int(anima_blocks_to_swap) if anima_checkbox and anima_blocks_to_swap else None
|
||||
),
|
||||
"single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None,
|
||||
"double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None,
|
||||
# Anima specific parameters
|
||||
"qwen3": anima_qwen3 if anima_checkbox else None,
|
||||
"llm_adapter_path": anima_llm_adapter_path if anima_checkbox and anima_llm_adapter_path else None,
|
||||
"t5_tokenizer_path": anima_t5_tokenizer_path if anima_checkbox and anima_t5_tokenizer_path else None,
|
||||
"timestep_sampling": (
|
||||
timestep_sampling if flux1_checkbox
|
||||
else anima_timestep_sampling if anima_checkbox
|
||||
else None
|
||||
),
|
||||
"discrete_flow_shift": (
|
||||
float(discrete_flow_shift) if flux1_checkbox
|
||||
else float(anima_discrete_flow_shift) if anima_checkbox
|
||||
else None
|
||||
),
|
||||
"sigmoid_scale": float(anima_sigmoid_scale) if anima_checkbox else None,
|
||||
"qwen3_max_token_length": int(anima_qwen3_max_token_length) if anima_checkbox else None,
|
||||
"t5_max_token_length": int(anima_t5_max_token_length) if anima_checkbox else None,
|
||||
"attn_mode": anima_attn_mode if anima_checkbox else None,
|
||||
"split_attn": anima_split_attn if anima_checkbox else None,
|
||||
"unsloth_offload_checkpointing": anima_unsloth_offload_checkpointing if anima_checkbox else None,
|
||||
"vae_chunk_size": int(anima_vae_chunk_size) if anima_checkbox and anima_vae_chunk_size else None,
|
||||
"vae_disable_cache": anima_vae_disable_cache if anima_checkbox else None,
|
||||
}
|
||||
|
||||
# Given dictionary `config_toml_data`
|
||||
|
|
@ -1923,6 +2031,7 @@ def lora_tab(
|
|||
LoRA_type = gr.Dropdown(
|
||||
label="LoRA type",
|
||||
choices=[
|
||||
"Anima",
|
||||
"Flux1",
|
||||
"Flux1 OFT",
|
||||
"Kohya DyLoRA",
|
||||
|
|
@ -2238,6 +2347,7 @@ def lora_tab(
|
|||
"update_params": {
|
||||
"visible": LoRA_type
|
||||
in {
|
||||
"Anima",
|
||||
"Flux1",
|
||||
"Flux1 OFT",
|
||||
"Kohya DyLoRA",
|
||||
|
|
@ -2292,6 +2402,7 @@ def lora_tab(
|
|||
"update_params": {
|
||||
"visible": LoRA_type
|
||||
in {
|
||||
"Anima",
|
||||
"Flux1",
|
||||
"Flux1 OFT",
|
||||
"Standard",
|
||||
|
|
@ -2314,6 +2425,7 @@ def lora_tab(
|
|||
"update_params": {
|
||||
"visible": LoRA_type
|
||||
in {
|
||||
"Anima",
|
||||
"Flux1",
|
||||
"Flux1 OFT",
|
||||
"Standard",
|
||||
|
|
@ -2336,6 +2448,7 @@ def lora_tab(
|
|||
"update_params": {
|
||||
"visible": LoRA_type
|
||||
in {
|
||||
"Anima",
|
||||
"Flux1",
|
||||
"Flux1 OFT",
|
||||
"Standard",
|
||||
|
|
@ -2679,6 +2792,11 @@ def lora_tab(
|
|||
headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox
|
||||
)
|
||||
|
||||
# Add Anima Parameters
|
||||
anima_training = animaTraining(
|
||||
headless=headless, config=config, anima_checkbox=source_model.anima_checkbox
|
||||
)
|
||||
|
||||
with gr.Accordion(
|
||||
"Advanced", open=False, elem_classes="advanced_background"
|
||||
):
|
||||
|
|
@ -3029,6 +3147,26 @@ def lora_tab(
|
|||
sd3_training.sd3_text_encoder_batch_size,
|
||||
sd3_training.weighting_scheme,
|
||||
source_model.sd3_checkbox,
|
||||
# Anima Parameters
|
||||
source_model.anima_checkbox,
|
||||
anima_training.qwen3,
|
||||
anima_training.anima_vae,
|
||||
anima_training.llm_adapter_path,
|
||||
anima_training.t5_tokenizer_path,
|
||||
anima_training.anima_timestep_sampling,
|
||||
anima_training.anima_discrete_flow_shift,
|
||||
anima_training.anima_sigmoid_scale,
|
||||
anima_training.qwen3_max_token_length,
|
||||
anima_training.t5_max_token_length,
|
||||
anima_training.anima_attn_mode,
|
||||
anima_training.anima_split_attn,
|
||||
anima_training.anima_cache_text_encoder_outputs,
|
||||
anima_training.anima_cache_text_encoder_outputs_to_disk,
|
||||
anima_training.anima_blocks_to_swap,
|
||||
anima_training.anima_unsloth_offload_checkpointing,
|
||||
anima_training.vae_chunk_size,
|
||||
anima_training.vae_disable_cache,
|
||||
anima_training.anima_train_llm_adapter,
|
||||
]
|
||||
|
||||
configuration.button_open_config.click(
|
||||
|
|
|
|||
Loading…
Reference in New Issue