pull/1166/head
bmaltais 2023-07-11 11:40:42 -04:00
parent b114e1f083
commit 689721cba5
11 changed files with 472 additions and 427 deletions

View File

@ -1 +1 @@
v21.8.1
v21.8.2

View File

@ -462,4 +462,8 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is
* 2023/07/10 (v21.8.1)
- Let Tensorboard works in docker #1137
- Fix for accelerate issue
- Fix for accelerate issue
- Add SDXL TI training support
- Rework gui for common layout
- More LoRA tools to class
- Add no_half_vae option to TI

View File

@ -659,184 +659,186 @@ def dreambooth_tab(
dummy_db_true = gr.Label(value=True, visible=False)
dummy_db_false = gr.Label(value=False, visible=False)
dummy_headless = gr.Label(value=headless, visible=False)
gr.Markdown('Train a custom model using kohya dreambooth python code...')
# Setup Configuration Files Gradio
config = ConfigurationFile(headless)
source_model = SourceModel(headless=headless)
with gr.Tab('Training'):
gr.Markdown('Train a custom model using kohya dreambooth python code...')
# Setup Configuration Files Gradio
config = ConfigurationFile(headless)
source_model = SourceModel(headless=headless)
with gr.Tab('Folders'):
folders = Folders(headless=headless)
with gr.Tab('Parameters'):
basic_training = BasicTraining(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
lr_warmup_value='10',
)
with gr.Accordion('Advanced Configuration', open=False):
advanced_training = AdvancedTraining(headless=headless)
advanced_training.color_aug.change(
color_aug_changed,
inputs=[advanced_training.color_aug],
outputs=[basic_training.cache_latents],
with gr.Tab('Folders'):
folders = Folders(headless=headless)
with gr.Tab('Parameters'):
basic_training = BasicTraining(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
lr_warmup_value='10',
)
with gr.Accordion('Advanced Configuration', open=False):
advanced_training = AdvancedTraining(headless=headless)
advanced_training.color_aug.change(
color_aug_changed,
inputs=[advanced_training.color_aug],
outputs=[basic_training.cache_latents],
)
sample = SampleImages()
with gr.Tab('Tools'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
)
sample = SampleImages()
button_run = gr.Button('Train model', variant='primary')
with gr.Tab('Tools'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
button_print = gr.Button('Print training command')
# Setup gradio tensorboard buttons
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
button_start_tensorboard.click(
start_tensorboard,
inputs=folders.logging_dir,
show_progress=False,
)
button_run = gr.Button('Train model', variant='primary')
button_stop_tensorboard.click(
stop_tensorboard,
show_progress=False,
)
button_print = gr.Button('Print training command')
settings_list = [
source_model.pretrained_model_name_or_path,
source_model.v2,
source_model.v_parameterization,
source_model.sdxl_checkbox,
folders.logging_dir,
folders.train_data_dir,
folders.reg_data_dir,
folders.output_dir,
basic_training.max_resolution,
basic_training.learning_rate,
basic_training.lr_scheduler,
basic_training.lr_warmup,
basic_training.train_batch_size,
basic_training.epoch,
basic_training.save_every_n_epochs,
basic_training.mixed_precision,
basic_training.save_precision,
basic_training.seed,
basic_training.num_cpu_threads_per_process,
basic_training.cache_latents,
basic_training.cache_latents_to_disk,
basic_training.caption_extension,
basic_training.enable_bucket,
advanced_training.gradient_checkpointing,
advanced_training.full_fp16,
advanced_training.no_token_padding,
basic_training.stop_text_encoder_training,
advanced_training.xformers,
source_model.save_model_as,
advanced_training.shuffle_caption,
advanced_training.save_state,
advanced_training.resume,
advanced_training.prior_loss_weight,
advanced_training.color_aug,
advanced_training.flip_aug,
advanced_training.clip_skip,
advanced_training.vae,
folders.output_name,
advanced_training.max_token_length,
advanced_training.max_train_epochs,
advanced_training.max_data_loader_n_workers,
advanced_training.mem_eff_attn,
advanced_training.gradient_accumulation_steps,
source_model.model_list,
advanced_training.keep_tokens,
advanced_training.persistent_data_loader_workers,
advanced_training.bucket_no_upscale,
advanced_training.random_crop,
advanced_training.bucket_reso_steps,
advanced_training.caption_dropout_every_n_epochs,
advanced_training.caption_dropout_rate,
basic_training.optimizer,
basic_training.optimizer_args,
advanced_training.noise_offset_type,
advanced_training.noise_offset,
advanced_training.adaptive_noise_scale,
advanced_training.multires_noise_iterations,
advanced_training.multires_noise_discount,
sample.sample_every_n_steps,
sample.sample_every_n_epochs,
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
advanced_training.weighted_captions,
advanced_training.save_every_n_steps,
advanced_training.save_last_n_steps,
advanced_training.save_last_n_steps_state,
advanced_training.use_wandb,
advanced_training.wandb_api_key,
advanced_training.scale_v_pred_loss_like_noise_pred,
advanced_training.min_timestep,
advanced_training.max_timestep,
]
# Setup gradio tensorboard buttons
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
config.button_open_config.click(
open_configuration,
inputs=[dummy_db_true, config.config_file_name] + settings_list,
outputs=[config.config_file_name] + settings_list,
show_progress=False,
)
button_start_tensorboard.click(
start_tensorboard,
inputs=folders.logging_dir,
show_progress=False,
)
config.button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config.config_file_name] + settings_list,
outputs=[config.config_file_name] + settings_list,
show_progress=False,
)
button_stop_tensorboard.click(
stop_tensorboard,
show_progress=False,
)
config.button_save_config.click(
save_configuration,
inputs=[dummy_db_false, config.config_file_name] + settings_list,
outputs=[config.config_file_name],
show_progress=False,
)
settings_list = [
source_model.pretrained_model_name_or_path,
source_model.v2,
source_model.v_parameterization,
source_model.sdxl_checkbox,
folders.logging_dir,
folders.train_data_dir,
folders.reg_data_dir,
folders.output_dir,
basic_training.max_resolution,
basic_training.learning_rate,
basic_training.lr_scheduler,
basic_training.lr_warmup,
basic_training.train_batch_size,
basic_training.epoch,
basic_training.save_every_n_epochs,
basic_training.mixed_precision,
basic_training.save_precision,
basic_training.seed,
basic_training.num_cpu_threads_per_process,
basic_training.cache_latents,
basic_training.cache_latents_to_disk,
basic_training.caption_extension,
basic_training.enable_bucket,
advanced_training.gradient_checkpointing,
advanced_training.full_fp16,
advanced_training.no_token_padding,
basic_training.stop_text_encoder_training,
advanced_training.xformers,
source_model.save_model_as,
advanced_training.shuffle_caption,
advanced_training.save_state,
advanced_training.resume,
advanced_training.prior_loss_weight,
advanced_training.color_aug,
advanced_training.flip_aug,
advanced_training.clip_skip,
advanced_training.vae,
folders.output_name,
advanced_training.max_token_length,
advanced_training.max_train_epochs,
advanced_training.max_data_loader_n_workers,
advanced_training.mem_eff_attn,
advanced_training.gradient_accumulation_steps,
source_model.model_list,
advanced_training.keep_tokens,
advanced_training.persistent_data_loader_workers,
advanced_training.bucket_no_upscale,
advanced_training.random_crop,
advanced_training.bucket_reso_steps,
advanced_training.caption_dropout_every_n_epochs,
advanced_training.caption_dropout_rate,
basic_training.optimizer,
basic_training.optimizer_args,
advanced_training.noise_offset_type,
advanced_training.noise_offset,
advanced_training.adaptive_noise_scale,
advanced_training.multires_noise_iterations,
advanced_training.multires_noise_discount,
sample.sample_every_n_steps,
sample.sample_every_n_epochs,
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
advanced_training.weighted_captions,
advanced_training.save_every_n_steps,
advanced_training.save_last_n_steps,
advanced_training.save_last_n_steps_state,
advanced_training.use_wandb,
advanced_training.wandb_api_key,
advanced_training.scale_v_pred_loss_like_noise_pred,
advanced_training.min_timestep,
advanced_training.max_timestep,
]
config.button_save_as_config.click(
save_configuration,
inputs=[dummy_db_true, config.config_file_name] + settings_list,
outputs=[config.config_file_name],
show_progress=False,
)
config.button_open_config.click(
open_configuration,
inputs=[dummy_db_true, config.config_file_name] + settings_list,
outputs=[config.config_file_name] + settings_list,
show_progress=False,
)
button_run.click(
train_model,
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
show_progress=False,
)
config.button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config.config_file_name] + settings_list,
outputs=[config.config_file_name] + settings_list,
show_progress=False,
)
button_print.click(
train_model,
inputs=[dummy_headless] + [dummy_db_true] + settings_list,
show_progress=False,
)
config.button_save_config.click(
save_configuration,
inputs=[dummy_db_false, config.config_file_name] + settings_list,
outputs=[config.config_file_name],
show_progress=False,
)
config.button_save_as_config.click(
save_configuration,
inputs=[dummy_db_true, config.config_file_name] + settings_list,
outputs=[config.config_file_name],
show_progress=False,
)
button_run.click(
train_model,
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
show_progress=False,
)
button_print.click(
train_model,
inputs=[dummy_headless] + [dummy_db_true] + settings_list,
show_progress=False,
)
return (
folders.train_data_dir,
folders.reg_data_dir,
folders.output_dir,
folders.logging_dir,
)
return (
folders.train_data_dir,
folders.reg_data_dir,
folders.output_dir,
folders.logging_dir,
)
def UI(**kwargs):

View File

@ -78,6 +78,7 @@ def save_configuration(
seed,
num_cpu_threads_per_process,
train_text_encoder,
full_bf16,
create_caption,
create_buckets,
save_model_as,
@ -197,6 +198,7 @@ def open_configuration(
seed,
num_cpu_threads_per_process,
train_text_encoder,
full_bf16,
create_caption,
create_buckets,
save_model_as,
@ -313,6 +315,7 @@ def train_model(
seed,
num_cpu_threads_per_process,
train_text_encoder,
full_bf16,
generate_caption_database,
generate_image_buckets,
save_model_as,
@ -495,6 +498,8 @@ def train_model(
run_cmd += ' --v_parameterization'
if train_text_encoder:
run_cmd += ' --train_text_encoder'
if full_bf16:
run_cmd += ' --full_bf16'
if weighted_captions:
run_cmd += ' --weighted_captions'
run_cmd += (
@ -788,6 +793,9 @@ def finetune_tab(headless=False):
train_text_encoder = gr.Checkbox(
label='Train text encoder', value=True
)
full_bf16 = gr.Checkbox(
label='Full bf16', value = False
)
with gr.Accordion('Advanced parameters', open=False):
with gr.Row():
gradient_accumulation_steps = gr.Number(
@ -848,6 +856,7 @@ def finetune_tab(headless=False):
basic_training.seed,
basic_training.num_cpu_threads_per_process,
train_text_encoder,
full_bf16,
create_caption,
create_buckets,
source_model.save_model_as,

View File

@ -5,13 +5,8 @@ from dreambooth_gui import dreambooth_tab
from finetune_gui import finetune_tab
from textual_inversion_gui import ti_tab
from library.utilities import utilities_tab
from library.extract_lora_gui import gradio_extract_lora_tab
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
from library.merge_lora_gui import gradio_merge_lora_tab
from library.resize_lora_gui import gradio_resize_lora_tab
from library.extract_lora_from_dylora_gui import gradio_extract_dylora_tab
from library.merge_lycoris_gui import gradio_merge_lycoris_tab
from lora_gui import lora_tab
from library.class_lora_tab import LoRATools
import os
from library.custom_logging import setup_logging
@ -67,12 +62,7 @@ def UI(**kwargs):
headless=headless,
)
with gr.Tab('LoRA'):
gradio_extract_dylora_tab(headless=headless)
gradio_extract_lora_tab(headless=headless)
gradio_extract_lycoris_locon_tab(headless=headless)
gradio_merge_lora_tab(headless=headless)
gradio_merge_lycoris_tab(headless=headless)
gradio_resize_lora_tab(headless=headless)
_ = LoRATools(headless=headless)
with gr.Tab('About'):
gr.Markdown(f'kohya_ss GUI release {release}')
with gr.Tab('README'):

42
library/class_lora_tab.py Normal file
View File

@ -0,0 +1,42 @@
import gradio as gr
from library.merge_lora_gui import gradio_merge_lora_tab
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
from library.verify_lora_gui import gradio_verify_lora_tab
from library.resize_lora_gui import gradio_resize_lora_tab
from library.extract_lora_gui import gradio_extract_lora_tab
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
from library.extract_lora_from_dylora_gui import gradio_extract_dylora_tab
from library.merge_lycoris_gui import gradio_merge_lycoris_tab
# Deprecated code
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
class LoRATools:
def __init__(self, folders = "", headless:bool = False):
self.headless = headless
self.folders = folders
gr.Markdown(
'This section provide LoRA tools to help setup your dataset...'
)
gradio_extract_dylora_tab(headless=headless)
gradio_extract_lora_tab(headless=headless)
gradio_extract_lycoris_locon_tab(headless=headless)
gradio_merge_lora_tab(headless=headless)
gradio_merge_lycoris_tab(headless=headless)
gradio_svd_merge_lora_tab(headless=headless)
gradio_resize_lora_tab(headless=headless)
gradio_verify_lora_tab(headless=headless)
if folders:
with gr.Tab('Deprecated'):
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
)
gradio_dataset_balancing_tab(headless=headless)

View File

@ -2,8 +2,9 @@ import gradio as gr
### SDXL Parameters class
class SDXLParameters:
def __init__(self, sdxl_checkbox):
def __init__(self, sdxl_checkbox, show_sdxl_cache_text_encoder_outputs:bool = True):
self.sdxl_checkbox = sdxl_checkbox
self.show_sdxl_cache_text_encoder_outputs = show_sdxl_cache_text_encoder_outputs
with gr.Accordion(visible=False, open=True, label='SDXL Specific Parameters') as self.sdxl_row:
with gr.Row():
@ -11,11 +12,12 @@ class SDXLParameters:
label='Cache text encoder outputs',
info='Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.',
value=False,
visible=show_sdxl_cache_text_encoder_outputs
)
self.sdxl_no_half_vae = gr.Checkbox(
label='No half VAE',
info='Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.',
value=False
value=True
)
self.sdxl_checkbox.change(lambda sdxl_checkbox: gr.Accordion.update(visible=sdxl_checkbox), inputs=[self.sdxl_checkbox], outputs=[self.sdxl_row])

View File

@ -33,8 +33,8 @@ class SourceModel:
label='Model Quick Pick',
choices=[
'custom',
'stabilityai/stable-diffusion-xl-base-0.9',
'stabilityai/stable-diffusion-xl-refiner-0.9',
# 'stabilityai/stable-diffusion-xl-base-0.9',
# 'stabilityai/stable-diffusion-xl-refiner-0.9',
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned',
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',

View File

@ -36,6 +36,11 @@ def svd_merge_lora(
new_conv_rank,
device,
):
# Check if the output file already exists
if os.path.isfile(save_to):
print(f"Output file '{save_to}' already exists. Aborting.")
return
# Check if the ratio total is equal to one. If not mormalise to 1
total_ratio = ratio_a + ratio_b + ratio_c + ratio_d
if total_ratio != 1:
@ -78,7 +83,7 @@ def svd_merge_lora(
run_cmd_ratios += f' {ratio_d}'
run_cmd += run_cmd_models
run_cmd += run_cmd_ratiosacti
run_cmd += run_cmd_ratios
run_cmd += f' --device {device}'
run_cmd += f' --new_rank "{new_rank}"'
run_cmd += f' --new_conv_rank "{new_conv_rank}"'

View File

@ -32,21 +32,14 @@ from library.class_basic_training import BasicTraining
from library.class_advanced_training import AdvancedTraining
from library.class_sdxl_parameters import SDXLParameters
from library.class_folders import Folders
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.tensorboard_gui import (
gradio_tensorboard,
start_tensorboard,
stop_tensorboard,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.utilities import utilities_tab
from library.merge_lora_gui import gradio_merge_lora_tab
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
from library.verify_lora_gui import gradio_verify_lora_tab
from library.resize_lora_gui import gradio_resize_lora_tab
from library.class_sample_images import SampleImages, run_cmd_sample
from library.class_lora_tab import LoRATools
from library.custom_logging import setup_logging
@ -1576,21 +1569,7 @@ def lora_tab(
)
with gr.Tab('Tools'):
gr.Markdown(
'This section provide LoRA tools to help setup your dataset...'
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
)
gradio_dataset_balancing_tab(headless=headless)
gradio_merge_lora_tab(headless=headless)
gradio_svd_merge_lora_tab(headless=headless)
gradio_resize_lora_tab(headless=headless)
gradio_verify_lora_tab(headless=headless)
lora_tools = LoRATools(folders=folders, headless=headless)
with gr.Tab('Guides'):
gr.Markdown(

View File

@ -30,6 +30,7 @@ from library.class_source_model import SourceModel
from library.class_basic_training import BasicTraining
from library.class_advanced_training import AdvancedTraining
from library.class_folders import Folders
from library.class_sdxl_parameters import SDXLParameters
from library.tensorboard_gui import (
gradio_tensorboard,
start_tensorboard,
@ -129,6 +130,7 @@ def save_configuration(
scale_v_pred_loss_like_noise_pred,
min_timestep,
max_timestep,
sdxl_no_half_vae
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -245,6 +247,7 @@ def open_configuration(
scale_v_pred_loss_like_noise_pred,
min_timestep,
max_timestep,
sdxl_no_half_vae
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -358,6 +361,7 @@ def train_model(
scale_v_pred_loss_like_noise_pred,
min_timestep,
max_timestep,
sdxl_no_half_vae
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -421,13 +425,6 @@ def train_model(
):
return
if sdxl:
output_message(
msg='TI training is not compatible with an SDXL model.',
headless=headless_bool,
)
return
# if float(noise_offset) > 0 and (
# multires_noise_iterations > 0 or multires_noise_discount > 0
# ):
@ -520,7 +517,12 @@ def train_model(
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
log.info(f'lr_warmup_steps = {lr_warmup_steps}')
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_textual_inversion.py"'
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process}'
if sdxl:
run_cmd += f' "./sdxl_train_textual_inversion.py"'
else:
run_cmd += f' "./train_textual_inversion.py"'
if v2:
run_cmd += ' --v2'
if v_parameterization:
@ -563,6 +565,9 @@ def train_model(
)
if int(gradient_accumulation_steps) > 1:
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
if sdxl_no_half_vae:
run_cmd += f' --no_half_vae'
run_cmd += run_cmd_training(
learning_rate=learning_rate,
@ -679,237 +684,244 @@ def ti_tab(
dummy_db_true = gr.Label(value=True, visible=False)
dummy_db_false = gr.Label(value=False, visible=False)
dummy_headless = gr.Label(value=headless, visible=False)
gr.Markdown('Train a TI using kohya textual inversion python code...')
# Setup Configuration Files Gradio
config = ConfigurationFile(headless)
with gr.Tab('Training'):
gr.Markdown('Train a TI using kohya textual inversion python code...')
# Setup Configuration Files Gradio
config = ConfigurationFile(headless)
source_model = SourceModel(
save_model_as_choices=[
'ckpt',
'safetensors',
],
headless=headless,
)
with gr.Tab('Folders'):
folders = Folders(headless=headless)
with gr.Tab('Parameters'):
with gr.Row():
weights = gr.Textbox(
label='Resume TI training',
placeholder='(Optional) Path to existing TI embeding file to keep training',
)
weights_file_input = gr.Button(
'📂', elem_id='open_folder_small', visible=(not headless)
)
weights_file_input.click(
get_file_path,
outputs=weights,
show_progress=False,
)
with gr.Row():
token_string = gr.Textbox(
label='Token string',
placeholder='eg: cat',
)
init_word = gr.Textbox(
label='Init word',
value='*',
)
num_vectors_per_token = gr.Slider(
minimum=1,
maximum=75,
value=1,
step=1,
label='Vectors',
)
max_train_steps = gr.Textbox(
label='Max train steps',
placeholder='(Optional) Maximum number of steps',
)
template = gr.Dropdown(
label='Template',
choices=[
'caption',
'object template',
'style template',
],
value='caption',
)
basic_training = BasicTraining(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
lr_warmup_value='10',
)
with gr.Accordion('Advanced Configuration', open=False):
advanced_training = AdvancedTraining(headless=headless)
advanced_training.color_aug.change(
color_aug_changed,
inputs=[advanced_training.color_aug],
outputs=[basic_training.cache_latents],
)
sample = SampleImages()
with gr.Tab('Tools'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
source_model = SourceModel(
save_model_as_choices=[
'ckpt',
'safetensors',
],
headless=headless,
)
button_run = gr.Button('Train model', variant='primary')
with gr.Tab('Folders'):
folders = Folders(headless=headless)
with gr.Tab('Parameters'):
with gr.Row():
weights = gr.Textbox(
label='Resume TI training',
placeholder='(Optional) Path to existing TI embeding file to keep training',
)
weights_file_input = gr.Button(
'📂', elem_id='open_folder_small', visible=(not headless)
)
weights_file_input.click(
get_file_path,
outputs=weights,
show_progress=False,
)
with gr.Row():
token_string = gr.Textbox(
label='Token string',
placeholder='eg: cat',
)
init_word = gr.Textbox(
label='Init word',
value='*',
)
num_vectors_per_token = gr.Slider(
minimum=1,
maximum=75,
value=1,
step=1,
label='Vectors',
)
max_train_steps = gr.Textbox(
label='Max train steps',
placeholder='(Optional) Maximum number of steps',
)
template = gr.Dropdown(
label='Template',
choices=[
'caption',
'object template',
'style template',
],
value='caption',
)
basic_training = BasicTraining(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
lr_warmup_value='10',
)
# Add SDXL Parameters
sdxl_params = SDXLParameters(source_model.sdxl_checkbox, show_sdxl_cache_text_encoder_outputs=False)
with gr.Accordion('Advanced Configuration', open=False):
advanced_training = AdvancedTraining(headless=headless)
advanced_training.color_aug.change(
color_aug_changed,
inputs=[advanced_training.color_aug],
outputs=[basic_training.cache_latents],
)
button_print = gr.Button('Print training command')
sample = SampleImages()
# Setup gradio tensorboard buttons
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
with gr.Tab('Tools'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
)
button_start_tensorboard.click(
start_tensorboard,
inputs=folders.logging_dir,
show_progress=False,
)
button_run = gr.Button('Train model', variant='primary')
button_stop_tensorboard.click(
stop_tensorboard,
show_progress=False,
)
button_print = gr.Button('Print training command')
settings_list = [
source_model.pretrained_model_name_or_path,
source_model.v2,
source_model.v_parameterization,
source_model.sdxl_checkbox,
folders.logging_dir,
folders.train_data_dir,
folders.reg_data_dir,
folders.output_dir,
basic_training.max_resolution,
basic_training.learning_rate,
basic_training.lr_scheduler,
basic_training.lr_warmup,
basic_training.train_batch_size,
basic_training.epoch,
basic_training.save_every_n_epochs,
basic_training.mixed_precision,
basic_training.save_precision,
basic_training.seed,
basic_training.num_cpu_threads_per_process,
basic_training.cache_latents,
basic_training.cache_latents_to_disk,
basic_training.caption_extension,
basic_training.enable_bucket,
advanced_training.gradient_checkpointing,
advanced_training.full_fp16,
advanced_training.no_token_padding,
basic_training.stop_text_encoder_training,
advanced_training.xformers,
source_model.save_model_as,
advanced_training.shuffle_caption,
advanced_training.save_state,
advanced_training.resume,
advanced_training.prior_loss_weight,
advanced_training.color_aug,
advanced_training.flip_aug,
advanced_training.clip_skip,
advanced_training.vae,
folders.output_name,
advanced_training.max_token_length,
advanced_training.max_train_epochs,
advanced_training.max_data_loader_n_workers,
advanced_training.mem_eff_attn,
advanced_training.gradient_accumulation_steps,
source_model.model_list,
token_string,
init_word,
num_vectors_per_token,
max_train_steps,
weights,
template,
advanced_training.keep_tokens,
advanced_training.persistent_data_loader_workers,
advanced_training.bucket_no_upscale,
advanced_training.random_crop,
advanced_training.bucket_reso_steps,
advanced_training.caption_dropout_every_n_epochs,
advanced_training.caption_dropout_rate,
basic_training.optimizer,
basic_training.optimizer_args,
advanced_training.noise_offset_type,
advanced_training.noise_offset,
advanced_training.adaptive_noise_scale,
advanced_training.multires_noise_iterations,
advanced_training.multires_noise_discount,
sample.sample_every_n_steps,
sample.sample_every_n_epochs,
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
advanced_training.save_every_n_steps,
advanced_training.save_last_n_steps,
advanced_training.save_last_n_steps_state,
advanced_training.use_wandb,
advanced_training.wandb_api_key,
advanced_training.scale_v_pred_loss_like_noise_pred,
advanced_training.min_timestep,
advanced_training.max_timestep
]
# Setup gradio tensorboard buttons
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
config.button_open_config.click(
open_configuration,
inputs=[dummy_db_true, config.config_file_name] + settings_list,
outputs=[config.config_file_name] + settings_list,
show_progress=False,
)
button_start_tensorboard.click(
start_tensorboard,
inputs=folders.logging_dir,
show_progress=False,
)
config.button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config.config_file_name] + settings_list,
outputs=[config.config_file_name] + settings_list,
show_progress=False,
)
button_stop_tensorboard.click(
stop_tensorboard,
show_progress=False,
)
config.button_save_config.click(
save_configuration,
inputs=[dummy_db_false, config.config_file_name] + settings_list,
outputs=[config.config_file_name],
show_progress=False,
)
settings_list = [
source_model.pretrained_model_name_or_path,
source_model.v2,
source_model.v_parameterization,
source_model.sdxl_checkbox,
folders.logging_dir,
folders.train_data_dir,
folders.reg_data_dir,
folders.output_dir,
basic_training.max_resolution,
basic_training.learning_rate,
basic_training.lr_scheduler,
basic_training.lr_warmup,
basic_training.train_batch_size,
basic_training.epoch,
basic_training.save_every_n_epochs,
basic_training.mixed_precision,
basic_training.save_precision,
basic_training.seed,
basic_training.num_cpu_threads_per_process,
basic_training.cache_latents,
basic_training.cache_latents_to_disk,
basic_training.caption_extension,
basic_training.enable_bucket,
advanced_training.gradient_checkpointing,
advanced_training.full_fp16,
advanced_training.no_token_padding,
basic_training.stop_text_encoder_training,
advanced_training.xformers,
source_model.save_model_as,
advanced_training.shuffle_caption,
advanced_training.save_state,
advanced_training.resume,
advanced_training.prior_loss_weight,
advanced_training.color_aug,
advanced_training.flip_aug,
advanced_training.clip_skip,
advanced_training.vae,
folders.output_name,
advanced_training.max_token_length,
advanced_training.max_train_epochs,
advanced_training.max_data_loader_n_workers,
advanced_training.mem_eff_attn,
advanced_training.gradient_accumulation_steps,
source_model.model_list,
token_string,
init_word,
num_vectors_per_token,
max_train_steps,
weights,
template,
advanced_training.keep_tokens,
advanced_training.persistent_data_loader_workers,
advanced_training.bucket_no_upscale,
advanced_training.random_crop,
advanced_training.bucket_reso_steps,
advanced_training.caption_dropout_every_n_epochs,
advanced_training.caption_dropout_rate,
basic_training.optimizer,
basic_training.optimizer_args,
advanced_training.noise_offset_type,
advanced_training.noise_offset,
advanced_training.adaptive_noise_scale,
advanced_training.multires_noise_iterations,
advanced_training.multires_noise_discount,
sample.sample_every_n_steps,
sample.sample_every_n_epochs,
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
advanced_training.save_every_n_steps,
advanced_training.save_last_n_steps,
advanced_training.save_last_n_steps_state,
advanced_training.use_wandb,
advanced_training.wandb_api_key,
advanced_training.scale_v_pred_loss_like_noise_pred,
advanced_training.min_timestep,
advanced_training.max_timestep,
sdxl_params.sdxl_no_half_vae,
]
config.button_save_as_config.click(
save_configuration,
inputs=[dummy_db_true, config.config_file_name] + settings_list,
outputs=[config.config_file_name],
show_progress=False,
)
config.button_open_config.click(
open_configuration,
inputs=[dummy_db_true, config.config_file_name] + settings_list,
outputs=[config.config_file_name] + settings_list,
show_progress=False,
)
button_run.click(
train_model,
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
show_progress=False,
)
config.button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config.config_file_name] + settings_list,
outputs=[config.config_file_name] + settings_list,
show_progress=False,
)
button_print.click(
train_model,
inputs=[dummy_headless] + [dummy_db_true] + settings_list,
show_progress=False,
)
config.button_save_config.click(
save_configuration,
inputs=[dummy_db_false, config.config_file_name] + settings_list,
outputs=[config.config_file_name],
show_progress=False,
)
return (
folders.train_data_dir,
folders.reg_data_dir,
folders.output_dir,
folders.logging_dir,
)
config.button_save_as_config.click(
save_configuration,
inputs=[dummy_db_true, config.config_file_name] + settings_list,
outputs=[config.config_file_name],
show_progress=False,
)
button_run.click(
train_model,
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
show_progress=False,
)
button_print.click(
train_model,
inputs=[dummy_headless] + [dummy_db_true] + settings_list,
show_progress=False,
)
return (
folders.train_data_dir,
folders.reg_data_dir,
folders.output_dir,
folders.logging_dir,
)
def UI(**kwargs):