mirror of https://github.com/bmaltais/kohya_ss
Feat/add max grad norm dreambooth (#3251)
* I've added a `max_grad_norm` parameter to the Dreambooth GUI. This change adds support for the `max_grad_norm` parameter to the Dreambooth training GUI. - The `max_grad_norm` option is now available in the 'Basic' training parameters section. - The value you set in the GUI for `max_grad_norm` is passed to the training script via the generated TOML configuration file, similar to its existing implementation in the LoRA GUI. * Fix mising entry for max_grad_norm --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>pull/3252/head
parent
93a06c7fc3
commit
c4247408fe
|
|
@ -13,8 +13,6 @@ if %errorlevel% neq 0 (
|
|||
echo Okay, please install uv manually from https://astral.sh/uv and then re-run this script. Exiting.
|
||||
exit /b 1
|
||||
)
|
||||
) else (
|
||||
echo uv is already installed.
|
||||
)
|
||||
endlocal
|
||||
|
||||
|
|
|
|||
|
|
@ -230,13 +230,7 @@ class BasicTraining:
|
|||
"""
|
||||
with gr.Row():
|
||||
# Initialize the maximum gradient norm slider
|
||||
self.max_grad_norm = gr.Slider(
|
||||
label="Max grad norm",
|
||||
value=self.config.get("basic.max_grad_norm", 1.0),
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
interactive=True,
|
||||
)
|
||||
self.max_grad_norm = gr.Number(label='Max grad norm', value=1.0, interactive=True)
|
||||
# Initialize the learning rate scheduler extra arguments textbox
|
||||
self.lr_scheduler_args = gr.Textbox(
|
||||
label="LR scheduler extra arguments",
|
||||
|
|
|
|||
|
|
@ -132,6 +132,7 @@ def save_configuration(
|
|||
keep_tokens,
|
||||
lr_scheduler_num_cycles,
|
||||
lr_scheduler_power,
|
||||
max_grad_norm,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
|
|
@ -342,6 +343,7 @@ def open_configuration(
|
|||
keep_tokens,
|
||||
lr_scheduler_num_cycles,
|
||||
lr_scheduler_power,
|
||||
max_grad_norm,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
|
|
@ -547,6 +549,7 @@ def train_model(
|
|||
keep_tokens,
|
||||
lr_scheduler_num_cycles,
|
||||
lr_scheduler_power,
|
||||
max_grad_norm,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
|
|
@ -953,6 +956,7 @@ def train_model(
|
|||
"lr_warmup_steps": lr_warmup_steps,
|
||||
"masked_loss": masked_loss,
|
||||
"max_bucket_reso": max_bucket_reso,
|
||||
"max_grad_norm": max_grad_norm,
|
||||
"max_timestep": max_timestep if max_timestep != 0 else None,
|
||||
"max_token_length": int(max_token_length),
|
||||
"max_train_epochs": (
|
||||
|
|
@ -1329,6 +1333,7 @@ def dreambooth_tab(
|
|||
advanced_training.keep_tokens,
|
||||
basic_training.lr_scheduler_num_cycles,
|
||||
basic_training.lr_scheduler_power,
|
||||
basic_training.max_grad_norm,
|
||||
advanced_training.persistent_data_loader_workers,
|
||||
advanced_training.bucket_no_upscale,
|
||||
advanced_training.random_crop,
|
||||
|
|
|
|||
Loading…
Reference in New Issue