diff --git a/.release b/.release index 042b749..3f77139 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v21.8.1 \ No newline at end of file +v21.8.2 \ No newline at end of file diff --git a/README.md b/README.md index 33e38be..96b2968 100644 --- a/README.md +++ b/README.md @@ -462,4 +462,8 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is * 2023/07/10 (v21.8.1) - Let Tensorboard works in docker #1137 - - Fix for accelerate issue \ No newline at end of file + - Fix for accelerate issue + - Add SDXL TI training support + - Rework gui for common layout + - More LoRA tools to class + - Add no_half_vae option to TI \ No newline at end of file diff --git a/dreambooth_gui.py b/dreambooth_gui.py index ef6d11b..dd8fa14 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -659,184 +659,186 @@ def dreambooth_tab( dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) dummy_headless = gr.Label(value=headless, visible=False) - gr.Markdown('Train a custom model using kohya dreambooth python code...') - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) - - source_model = SourceModel(headless=headless) + with gr.Tab('Training'): + gr.Markdown('Train a custom model using kohya dreambooth python code...') + + # Setup Configuration Files Gradio + config = ConfigurationFile(headless) + + source_model = SourceModel(headless=headless) - with gr.Tab('Folders'): - folders = Folders(headless=headless) - with gr.Tab('Parameters'): - basic_training = BasicTraining( - learning_rate_value='1e-5', - lr_scheduler_value='cosine', - lr_warmup_value='10', - ) - with gr.Accordion('Advanced Configuration', open=False): - advanced_training = AdvancedTraining(headless=headless) - advanced_training.color_aug.change( - color_aug_changed, - inputs=[advanced_training.color_aug], - outputs=[basic_training.cache_latents], + with gr.Tab('Folders'): + folders = Folders(headless=headless) + with gr.Tab('Parameters'): + basic_training = BasicTraining( + learning_rate_value='1e-5', + lr_scheduler_value='cosine', + lr_warmup_value='10', + ) + with gr.Accordion('Advanced Configuration', open=False): + advanced_training = AdvancedTraining(headless=headless) + advanced_training.color_aug.change( + color_aug_changed, + inputs=[advanced_training.color_aug], + outputs=[basic_training.cache_latents], + ) + + sample = SampleImages() + + with gr.Tab('Tools'): + gr.Markdown( + 'This section provide Dreambooth tools to help setup your dataset...' + ) + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=folders.train_data_dir, + reg_data_dir_input=folders.reg_data_dir, + output_dir_input=folders.output_dir, + logging_dir_input=folders.logging_dir, + headless=headless, ) - sample = SampleImages() + button_run = gr.Button('Train model', variant='primary') - with gr.Tab('Tools'): - gr.Markdown( - 'This section provide Dreambooth tools to help setup your dataset...' - ) - gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, - reg_data_dir_input=folders.reg_data_dir, - output_dir_input=folders.output_dir, - logging_dir_input=folders.logging_dir, - headless=headless, + button_print = gr.Button('Print training command') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=folders.logging_dir, + show_progress=False, ) - button_run = gr.Button('Train model', variant='primary') + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) - button_print = gr.Button('Print training command') + settings_list = [ + source_model.pretrained_model_name_or_path, + source_model.v2, + source_model.v_parameterization, + source_model.sdxl_checkbox, + folders.logging_dir, + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + basic_training.max_resolution, + basic_training.learning_rate, + basic_training.lr_scheduler, + basic_training.lr_warmup, + basic_training.train_batch_size, + basic_training.epoch, + basic_training.save_every_n_epochs, + basic_training.mixed_precision, + basic_training.save_precision, + basic_training.seed, + basic_training.num_cpu_threads_per_process, + basic_training.cache_latents, + basic_training.cache_latents_to_disk, + basic_training.caption_extension, + basic_training.enable_bucket, + advanced_training.gradient_checkpointing, + advanced_training.full_fp16, + advanced_training.no_token_padding, + basic_training.stop_text_encoder_training, + advanced_training.xformers, + source_model.save_model_as, + advanced_training.shuffle_caption, + advanced_training.save_state, + advanced_training.resume, + advanced_training.prior_loss_weight, + advanced_training.color_aug, + advanced_training.flip_aug, + advanced_training.clip_skip, + advanced_training.vae, + folders.output_name, + advanced_training.max_token_length, + advanced_training.max_train_epochs, + advanced_training.max_data_loader_n_workers, + advanced_training.mem_eff_attn, + advanced_training.gradient_accumulation_steps, + source_model.model_list, + advanced_training.keep_tokens, + advanced_training.persistent_data_loader_workers, + advanced_training.bucket_no_upscale, + advanced_training.random_crop, + advanced_training.bucket_reso_steps, + advanced_training.caption_dropout_every_n_epochs, + advanced_training.caption_dropout_rate, + basic_training.optimizer, + basic_training.optimizer_args, + advanced_training.noise_offset_type, + advanced_training.noise_offset, + advanced_training.adaptive_noise_scale, + advanced_training.multires_noise_iterations, + advanced_training.multires_noise_discount, + sample.sample_every_n_steps, + sample.sample_every_n_epochs, + sample.sample_sampler, + sample.sample_prompts, + advanced_training.additional_parameters, + advanced_training.vae_batch_size, + advanced_training.min_snr_gamma, + advanced_training.weighted_captions, + advanced_training.save_every_n_steps, + advanced_training.save_last_n_steps, + advanced_training.save_last_n_steps_state, + advanced_training.use_wandb, + advanced_training.wandb_api_key, + advanced_training.scale_v_pred_loss_like_noise_pred, + advanced_training.min_timestep, + advanced_training.max_timestep, + ] - # Setup gradio tensorboard buttons - button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + config.button_open_config.click( + open_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) - button_start_tensorboard.click( - start_tensorboard, - inputs=folders.logging_dir, - show_progress=False, - ) + config.button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) - button_stop_tensorboard.click( - stop_tensorboard, - show_progress=False, - ) + config.button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) - settings_list = [ - source_model.pretrained_model_name_or_path, - source_model.v2, - source_model.v_parameterization, - source_model.sdxl_checkbox, - folders.logging_dir, - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - basic_training.max_resolution, - basic_training.learning_rate, - basic_training.lr_scheduler, - basic_training.lr_warmup, - basic_training.train_batch_size, - basic_training.epoch, - basic_training.save_every_n_epochs, - basic_training.mixed_precision, - basic_training.save_precision, - basic_training.seed, - basic_training.num_cpu_threads_per_process, - basic_training.cache_latents, - basic_training.cache_latents_to_disk, - basic_training.caption_extension, - basic_training.enable_bucket, - advanced_training.gradient_checkpointing, - advanced_training.full_fp16, - advanced_training.no_token_padding, - basic_training.stop_text_encoder_training, - advanced_training.xformers, - source_model.save_model_as, - advanced_training.shuffle_caption, - advanced_training.save_state, - advanced_training.resume, - advanced_training.prior_loss_weight, - advanced_training.color_aug, - advanced_training.flip_aug, - advanced_training.clip_skip, - advanced_training.vae, - folders.output_name, - advanced_training.max_token_length, - advanced_training.max_train_epochs, - advanced_training.max_data_loader_n_workers, - advanced_training.mem_eff_attn, - advanced_training.gradient_accumulation_steps, - source_model.model_list, - advanced_training.keep_tokens, - advanced_training.persistent_data_loader_workers, - advanced_training.bucket_no_upscale, - advanced_training.random_crop, - advanced_training.bucket_reso_steps, - advanced_training.caption_dropout_every_n_epochs, - advanced_training.caption_dropout_rate, - basic_training.optimizer, - basic_training.optimizer_args, - advanced_training.noise_offset_type, - advanced_training.noise_offset, - advanced_training.adaptive_noise_scale, - advanced_training.multires_noise_iterations, - advanced_training.multires_noise_discount, - sample.sample_every_n_steps, - sample.sample_every_n_epochs, - sample.sample_sampler, - sample.sample_prompts, - advanced_training.additional_parameters, - advanced_training.vae_batch_size, - advanced_training.min_snr_gamma, - advanced_training.weighted_captions, - advanced_training.save_every_n_steps, - advanced_training.save_last_n_steps, - advanced_training.save_last_n_steps_state, - advanced_training.use_wandb, - advanced_training.wandb_api_key, - advanced_training.scale_v_pred_loss_like_noise_pred, - advanced_training.min_timestep, - advanced_training.max_timestep, - ] + config.button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) - config.button_open_config.click( - open_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) + button_run.click( + train_model, + inputs=[dummy_headless] + [dummy_db_false] + settings_list, + show_progress=False, + ) - config.button_load_config.click( - open_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) + button_print.click( + train_model, + inputs=[dummy_headless] + [dummy_db_true] + settings_list, + show_progress=False, + ) - config.button_save_config.click( - save_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) - - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) - - button_run.click( - train_model, - inputs=[dummy_headless] + [dummy_db_false] + settings_list, - show_progress=False, - ) - - button_print.click( - train_model, - inputs=[dummy_headless] + [dummy_db_true] + settings_list, - show_progress=False, - ) - - return ( - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - folders.logging_dir, - ) + return ( + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + folders.logging_dir, + ) def UI(**kwargs): diff --git a/finetune_gui.py b/finetune_gui.py index ca1e51f..5253ac0 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -78,6 +78,7 @@ def save_configuration( seed, num_cpu_threads_per_process, train_text_encoder, + full_bf16, create_caption, create_buckets, save_model_as, @@ -197,6 +198,7 @@ def open_configuration( seed, num_cpu_threads_per_process, train_text_encoder, + full_bf16, create_caption, create_buckets, save_model_as, @@ -313,6 +315,7 @@ def train_model( seed, num_cpu_threads_per_process, train_text_encoder, + full_bf16, generate_caption_database, generate_image_buckets, save_model_as, @@ -495,6 +498,8 @@ def train_model( run_cmd += ' --v_parameterization' if train_text_encoder: run_cmd += ' --train_text_encoder' + if full_bf16: + run_cmd += ' --full_bf16' if weighted_captions: run_cmd += ' --weighted_captions' run_cmd += ( @@ -788,6 +793,9 @@ def finetune_tab(headless=False): train_text_encoder = gr.Checkbox( label='Train text encoder', value=True ) + full_bf16 = gr.Checkbox( + label='Full bf16', value = False + ) with gr.Accordion('Advanced parameters', open=False): with gr.Row(): gradient_accumulation_steps = gr.Number( @@ -848,6 +856,7 @@ def finetune_tab(headless=False): basic_training.seed, basic_training.num_cpu_threads_per_process, train_text_encoder, + full_bf16, create_caption, create_buckets, source_model.save_model_as, diff --git a/kohya_gui.py b/kohya_gui.py index 0ac0e15..a6043e4 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -5,13 +5,8 @@ from dreambooth_gui import dreambooth_tab from finetune_gui import finetune_tab from textual_inversion_gui import ti_tab from library.utilities import utilities_tab -from library.extract_lora_gui import gradio_extract_lora_tab -from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab -from library.merge_lora_gui import gradio_merge_lora_tab -from library.resize_lora_gui import gradio_resize_lora_tab -from library.extract_lora_from_dylora_gui import gradio_extract_dylora_tab -from library.merge_lycoris_gui import gradio_merge_lycoris_tab from lora_gui import lora_tab +from library.class_lora_tab import LoRATools import os from library.custom_logging import setup_logging @@ -67,12 +62,7 @@ def UI(**kwargs): headless=headless, ) with gr.Tab('LoRA'): - gradio_extract_dylora_tab(headless=headless) - gradio_extract_lora_tab(headless=headless) - gradio_extract_lycoris_locon_tab(headless=headless) - gradio_merge_lora_tab(headless=headless) - gradio_merge_lycoris_tab(headless=headless) - gradio_resize_lora_tab(headless=headless) + _ = LoRATools(headless=headless) with gr.Tab('About'): gr.Markdown(f'kohya_ss GUI release {release}') with gr.Tab('README'): diff --git a/library/class_lora_tab.py b/library/class_lora_tab.py new file mode 100644 index 0000000..a19f34a --- /dev/null +++ b/library/class_lora_tab.py @@ -0,0 +1,42 @@ +import gradio as gr +from library.merge_lora_gui import gradio_merge_lora_tab +from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab +from library.verify_lora_gui import gradio_verify_lora_tab +from library.resize_lora_gui import gradio_resize_lora_tab +from library.extract_lora_gui import gradio_extract_lora_tab +from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab +from library.extract_lora_from_dylora_gui import gradio_extract_dylora_tab +from library.merge_lycoris_gui import gradio_merge_lycoris_tab + +# Deprecated code +from library.dataset_balancing_gui import gradio_dataset_balancing_tab +from library.dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) + +class LoRATools: + def __init__(self, folders = "", headless:bool = False): + self.headless = headless + self.folders = folders + + gr.Markdown( + 'This section provide LoRA tools to help setup your dataset...' + ) + gradio_extract_dylora_tab(headless=headless) + gradio_extract_lora_tab(headless=headless) + gradio_extract_lycoris_locon_tab(headless=headless) + gradio_merge_lora_tab(headless=headless) + gradio_merge_lycoris_tab(headless=headless) + gradio_svd_merge_lora_tab(headless=headless) + gradio_resize_lora_tab(headless=headless) + gradio_verify_lora_tab(headless=headless) + if folders: + with gr.Tab('Deprecated'): + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=folders.train_data_dir, + reg_data_dir_input=folders.reg_data_dir, + output_dir_input=folders.output_dir, + logging_dir_input=folders.logging_dir, + headless=headless, + ) + gradio_dataset_balancing_tab(headless=headless) \ No newline at end of file diff --git a/library/class_sdxl_parameters.py b/library/class_sdxl_parameters.py index 8f7883e..33af863 100644 --- a/library/class_sdxl_parameters.py +++ b/library/class_sdxl_parameters.py @@ -2,8 +2,9 @@ import gradio as gr ### SDXL Parameters class class SDXLParameters: - def __init__(self, sdxl_checkbox): + def __init__(self, sdxl_checkbox, show_sdxl_cache_text_encoder_outputs:bool = True): self.sdxl_checkbox = sdxl_checkbox + self.show_sdxl_cache_text_encoder_outputs = show_sdxl_cache_text_encoder_outputs with gr.Accordion(visible=False, open=True, label='SDXL Specific Parameters') as self.sdxl_row: with gr.Row(): @@ -11,11 +12,12 @@ class SDXLParameters: label='Cache text encoder outputs', info='Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.', value=False, + visible=show_sdxl_cache_text_encoder_outputs ) self.sdxl_no_half_vae = gr.Checkbox( label='No half VAE', info='Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.', - value=False + value=True ) self.sdxl_checkbox.change(lambda sdxl_checkbox: gr.Accordion.update(visible=sdxl_checkbox), inputs=[self.sdxl_checkbox], outputs=[self.sdxl_row]) diff --git a/library/class_source_model.py b/library/class_source_model.py index 4080f04..509bc41 100644 --- a/library/class_source_model.py +++ b/library/class_source_model.py @@ -33,8 +33,8 @@ class SourceModel: label='Model Quick Pick', choices=[ 'custom', - 'stabilityai/stable-diffusion-xl-base-0.9', - 'stabilityai/stable-diffusion-xl-refiner-0.9', + # 'stabilityai/stable-diffusion-xl-base-0.9', + # 'stabilityai/stable-diffusion-xl-refiner-0.9', 'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned', 'stabilityai/stable-diffusion-2-1-base', 'stabilityai/stable-diffusion-2-base', diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py index c422272..9b5cce9 100644 --- a/library/svd_merge_lora_gui.py +++ b/library/svd_merge_lora_gui.py @@ -36,6 +36,11 @@ def svd_merge_lora( new_conv_rank, device, ): + # Check if the output file already exists + if os.path.isfile(save_to): + print(f"Output file '{save_to}' already exists. Aborting.") + return + # Check if the ratio total is equal to one. If not mormalise to 1 total_ratio = ratio_a + ratio_b + ratio_c + ratio_d if total_ratio != 1: @@ -78,7 +83,7 @@ def svd_merge_lora( run_cmd_ratios += f' {ratio_d}' run_cmd += run_cmd_models - run_cmd += run_cmd_ratiosacti + run_cmd += run_cmd_ratios run_cmd += f' --device {device}' run_cmd += f' --new_rank "{new_rank}"' run_cmd += f' --new_conv_rank "{new_conv_rank}"' diff --git a/lora_gui.py b/lora_gui.py index 2e9becf..0a6fc49 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -32,21 +32,14 @@ from library.class_basic_training import BasicTraining from library.class_advanced_training import AdvancedTraining from library.class_sdxl_parameters import SDXLParameters from library.class_folders import Folders -from library.dreambooth_folder_creation_gui import ( - gradio_dreambooth_folder_creation_tab, -) from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, ) -from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.utilities import utilities_tab -from library.merge_lora_gui import gradio_merge_lora_tab -from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab -from library.verify_lora_gui import gradio_verify_lora_tab -from library.resize_lora_gui import gradio_resize_lora_tab from library.class_sample_images import SampleImages, run_cmd_sample +from library.class_lora_tab import LoRATools from library.custom_logging import setup_logging @@ -1576,21 +1569,7 @@ def lora_tab( ) with gr.Tab('Tools'): - gr.Markdown( - 'This section provide LoRA tools to help setup your dataset...' - ) - gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, - reg_data_dir_input=folders.reg_data_dir, - output_dir_input=folders.output_dir, - logging_dir_input=folders.logging_dir, - headless=headless, - ) - gradio_dataset_balancing_tab(headless=headless) - gradio_merge_lora_tab(headless=headless) - gradio_svd_merge_lora_tab(headless=headless) - gradio_resize_lora_tab(headless=headless) - gradio_verify_lora_tab(headless=headless) + lora_tools = LoRATools(folders=folders, headless=headless) with gr.Tab('Guides'): gr.Markdown( diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index f91b8dd..702241c 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -30,6 +30,7 @@ from library.class_source_model import SourceModel from library.class_basic_training import BasicTraining from library.class_advanced_training import AdvancedTraining from library.class_folders import Folders +from library.class_sdxl_parameters import SDXLParameters from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, @@ -129,6 +130,7 @@ def save_configuration( scale_v_pred_loss_like_noise_pred, min_timestep, max_timestep, + sdxl_no_half_vae ): # Get list of function parameters and values parameters = list(locals().items()) @@ -245,6 +247,7 @@ def open_configuration( scale_v_pred_loss_like_noise_pred, min_timestep, max_timestep, + sdxl_no_half_vae ): # Get list of function parameters and values parameters = list(locals().items()) @@ -358,6 +361,7 @@ def train_model( scale_v_pred_loss_like_noise_pred, min_timestep, max_timestep, + sdxl_no_half_vae ): # Get list of function parameters and values parameters = list(locals().items()) @@ -421,13 +425,6 @@ def train_model( ): return - if sdxl: - output_message( - msg='TI training is not compatible with an SDXL model.', - headless=headless_bool, - ) - return - # if float(noise_offset) > 0 and ( # multires_noise_iterations > 0 or multires_noise_discount > 0 # ): @@ -520,7 +517,12 @@ def train_model( lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) log.info(f'lr_warmup_steps = {lr_warmup_steps}') - run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_textual_inversion.py"' + run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process}' + if sdxl: + run_cmd += f' "./sdxl_train_textual_inversion.py"' + else: + run_cmd += f' "./train_textual_inversion.py"' + if v2: run_cmd += ' --v2' if v_parameterization: @@ -563,6 +565,9 @@ def train_model( ) if int(gradient_accumulation_steps) > 1: run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' + + if sdxl_no_half_vae: + run_cmd += f' --no_half_vae' run_cmd += run_cmd_training( learning_rate=learning_rate, @@ -679,237 +684,244 @@ def ti_tab( dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) dummy_headless = gr.Label(value=headless, visible=False) - gr.Markdown('Train a TI using kohya textual inversion python code...') - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) + with gr.Tab('Training'): + gr.Markdown('Train a TI using kohya textual inversion python code...') + + # Setup Configuration Files Gradio + config = ConfigurationFile(headless) - source_model = SourceModel( - save_model_as_choices=[ - 'ckpt', - 'safetensors', - ], - headless=headless, - ) - - with gr.Tab('Folders'): - folders = Folders(headless=headless) - with gr.Tab('Parameters'): - with gr.Row(): - weights = gr.Textbox( - label='Resume TI training', - placeholder='(Optional) Path to existing TI embeding file to keep training', - ) - weights_file_input = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) - ) - weights_file_input.click( - get_file_path, - outputs=weights, - show_progress=False, - ) - with gr.Row(): - token_string = gr.Textbox( - label='Token string', - placeholder='eg: cat', - ) - init_word = gr.Textbox( - label='Init word', - value='*', - ) - num_vectors_per_token = gr.Slider( - minimum=1, - maximum=75, - value=1, - step=1, - label='Vectors', - ) - max_train_steps = gr.Textbox( - label='Max train steps', - placeholder='(Optional) Maximum number of steps', - ) - template = gr.Dropdown( - label='Template', - choices=[ - 'caption', - 'object template', - 'style template', - ], - value='caption', - ) - basic_training = BasicTraining( - learning_rate_value='1e-5', - lr_scheduler_value='cosine', - lr_warmup_value='10', - ) - with gr.Accordion('Advanced Configuration', open=False): - advanced_training = AdvancedTraining(headless=headless) - advanced_training.color_aug.change( - color_aug_changed, - inputs=[advanced_training.color_aug], - outputs=[basic_training.cache_latents], - ) - - sample = SampleImages() - - with gr.Tab('Tools'): - gr.Markdown( - 'This section provide Dreambooth tools to help setup your dataset...' - ) - gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, - reg_data_dir_input=folders.reg_data_dir, - output_dir_input=folders.output_dir, - logging_dir_input=folders.logging_dir, + source_model = SourceModel( + save_model_as_choices=[ + 'ckpt', + 'safetensors', + ], headless=headless, ) - button_run = gr.Button('Train model', variant='primary') + with gr.Tab('Folders'): + folders = Folders(headless=headless) + with gr.Tab('Parameters'): + with gr.Row(): + weights = gr.Textbox( + label='Resume TI training', + placeholder='(Optional) Path to existing TI embeding file to keep training', + ) + weights_file_input = gr.Button( + '📂', elem_id='open_folder_small', visible=(not headless) + ) + weights_file_input.click( + get_file_path, + outputs=weights, + show_progress=False, + ) + with gr.Row(): + token_string = gr.Textbox( + label='Token string', + placeholder='eg: cat', + ) + init_word = gr.Textbox( + label='Init word', + value='*', + ) + num_vectors_per_token = gr.Slider( + minimum=1, + maximum=75, + value=1, + step=1, + label='Vectors', + ) + max_train_steps = gr.Textbox( + label='Max train steps', + placeholder='(Optional) Maximum number of steps', + ) + template = gr.Dropdown( + label='Template', + choices=[ + 'caption', + 'object template', + 'style template', + ], + value='caption', + ) + basic_training = BasicTraining( + learning_rate_value='1e-5', + lr_scheduler_value='cosine', + lr_warmup_value='10', + ) + + # Add SDXL Parameters + sdxl_params = SDXLParameters(source_model.sdxl_checkbox, show_sdxl_cache_text_encoder_outputs=False) + + with gr.Accordion('Advanced Configuration', open=False): + advanced_training = AdvancedTraining(headless=headless) + advanced_training.color_aug.change( + color_aug_changed, + inputs=[advanced_training.color_aug], + outputs=[basic_training.cache_latents], + ) - button_print = gr.Button('Print training command') + sample = SampleImages() - # Setup gradio tensorboard buttons - button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + with gr.Tab('Tools'): + gr.Markdown( + 'This section provide Dreambooth tools to help setup your dataset...' + ) + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=folders.train_data_dir, + reg_data_dir_input=folders.reg_data_dir, + output_dir_input=folders.output_dir, + logging_dir_input=folders.logging_dir, + headless=headless, + ) - button_start_tensorboard.click( - start_tensorboard, - inputs=folders.logging_dir, - show_progress=False, - ) + button_run = gr.Button('Train model', variant='primary') - button_stop_tensorboard.click( - stop_tensorboard, - show_progress=False, - ) + button_print = gr.Button('Print training command') - settings_list = [ - source_model.pretrained_model_name_or_path, - source_model.v2, - source_model.v_parameterization, - source_model.sdxl_checkbox, - folders.logging_dir, - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - basic_training.max_resolution, - basic_training.learning_rate, - basic_training.lr_scheduler, - basic_training.lr_warmup, - basic_training.train_batch_size, - basic_training.epoch, - basic_training.save_every_n_epochs, - basic_training.mixed_precision, - basic_training.save_precision, - basic_training.seed, - basic_training.num_cpu_threads_per_process, - basic_training.cache_latents, - basic_training.cache_latents_to_disk, - basic_training.caption_extension, - basic_training.enable_bucket, - advanced_training.gradient_checkpointing, - advanced_training.full_fp16, - advanced_training.no_token_padding, - basic_training.stop_text_encoder_training, - advanced_training.xformers, - source_model.save_model_as, - advanced_training.shuffle_caption, - advanced_training.save_state, - advanced_training.resume, - advanced_training.prior_loss_weight, - advanced_training.color_aug, - advanced_training.flip_aug, - advanced_training.clip_skip, - advanced_training.vae, - folders.output_name, - advanced_training.max_token_length, - advanced_training.max_train_epochs, - advanced_training.max_data_loader_n_workers, - advanced_training.mem_eff_attn, - advanced_training.gradient_accumulation_steps, - source_model.model_list, - token_string, - init_word, - num_vectors_per_token, - max_train_steps, - weights, - template, - advanced_training.keep_tokens, - advanced_training.persistent_data_loader_workers, - advanced_training.bucket_no_upscale, - advanced_training.random_crop, - advanced_training.bucket_reso_steps, - advanced_training.caption_dropout_every_n_epochs, - advanced_training.caption_dropout_rate, - basic_training.optimizer, - basic_training.optimizer_args, - advanced_training.noise_offset_type, - advanced_training.noise_offset, - advanced_training.adaptive_noise_scale, - advanced_training.multires_noise_iterations, - advanced_training.multires_noise_discount, - sample.sample_every_n_steps, - sample.sample_every_n_epochs, - sample.sample_sampler, - sample.sample_prompts, - advanced_training.additional_parameters, - advanced_training.vae_batch_size, - advanced_training.min_snr_gamma, - advanced_training.save_every_n_steps, - advanced_training.save_last_n_steps, - advanced_training.save_last_n_steps_state, - advanced_training.use_wandb, - advanced_training.wandb_api_key, - advanced_training.scale_v_pred_loss_like_noise_pred, - advanced_training.min_timestep, - advanced_training.max_timestep - ] + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() - config.button_open_config.click( - open_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) + button_start_tensorboard.click( + start_tensorboard, + inputs=folders.logging_dir, + show_progress=False, + ) - config.button_load_config.click( - open_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) - config.button_save_config.click( - save_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) + settings_list = [ + source_model.pretrained_model_name_or_path, + source_model.v2, + source_model.v_parameterization, + source_model.sdxl_checkbox, + folders.logging_dir, + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + basic_training.max_resolution, + basic_training.learning_rate, + basic_training.lr_scheduler, + basic_training.lr_warmup, + basic_training.train_batch_size, + basic_training.epoch, + basic_training.save_every_n_epochs, + basic_training.mixed_precision, + basic_training.save_precision, + basic_training.seed, + basic_training.num_cpu_threads_per_process, + basic_training.cache_latents, + basic_training.cache_latents_to_disk, + basic_training.caption_extension, + basic_training.enable_bucket, + advanced_training.gradient_checkpointing, + advanced_training.full_fp16, + advanced_training.no_token_padding, + basic_training.stop_text_encoder_training, + advanced_training.xformers, + source_model.save_model_as, + advanced_training.shuffle_caption, + advanced_training.save_state, + advanced_training.resume, + advanced_training.prior_loss_weight, + advanced_training.color_aug, + advanced_training.flip_aug, + advanced_training.clip_skip, + advanced_training.vae, + folders.output_name, + advanced_training.max_token_length, + advanced_training.max_train_epochs, + advanced_training.max_data_loader_n_workers, + advanced_training.mem_eff_attn, + advanced_training.gradient_accumulation_steps, + source_model.model_list, + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + advanced_training.keep_tokens, + advanced_training.persistent_data_loader_workers, + advanced_training.bucket_no_upscale, + advanced_training.random_crop, + advanced_training.bucket_reso_steps, + advanced_training.caption_dropout_every_n_epochs, + advanced_training.caption_dropout_rate, + basic_training.optimizer, + basic_training.optimizer_args, + advanced_training.noise_offset_type, + advanced_training.noise_offset, + advanced_training.adaptive_noise_scale, + advanced_training.multires_noise_iterations, + advanced_training.multires_noise_discount, + sample.sample_every_n_steps, + sample.sample_every_n_epochs, + sample.sample_sampler, + sample.sample_prompts, + advanced_training.additional_parameters, + advanced_training.vae_batch_size, + advanced_training.min_snr_gamma, + advanced_training.save_every_n_steps, + advanced_training.save_last_n_steps, + advanced_training.save_last_n_steps_state, + advanced_training.use_wandb, + advanced_training.wandb_api_key, + advanced_training.scale_v_pred_loss_like_noise_pred, + advanced_training.min_timestep, + advanced_training.max_timestep, + sdxl_params.sdxl_no_half_vae, + ] - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) + config.button_open_config.click( + open_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) - button_run.click( - train_model, - inputs=[dummy_headless] + [dummy_db_false] + settings_list, - show_progress=False, - ) + config.button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) - button_print.click( - train_model, - inputs=[dummy_headless] + [dummy_db_true] + settings_list, - show_progress=False, - ) + config.button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) - return ( - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - folders.logging_dir, - ) + config.button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) + + button_run.click( + train_model, + inputs=[dummy_headless] + [dummy_db_false] + settings_list, + show_progress=False, + ) + + button_print.click( + train_model, + inputs=[dummy_headless] + [dummy_db_true] + settings_list, + show_progress=False, + ) + + return ( + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + folders.logging_dir, + ) def UI(**kwargs):