From c4247408fe40aaf393a5bf781eb85301b9291f60 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 25 May 2025 17:44:03 -0400 Subject: [PATCH] Feat/add max grad norm dreambooth (#3251) * I've added a `max_grad_norm` parameter to the Dreambooth GUI. This change adds support for the `max_grad_norm` parameter to the Dreambooth training GUI. - The `max_grad_norm` option is now available in the 'Basic' training parameters section. - The value you set in the GUI for `max_grad_norm` is passed to the training script via the generated TOML configuration file, similar to its existing implementation in the LoRA GUI. * Fix mising entry for max_grad_norm --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- gui-uv.bat | 2 -- kohya_gui/class_basic_training.py | 8 +------- kohya_gui/dreambooth_gui.py | 5 +++++ 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/gui-uv.bat b/gui-uv.bat index 337a570..6a3e309 100644 --- a/gui-uv.bat +++ b/gui-uv.bat @@ -13,8 +13,6 @@ if %errorlevel% neq 0 ( echo Okay, please install uv manually from https://astral.sh/uv and then re-run this script. Exiting. exit /b 1 ) -) else ( - echo uv is already installed. ) endlocal diff --git a/kohya_gui/class_basic_training.py b/kohya_gui/class_basic_training.py index 7cda94b..883cdb9 100644 --- a/kohya_gui/class_basic_training.py +++ b/kohya_gui/class_basic_training.py @@ -230,13 +230,7 @@ class BasicTraining: """ with gr.Row(): # Initialize the maximum gradient norm slider - self.max_grad_norm = gr.Slider( - label="Max grad norm", - value=self.config.get("basic.max_grad_norm", 1.0), - minimum=0.0, - maximum=1.0, - interactive=True, - ) + self.max_grad_norm = gr.Number(label='Max grad norm', value=1.0, interactive=True) # Initialize the learning rate scheduler extra arguments textbox self.lr_scheduler_args = gr.Textbox( label="LR scheduler extra arguments", diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index 98c6a8f..b2702e2 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -132,6 +132,7 @@ def save_configuration( keep_tokens, lr_scheduler_num_cycles, lr_scheduler_power, + max_grad_norm, persistent_data_loader_workers, bucket_no_upscale, random_crop, @@ -342,6 +343,7 @@ def open_configuration( keep_tokens, lr_scheduler_num_cycles, lr_scheduler_power, + max_grad_norm, persistent_data_loader_workers, bucket_no_upscale, random_crop, @@ -547,6 +549,7 @@ def train_model( keep_tokens, lr_scheduler_num_cycles, lr_scheduler_power, + max_grad_norm, persistent_data_loader_workers, bucket_no_upscale, random_crop, @@ -953,6 +956,7 @@ def train_model( "lr_warmup_steps": lr_warmup_steps, "masked_loss": masked_loss, "max_bucket_reso": max_bucket_reso, + "max_grad_norm": max_grad_norm, "max_timestep": max_timestep if max_timestep != 0 else None, "max_token_length": int(max_token_length), "max_train_epochs": ( @@ -1329,6 +1333,7 @@ def dreambooth_tab( advanced_training.keep_tokens, basic_training.lr_scheduler_num_cycles, basic_training.lr_scheduler_power, + basic_training.max_grad_norm, advanced_training.persistent_data_loader_workers, advanced_training.bucket_no_upscale, advanced_training.random_crop,