kohya_ss/kohya_gui/class_sd3.py

250 lines
11 KiB
Python

import gradio as gr
from typing import Tuple
from .common_gui import (
get_folder_path,
get_any_file_path,
list_files,
list_dirs,
create_refresh_button,
document_symbol,
)
class sd3Training:
"""
This class configures and initializes the advanced training settings for a machine learning model,
including options for headless operation, fine-tuning, training type selection, and default directory paths.
Attributes:
headless (bool): If True, run without the Gradio interface.
finetuning (bool): If True, enables fine-tuning of the model.
training_type (str): Specifies the type of training to perform.
no_token_padding (gr.Checkbox): Checkbox to disable token padding.
gradient_accumulation_steps (gr.Slider): Slider to set the number of gradient accumulation steps.
weighted_captions (gr.Checkbox): Checkbox to enable weighted captions.
"""
def __init__(
self,
headless: bool = False,
finetuning: bool = False,
training_type: str = "",
config: dict = {},
sd3_checkbox: gr.Checkbox = False,
) -> None:
"""
Initializes the AdvancedTraining class with given settings.
Parameters:
headless (bool): Run in headless mode without GUI.
finetuning (bool): Enable model fine-tuning.
training_type (str): The type of training to be performed.
config (dict): Configuration options for the training process.
"""
self.headless = headless
self.finetuning = finetuning
self.training_type = training_type
self.config = config
self.sd3_checkbox = sd3_checkbox
# Define the behavior for changing noise offset type.
def noise_offset_type_change(
noise_offset_type: str,
) -> Tuple[gr.Group, gr.Group]:
"""
Returns a tuple of Gradio Groups with visibility set based on the noise offset type.
Parameters:
noise_offset_type (str): The selected noise offset type.
Returns:
Tuple[gr.Group, gr.Group]: A tuple containing two Gradio Group elements with their visibility set.
"""
if noise_offset_type == "Original":
return (gr.Group(visible=True), gr.Group(visible=False))
else:
return (gr.Group(visible=False), gr.Group(visible=True))
with gr.Accordion(
"SD3", open=False, elem_id="sd3_tab", visible=False
) as sd3_accordion:
with gr.Group():
gr.Markdown("### SD3 Specific Parameters")
with gr.Row():
self.weighting_scheme = gr.Dropdown(
label="Weighting Scheme",
choices=["logit_normal", "sigma_sqrt", "mode", "cosmap", "uniform"],
value=self.config.get("sd3.weighting_scheme", "logit_normal"),
interactive=True,
)
self.logit_mean = gr.Number(
label="Logit Mean",
value=self.config.get("sd3.logit_mean", 0.0),
interactive=True,
)
self.logit_std = gr.Number(
label="Logit Std",
value=self.config.get("sd3.logit_std", 1.0),
interactive=True,
)
self.mode_scale = gr.Number(
label="Mode Scale",
value=self.config.get("sd3.mode_scale", 1.29),
interactive=True,
)
with gr.Row():
self.clip_l = gr.Textbox(
label="CLIP-L Path",
placeholder="Path to CLIP-L model",
value=self.config.get("sd3.clip_l", ""),
interactive=True,
)
self.clip_l_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.clip_l_button.click(
get_any_file_path,
outputs=self.clip_l,
show_progress=False,
)
self.clip_g = gr.Textbox(
label="CLIP-G Path",
placeholder="Path to CLIP-G model",
value=self.config.get("sd3.clip_g", ""),
interactive=True,
)
self.clip_g_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.clip_g_button.click(
get_any_file_path,
outputs=self.clip_g,
show_progress=False,
)
self.t5xxl = gr.Textbox(
label="T5-XXL Path",
placeholder="Path to T5-XXL model",
value=self.config.get("sd3.t5xxl", ""),
interactive=True,
)
self.t5xxl_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.t5xxl_button.click(
get_any_file_path,
outputs=self.t5xxl,
show_progress=False,
)
with gr.Row():
self.save_clip = gr.Checkbox(
label="Save CLIP models",
value=self.config.get("sd3.save_clip", False),
interactive=True,
)
self.save_t5xxl = gr.Checkbox(
label="Save T5-XXL model",
value=self.config.get("sd3.save_t5xxl", False),
interactive=True,
)
with gr.Row():
self.t5xxl_device = gr.Textbox(
label="T5-XXL Device",
placeholder="Device for T5-XXL (e.g., cuda:0)",
value=self.config.get("sd3.t5xxl_device", ""),
interactive=True,
)
self.t5xxl_dtype = gr.Dropdown(
label="T5-XXL Dtype",
choices=["float32", "fp16", "bf16"],
value=self.config.get("sd3.t5xxl_dtype", "bf16"),
interactive=True,
)
self.sd3_text_encoder_batch_size = gr.Number(
label="Text Encoder Batch Size",
value=self.config.get("sd3.text_encoder_batch_size", 1),
minimum=1,
maximum=1024,
step=1,
interactive=True,
)
self.sd3_cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
value=self.config.get("sd3.cache_text_encoder_outputs", False),
info="Cache text encoder outputs to speed up inference",
interactive=True,
)
self.sd3_cache_text_encoder_outputs_to_disk = gr.Checkbox(
label="Cache Text Encoder Outputs to Disk",
value=self.config.get(
"sd3.cache_text_encoder_outputs_to_disk", False
),
info="Cache text encoder outputs to disk to speed up inference",
interactive=True,
)
with gr.Row():
self.clip_l_dropout_rate = gr.Number(
label="CLIP-L Dropout Rate",
value=self.config.get("sd3.clip_l_dropout_rate", 0.0),
interactive=True,
minimum=0.0,
info="Dropout rate for CLIP-L encoder"
)
self.clip_g_dropout_rate = gr.Number(
label="CLIP-G Dropout Rate",
value=self.config.get("sd3.clip_g_dropout_rate", 0.0),
interactive=True,
minimum=0.0,
info="Dropout rate for CLIP-G encoder"
)
self.t5_dropout_rate = gr.Number(
label="T5 Dropout Rate",
value=self.config.get("sd3.t5_dropout_rate", 0.0),
interactive=True,
minimum=0.0,
info="Dropout rate for T5-XXL encoder"
)
with gr.Row():
self.sd3_fused_backward_pass = gr.Checkbox(
label="Fused Backward Pass",
value=self.config.get("sd3.fused_backward_pass", False),
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
interactive=True,
)
self.disable_mmap_load_safetensors = gr.Checkbox(
label="Disable mmap load safe tensors",
info="Disable memory mapping when loading the model's .safetensors in SDXL.",
value=self.config.get("sd3.disable_mmap_load_safetensors", False),
)
self.enable_scaled_pos_embed = gr.Checkbox(
label="Enable Scaled Positional Embeddings",
info="Enable scaled positional embeddings in the model.",
value=self.config.get("sd3.enable_scaled_pos_embed", False),
)
self.pos_emb_random_crop_rate = gr.Number(
label="Positional Embedding Random Crop Rate",
value=self.config.get("sd3.pos_emb_random_crop_rate", 0.0),
interactive=True,
minimum=0.0,
info="Random crop rate for positional embeddings"
)
self.sd3_checkbox.change(
lambda sd3_checkbox: gr.Accordion(visible=sd3_checkbox),
inputs=[self.sd3_checkbox],
outputs=[sd3_accordion],
)