Feat: Add logging for effective learning rates in LoRA GUI

This commit introduces a helper function, `get_effective_lr_messages`, into `kohya_gui/lora_gui.py` and integrates it into the `train_model` function.

The purpose is to provide you with clearer information about how the learning rates set in the GUI (Main LR, Text Encoder LR, U-Net LR, T5XXL LR) will be interpreted and effectively applied by the underlying `sd-scripts` training engine.

Before training commences, the GUI will now log:
- The Main LR.
- The effective LR for the primary Text Encoder (CLIP), indicating if it's a specific value or a fallback to the Main LR.
- The effective LR for the T5XXL Text Encoder (if applicable), indicating its source (specific, inherited from primary TE, or fallback to Main LR).
- The effective LR for the U-Net, indicating if it's a specific value or a fallback to the Main LR.

This enhances transparency by helping you understand how your LR settings interact, without modifying the `sd-scripts` submodule.
pull/3264/head
google-labs-jules[bot] 2025-06-01 13:59:07 +00:00
parent d63a7fa2b6
commit 69d8b96c1c
1 changed files with 70 additions and 1 deletions

View File

@ -681,6 +681,64 @@ def open_configuration(
return tuple(values)
def get_effective_lr_messages(
main_lr_val: float,
text_encoder_lr_val: float, # Value from the 'Text Encoder learning rate' GUI field
unet_lr_val: float, # Value from the 'Unet learning rate' GUI field
t5xxl_lr_val: float # Value from the 'T5XXL learning rate' GUI field
) -> list[str]:
messages = []
# Format LRs to scientific notation with 2 decimal places for readability
f_main_lr = f"{main_lr_val:.2e}"
f_te_lr = f"{text_encoder_lr_val:.2e}"
f_unet_lr = f"{unet_lr_val:.2e}"
f_t5_lr = f"{t5xxl_lr_val:.2e}"
messages.append("Effective Learning Rate Configuration (based on GUI settings):")
messages.append(f" - Main LR (for optimizer & fallback): {f_main_lr}")
# --- Text Encoder (Primary/CLIP) LR ---
# If text_encoder_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback.
effective_clip_lr_str = f_main_lr
clip_lr_source_msg = "(Fallback to Main LR)"
if text_encoder_lr_val != 0.0:
effective_clip_lr_str = f_te_lr
clip_lr_source_msg = "(Specific Value)"
messages.append(f" - Text Encoder (Primary/CLIP) Effective LR: {effective_clip_lr_str} {clip_lr_source_msg}")
# --- Text Encoder (T5XXL, if applicable) LR ---
# Logic based on how text_encoder_lr_list is formed in train_model for sd-scripts:
# 1. If t5xxl_lr_val is non-zero, it's used for T5.
# 2. Else, if text_encoder_lr_val (primary TE LR) is non-zero, it's used for T5.
# 3. Else (both primary TE LR and specific T5XXL LR are zero), T5 uses main_lr_val.
effective_t5_lr_str = f_main_lr # Default fallback
t5_lr_source_msg = "(Fallback to Main LR)"
if t5xxl_lr_val != 0.0:
effective_t5_lr_str = f_t5_lr
t5_lr_source_msg = "(Specific T5XXL Value)"
elif text_encoder_lr_val != 0.0: # No specific T5 LR, but main TE LR is set
effective_t5_lr_str = f_te_lr # T5 inherits from the primary TE LR setting
t5_lr_source_msg = "(Inherited from Primary TE LR)"
# If both t5xxl_lr_val and text_encoder_lr_val are 0.0, effective_t5_lr_str remains f_main_lr.
# The message for T5XXL LR is always added for completeness, indicating its potential value.
# Users should understand it's relevant only if their model architecture uses a T5XXL text encoder.
messages.append(f" - Text Encoder (T5XXL, if applicable) Effective LR: {effective_t5_lr_str} {t5_lr_source_msg}")
# --- U-Net LR ---
# If unet_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback.
effective_unet_lr_str = f_main_lr
unet_lr_source_msg = "(Fallback to Main LR)"
if unet_lr_val != 0.0:
effective_unet_lr_str = f_unet_lr
unet_lr_source_msg = "(Specific Value)"
messages.append(f" - U-Net Effective LR: {effective_unet_lr_str} {unet_lr_source_msg}")
messages.append("Note: These LRs reflect the GUI's direct settings. Advanced options in sd-scripts (e.g., block LRs, LoRA+) can further modify rates for specific layers.")
return messages
def train_model(
headless,
print_only,
@ -1426,10 +1484,21 @@ def train_model(
float(text_encoder_lr) if text_encoder_lr is not None else 0.0
)
unet_lr_float = float(unet_lr) if unet_lr is not None else 0.0
t5xxl_lr_float = float(t5xxl_lr) if t5xxl_lr is not None else 0.0
# Log effective learning rate messages
lr_messages = get_effective_lr_messages(
learning_rate_float,
text_encoder_lr_float,
unet_lr_float,
t5xxl_lr_float
)
for message in lr_messages:
log.info(message)
# Determine the training configuration based on learning rate values
# Sets flags for training specific components based on the provided learning rates.
if learning_rate_float == unet_lr_float == text_encoder_lr_float == 0:
if learning_rate_float == 0.0 and text_encoder_lr_float == 0.0 and unet_lr_float == 0.0:
output_message(msg="Please input learning rate values.", headless=headless)
return TRAIN_BUTTON_VISIBLE
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.