Refactor: Clean up LR handling logic in LoRA GUI

This commit refactors the learning rate (LR) handling in `kohya_gui/lora_gui.py` for LoRA training.

The previous fix for LR misinterpretation involved commenting out a line. This commit completes the cleanup by:
- Removing the `do_not_set_learning_rate` variable and its associated conditional logic, which became redundant.
- Renaming the float-converted `learning_rate` to `learning_rate_float` for clarity.
- Ensuring that `learning_rate_float` and the float-converted `unet_lr_float` are consistently used when preparing the `config_toml_data` for the training script.

This makes the code cleaner and the intent of always passing the main learning rate (along with specific TE/UNet LRs) more direct. The functional behavior of the LR fix remains the same.
pull/3264/head
google-labs-jules[bot] 2025-06-01 12:29:08 +00:00
parent 3a8b599ba9
commit d63a7fa2b6
1 changed files with 4 additions and 9 deletions

View File

@ -1421,7 +1421,7 @@ def train_model(
text_encoder_lr_list = [float(text_encoder_lr), float(text_encoder_lr)]
# Convert learning rates to float once and store the result for re-use
learning_rate = float(learning_rate) if learning_rate is not None else 0.0
learning_rate_float = float(learning_rate) if learning_rate is not None else 0.0
text_encoder_lr_float = (
float(text_encoder_lr) if text_encoder_lr is not None else 0.0
)
@ -1429,7 +1429,7 @@ def train_model(
# Determine the training configuration based on learning rate values
# Sets flags for training specific components based on the provided learning rates.
if float(learning_rate) == unet_lr_float == text_encoder_lr_float == 0:
if learning_rate_float == unet_lr_float == text_encoder_lr_float == 0:
output_message(msg="Please input learning rate values.", headless=headless)
return TRAIN_BUTTON_VISIBLE
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
@ -1437,11 +1437,6 @@ def train_model(
# Flag to train unet only if its learning rate is non-zero and text encoder's is zero.
network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0
do_not_set_learning_rate = False # Initialize with a default value
if text_encoder_lr_float != 0 or unet_lr_float != 0:
log.info("Learning rate won't be used for training because text_encoder_lr or unet_lr is set.")
# do_not_set_learning_rate = True # This line is now commented out
clip_l_value = None
if sd3_checkbox:
# print("Setting clip_l_value to sd3_clip_l")
@ -1519,7 +1514,7 @@ def train_model(
"ip_noise_gamma": ip_noise_gamma if ip_noise_gamma != 0 else None,
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
"keep_tokens": int(keep_tokens),
"learning_rate": None if do_not_set_learning_rate else learning_rate,
"learning_rate": learning_rate_float,
"logging_dir": logging_dir,
"log_config": log_config,
"log_tracker_name": log_tracker_name,
@ -1640,7 +1635,7 @@ def train_model(
"train_batch_size": train_batch_size,
"train_data_dir": train_data_dir,
"training_comment": training_comment,
"unet_lr": unet_lr if unet_lr != 0 else None,
"unet_lr": unet_lr_float if unet_lr_float != 0.0 else None,
"log_with": log_with,
"v2": v2,
"v_parameterization": v_parameterization,