diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index 0048189..4f42294 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -1421,7 +1421,7 @@ def train_model( text_encoder_lr_list = [float(text_encoder_lr), float(text_encoder_lr)] # Convert learning rates to float once and store the result for re-use - learning_rate = float(learning_rate) if learning_rate is not None else 0.0 + learning_rate_float = float(learning_rate) if learning_rate is not None else 0.0 text_encoder_lr_float = ( float(text_encoder_lr) if text_encoder_lr is not None else 0.0 ) @@ -1429,7 +1429,7 @@ def train_model( # Determine the training configuration based on learning rate values # Sets flags for training specific components based on the provided learning rates. - if float(learning_rate) == unet_lr_float == text_encoder_lr_float == 0: + if learning_rate_float == unet_lr_float == text_encoder_lr_float == 0: output_message(msg="Please input learning rate values.", headless=headless) return TRAIN_BUTTON_VISIBLE # Flag to train text encoder only if its learning rate is non-zero and unet's is zero. @@ -1437,11 +1437,6 @@ def train_model( # Flag to train unet only if its learning rate is non-zero and text encoder's is zero. network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0 - do_not_set_learning_rate = False # Initialize with a default value - if text_encoder_lr_float != 0 or unet_lr_float != 0: - log.info("Learning rate won't be used for training because text_encoder_lr or unet_lr is set.") - # do_not_set_learning_rate = True # This line is now commented out - clip_l_value = None if sd3_checkbox: # print("Setting clip_l_value to sd3_clip_l") @@ -1519,7 +1514,7 @@ def train_model( "ip_noise_gamma": ip_noise_gamma if ip_noise_gamma != 0 else None, "ip_noise_gamma_random_strength": ip_noise_gamma_random_strength, "keep_tokens": int(keep_tokens), - "learning_rate": None if do_not_set_learning_rate else learning_rate, + "learning_rate": learning_rate_float, "logging_dir": logging_dir, "log_config": log_config, "log_tracker_name": log_tracker_name, @@ -1640,7 +1635,7 @@ def train_model( "train_batch_size": train_batch_size, "train_data_dir": train_data_dir, "training_comment": training_comment, - "unet_lr": unet_lr if unet_lr != 0 else None, + "unet_lr": unet_lr_float if unet_lr_float != 0.0 else None, "log_with": log_with, "v2": v2, "v_parameterization": v_parameterization,