mirror of https://github.com/bmaltais/kohya_ss
250 lines
11 KiB
Python
250 lines
11 KiB
Python
import gradio as gr
|
|
from typing import Tuple
|
|
from .common_gui import (
|
|
get_folder_path,
|
|
get_any_file_path,
|
|
list_files,
|
|
list_dirs,
|
|
create_refresh_button,
|
|
document_symbol,
|
|
)
|
|
|
|
|
|
class sd3Training:
|
|
"""
|
|
This class configures and initializes the advanced training settings for a machine learning model,
|
|
including options for headless operation, fine-tuning, training type selection, and default directory paths.
|
|
|
|
Attributes:
|
|
headless (bool): If True, run without the Gradio interface.
|
|
finetuning (bool): If True, enables fine-tuning of the model.
|
|
training_type (str): Specifies the type of training to perform.
|
|
no_token_padding (gr.Checkbox): Checkbox to disable token padding.
|
|
gradient_accumulation_steps (gr.Slider): Slider to set the number of gradient accumulation steps.
|
|
weighted_captions (gr.Checkbox): Checkbox to enable weighted captions.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
headless: bool = False,
|
|
finetuning: bool = False,
|
|
training_type: str = "",
|
|
config: dict = {},
|
|
sd3_checkbox: gr.Checkbox = False,
|
|
) -> None:
|
|
"""
|
|
Initializes the AdvancedTraining class with given settings.
|
|
|
|
Parameters:
|
|
headless (bool): Run in headless mode without GUI.
|
|
finetuning (bool): Enable model fine-tuning.
|
|
training_type (str): The type of training to be performed.
|
|
config (dict): Configuration options for the training process.
|
|
"""
|
|
self.headless = headless
|
|
self.finetuning = finetuning
|
|
self.training_type = training_type
|
|
self.config = config
|
|
self.sd3_checkbox = sd3_checkbox
|
|
|
|
# Define the behavior for changing noise offset type.
|
|
def noise_offset_type_change(
|
|
noise_offset_type: str,
|
|
) -> Tuple[gr.Group, gr.Group]:
|
|
"""
|
|
Returns a tuple of Gradio Groups with visibility set based on the noise offset type.
|
|
|
|
Parameters:
|
|
noise_offset_type (str): The selected noise offset type.
|
|
|
|
Returns:
|
|
Tuple[gr.Group, gr.Group]: A tuple containing two Gradio Group elements with their visibility set.
|
|
"""
|
|
if noise_offset_type == "Original":
|
|
return (gr.Group(visible=True), gr.Group(visible=False))
|
|
else:
|
|
return (gr.Group(visible=False), gr.Group(visible=True))
|
|
|
|
with gr.Accordion(
|
|
"SD3", open=False, elem_id="sd3_tab", visible=False
|
|
) as sd3_accordion:
|
|
with gr.Group():
|
|
gr.Markdown("### SD3 Specific Parameters")
|
|
with gr.Row():
|
|
self.weighting_scheme = gr.Dropdown(
|
|
label="Weighting Scheme",
|
|
choices=["logit_normal", "sigma_sqrt", "mode", "cosmap", "uniform"],
|
|
value=self.config.get("sd3.weighting_scheme", "logit_normal"),
|
|
interactive=True,
|
|
)
|
|
self.logit_mean = gr.Number(
|
|
label="Logit Mean",
|
|
value=self.config.get("sd3.logit_mean", 0.0),
|
|
interactive=True,
|
|
)
|
|
self.logit_std = gr.Number(
|
|
label="Logit Std",
|
|
value=self.config.get("sd3.logit_std", 1.0),
|
|
interactive=True,
|
|
)
|
|
self.mode_scale = gr.Number(
|
|
label="Mode Scale",
|
|
value=self.config.get("sd3.mode_scale", 1.29),
|
|
interactive=True,
|
|
)
|
|
|
|
with gr.Row():
|
|
self.clip_l = gr.Textbox(
|
|
label="CLIP-L Path",
|
|
placeholder="Path to CLIP-L model",
|
|
value=self.config.get("sd3.clip_l", ""),
|
|
interactive=True,
|
|
)
|
|
self.clip_l_button = gr.Button(
|
|
document_symbol,
|
|
elem_id="open_folder_small",
|
|
visible=(not headless),
|
|
interactive=True,
|
|
)
|
|
self.clip_l_button.click(
|
|
get_any_file_path,
|
|
outputs=self.clip_l,
|
|
show_progress=False,
|
|
)
|
|
|
|
self.clip_g = gr.Textbox(
|
|
label="CLIP-G Path",
|
|
placeholder="Path to CLIP-G model",
|
|
value=self.config.get("sd3.clip_g", ""),
|
|
interactive=True,
|
|
)
|
|
self.clip_g_button = gr.Button(
|
|
document_symbol,
|
|
elem_id="open_folder_small",
|
|
visible=(not headless),
|
|
interactive=True,
|
|
)
|
|
self.clip_g_button.click(
|
|
get_any_file_path,
|
|
outputs=self.clip_g,
|
|
show_progress=False,
|
|
)
|
|
|
|
self.t5xxl = gr.Textbox(
|
|
label="T5-XXL Path",
|
|
placeholder="Path to T5-XXL model",
|
|
value=self.config.get("sd3.t5xxl", ""),
|
|
interactive=True,
|
|
)
|
|
self.t5xxl_button = gr.Button(
|
|
document_symbol,
|
|
elem_id="open_folder_small",
|
|
visible=(not headless),
|
|
interactive=True,
|
|
)
|
|
self.t5xxl_button.click(
|
|
get_any_file_path,
|
|
outputs=self.t5xxl,
|
|
show_progress=False,
|
|
)
|
|
|
|
with gr.Row():
|
|
self.save_clip = gr.Checkbox(
|
|
label="Save CLIP models",
|
|
value=self.config.get("sd3.save_clip", False),
|
|
interactive=True,
|
|
)
|
|
self.save_t5xxl = gr.Checkbox(
|
|
label="Save T5-XXL model",
|
|
value=self.config.get("sd3.save_t5xxl", False),
|
|
interactive=True,
|
|
)
|
|
|
|
with gr.Row():
|
|
self.t5xxl_device = gr.Textbox(
|
|
label="T5-XXL Device",
|
|
placeholder="Device for T5-XXL (e.g., cuda:0)",
|
|
value=self.config.get("sd3.t5xxl_device", ""),
|
|
interactive=True,
|
|
)
|
|
self.t5xxl_dtype = gr.Dropdown(
|
|
label="T5-XXL Dtype",
|
|
choices=["float32", "fp16", "bf16"],
|
|
value=self.config.get("sd3.t5xxl_dtype", "bf16"),
|
|
interactive=True,
|
|
)
|
|
self.sd3_text_encoder_batch_size = gr.Number(
|
|
label="Text Encoder Batch Size",
|
|
value=self.config.get("sd3.text_encoder_batch_size", 1),
|
|
minimum=1,
|
|
maximum=1024,
|
|
step=1,
|
|
interactive=True,
|
|
)
|
|
self.sd3_cache_text_encoder_outputs = gr.Checkbox(
|
|
label="Cache Text Encoder Outputs",
|
|
value=self.config.get("sd3.cache_text_encoder_outputs", False),
|
|
info="Cache text encoder outputs to speed up inference",
|
|
interactive=True,
|
|
)
|
|
self.sd3_cache_text_encoder_outputs_to_disk = gr.Checkbox(
|
|
label="Cache Text Encoder Outputs to Disk",
|
|
value=self.config.get(
|
|
"sd3.cache_text_encoder_outputs_to_disk", False
|
|
),
|
|
info="Cache text encoder outputs to disk to speed up inference",
|
|
interactive=True,
|
|
)
|
|
with gr.Row():
|
|
self.clip_l_dropout_rate = gr.Number(
|
|
label="CLIP-L Dropout Rate",
|
|
value=self.config.get("sd3.clip_l_dropout_rate", 0.0),
|
|
interactive=True,
|
|
minimum=0.0,
|
|
info="Dropout rate for CLIP-L encoder"
|
|
)
|
|
self.clip_g_dropout_rate = gr.Number(
|
|
label="CLIP-G Dropout Rate",
|
|
value=self.config.get("sd3.clip_g_dropout_rate", 0.0),
|
|
interactive=True,
|
|
minimum=0.0,
|
|
info="Dropout rate for CLIP-G encoder"
|
|
)
|
|
self.t5_dropout_rate = gr.Number(
|
|
label="T5 Dropout Rate",
|
|
value=self.config.get("sd3.t5_dropout_rate", 0.0),
|
|
interactive=True,
|
|
minimum=0.0,
|
|
info="Dropout rate for T5-XXL encoder"
|
|
)
|
|
with gr.Row():
|
|
self.sd3_fused_backward_pass = gr.Checkbox(
|
|
label="Fused Backward Pass",
|
|
value=self.config.get("sd3.fused_backward_pass", False),
|
|
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
|
|
interactive=True,
|
|
)
|
|
self.disable_mmap_load_safetensors = gr.Checkbox(
|
|
label="Disable mmap load safe tensors",
|
|
info="Disable memory mapping when loading the model's .safetensors in SDXL.",
|
|
value=self.config.get("sd3.disable_mmap_load_safetensors", False),
|
|
)
|
|
self.enable_scaled_pos_embed = gr.Checkbox(
|
|
label="Enable Scaled Positional Embeddings",
|
|
info="Enable scaled positional embeddings in the model.",
|
|
value=self.config.get("sd3.enable_scaled_pos_embed", False),
|
|
)
|
|
self.pos_emb_random_crop_rate = gr.Number(
|
|
label="Positional Embedding Random Crop Rate",
|
|
value=self.config.get("sd3.pos_emb_random_crop_rate", 0.0),
|
|
interactive=True,
|
|
minimum=0.0,
|
|
info="Random crop rate for positional embeddings"
|
|
)
|
|
|
|
self.sd3_checkbox.change(
|
|
lambda sd3_checkbox: gr.Accordion(visible=sd3_checkbox),
|
|
inputs=[self.sd3_checkbox],
|
|
outputs=[sd3_accordion],
|
|
)
|