diff --git a/gui-uv.bat b/gui-uv.bat index 337a570..6a3e309 100644 --- a/gui-uv.bat +++ b/gui-uv.bat @@ -13,8 +13,6 @@ if %errorlevel% neq 0 ( echo Okay, please install uv manually from https://astral.sh/uv and then re-run this script. Exiting. exit /b 1 ) -) else ( - echo uv is already installed. ) endlocal diff --git a/kohya_gui/class_basic_training.py b/kohya_gui/class_basic_training.py index 7cda94b..883cdb9 100644 --- a/kohya_gui/class_basic_training.py +++ b/kohya_gui/class_basic_training.py @@ -230,13 +230,7 @@ class BasicTraining: """ with gr.Row(): # Initialize the maximum gradient norm slider - self.max_grad_norm = gr.Slider( - label="Max grad norm", - value=self.config.get("basic.max_grad_norm", 1.0), - minimum=0.0, - maximum=1.0, - interactive=True, - ) + self.max_grad_norm = gr.Number(label='Max grad norm', value=1.0, interactive=True) # Initialize the learning rate scheduler extra arguments textbox self.lr_scheduler_args = gr.Textbox( label="LR scheduler extra arguments", diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index 98c6a8f..b2702e2 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -132,6 +132,7 @@ def save_configuration( keep_tokens, lr_scheduler_num_cycles, lr_scheduler_power, + max_grad_norm, persistent_data_loader_workers, bucket_no_upscale, random_crop, @@ -342,6 +343,7 @@ def open_configuration( keep_tokens, lr_scheduler_num_cycles, lr_scheduler_power, + max_grad_norm, persistent_data_loader_workers, bucket_no_upscale, random_crop, @@ -547,6 +549,7 @@ def train_model( keep_tokens, lr_scheduler_num_cycles, lr_scheduler_power, + max_grad_norm, persistent_data_loader_workers, bucket_no_upscale, random_crop, @@ -953,6 +956,7 @@ def train_model( "lr_warmup_steps": lr_warmup_steps, "masked_loss": masked_loss, "max_bucket_reso": max_bucket_reso, + "max_grad_norm": max_grad_norm, "max_timestep": max_timestep if max_timestep != 0 else None, "max_token_length": int(max_token_length), "max_train_epochs": ( @@ -1329,6 +1333,7 @@ def dreambooth_tab( advanced_training.keep_tokens, basic_training.lr_scheduler_num_cycles, basic_training.lr_scheduler_power, + basic_training.max_grad_norm, advanced_training.persistent_data_loader_workers, advanced_training.bucket_no_upscale, advanced_training.random_crop,