mirror of https://github.com/bmaltais/kohya_ss
Refactor: Clean up LR handling logic in LoRA GUI
This commit refactors the learning rate (LR) handling in `kohya_gui/lora_gui.py` for LoRA training. The previous fix for LR misinterpretation involved commenting out a line. This commit completes the cleanup by: - Removing the `do_not_set_learning_rate` variable and its associated conditional logic, which became redundant. - Renaming the float-converted `learning_rate` to `learning_rate_float` for clarity. - Ensuring that `learning_rate_float` and the float-converted `unet_lr_float` are consistently used when preparing the `config_toml_data` for the training script. This makes the code cleaner and the intent of always passing the main learning rate (along with specific TE/UNet LRs) more direct. The functional behavior of the LR fix remains the same.pull/3264/head
parent
3a8b599ba9
commit
d63a7fa2b6
|
|
@ -1421,7 +1421,7 @@ def train_model(
|
|||
text_encoder_lr_list = [float(text_encoder_lr), float(text_encoder_lr)]
|
||||
|
||||
# Convert learning rates to float once and store the result for re-use
|
||||
learning_rate = float(learning_rate) if learning_rate is not None else 0.0
|
||||
learning_rate_float = float(learning_rate) if learning_rate is not None else 0.0
|
||||
text_encoder_lr_float = (
|
||||
float(text_encoder_lr) if text_encoder_lr is not None else 0.0
|
||||
)
|
||||
|
|
@ -1429,7 +1429,7 @@ def train_model(
|
|||
|
||||
# Determine the training configuration based on learning rate values
|
||||
# Sets flags for training specific components based on the provided learning rates.
|
||||
if float(learning_rate) == unet_lr_float == text_encoder_lr_float == 0:
|
||||
if learning_rate_float == unet_lr_float == text_encoder_lr_float == 0:
|
||||
output_message(msg="Please input learning rate values.", headless=headless)
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
|
||||
|
|
@ -1437,11 +1437,6 @@ def train_model(
|
|||
# Flag to train unet only if its learning rate is non-zero and text encoder's is zero.
|
||||
network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0
|
||||
|
||||
do_not_set_learning_rate = False # Initialize with a default value
|
||||
if text_encoder_lr_float != 0 or unet_lr_float != 0:
|
||||
log.info("Learning rate won't be used for training because text_encoder_lr or unet_lr is set.")
|
||||
# do_not_set_learning_rate = True # This line is now commented out
|
||||
|
||||
clip_l_value = None
|
||||
if sd3_checkbox:
|
||||
# print("Setting clip_l_value to sd3_clip_l")
|
||||
|
|
@ -1519,7 +1514,7 @@ def train_model(
|
|||
"ip_noise_gamma": ip_noise_gamma if ip_noise_gamma != 0 else None,
|
||||
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
|
||||
"keep_tokens": int(keep_tokens),
|
||||
"learning_rate": None if do_not_set_learning_rate else learning_rate,
|
||||
"learning_rate": learning_rate_float,
|
||||
"logging_dir": logging_dir,
|
||||
"log_config": log_config,
|
||||
"log_tracker_name": log_tracker_name,
|
||||
|
|
@ -1640,7 +1635,7 @@ def train_model(
|
|||
"train_batch_size": train_batch_size,
|
||||
"train_data_dir": train_data_dir,
|
||||
"training_comment": training_comment,
|
||||
"unet_lr": unet_lr if unet_lr != 0 else None,
|
||||
"unet_lr": unet_lr_float if unet_lr_float != 0.0 else None,
|
||||
"log_with": log_with,
|
||||
"v2": v2,
|
||||
"v_parameterization": v_parameterization,
|
||||
|
|
|
|||
Loading…
Reference in New Issue