Feat/add max grad norm dreambooth (#3251)

* I've added a `max_grad_norm` parameter to the Dreambooth GUI.

This change adds support for the `max_grad_norm` parameter to the Dreambooth training GUI.

- The `max_grad_norm` option is now available in the 'Basic' training parameters section.
- The value you set in the GUI for `max_grad_norm` is passed to the training script via the generated TOML configuration file, similar to its existing implementation in the LoRA GUI.

* Fix mising entry for max_grad_norm

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
pull/3252/head
bmaltais 2025-05-25 17:44:03 -04:00 committed by GitHub
parent 93a06c7fc3
commit c4247408fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 9 deletions

View File

@ -13,8 +13,6 @@ if %errorlevel% neq 0 (
echo Okay, please install uv manually from https://astral.sh/uv and then re-run this script. Exiting.
exit /b 1
)
) else (
echo uv is already installed.
)
endlocal

View File

@ -230,13 +230,7 @@ class BasicTraining:
"""
with gr.Row():
# Initialize the maximum gradient norm slider
self.max_grad_norm = gr.Slider(
label="Max grad norm",
value=self.config.get("basic.max_grad_norm", 1.0),
minimum=0.0,
maximum=1.0,
interactive=True,
)
self.max_grad_norm = gr.Number(label='Max grad norm', value=1.0, interactive=True)
# Initialize the learning rate scheduler extra arguments textbox
self.lr_scheduler_args = gr.Textbox(
label="LR scheduler extra arguments",

View File

@ -132,6 +132,7 @@ def save_configuration(
keep_tokens,
lr_scheduler_num_cycles,
lr_scheduler_power,
max_grad_norm,
persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
@ -342,6 +343,7 @@ def open_configuration(
keep_tokens,
lr_scheduler_num_cycles,
lr_scheduler_power,
max_grad_norm,
persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
@ -547,6 +549,7 @@ def train_model(
keep_tokens,
lr_scheduler_num_cycles,
lr_scheduler_power,
max_grad_norm,
persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
@ -953,6 +956,7 @@ def train_model(
"lr_warmup_steps": lr_warmup_steps,
"masked_loss": masked_loss,
"max_bucket_reso": max_bucket_reso,
"max_grad_norm": max_grad_norm,
"max_timestep": max_timestep if max_timestep != 0 else None,
"max_token_length": int(max_token_length),
"max_train_epochs": (
@ -1329,6 +1333,7 @@ def dreambooth_tab(
advanced_training.keep_tokens,
basic_training.lr_scheduler_num_cycles,
basic_training.lr_scheduler_power,
basic_training.max_grad_norm,
advanced_training.persistent_data_loader_workers,
advanced_training.bucket_no_upscale,
advanced_training.random_crop,