mirror of https://github.com/bmaltais/kohya_ss
Feat: Add logging for effective learning rates in LoRA GUI
This commit introduces a helper function, `get_effective_lr_messages`, into `kohya_gui/lora_gui.py` and integrates it into the `train_model` function. The purpose is to provide you with clearer information about how the learning rates set in the GUI (Main LR, Text Encoder LR, U-Net LR, T5XXL LR) will be interpreted and effectively applied by the underlying `sd-scripts` training engine. Before training commences, the GUI will now log: - The Main LR. - The effective LR for the primary Text Encoder (CLIP), indicating if it's a specific value or a fallback to the Main LR. - The effective LR for the T5XXL Text Encoder (if applicable), indicating its source (specific, inherited from primary TE, or fallback to Main LR). - The effective LR for the U-Net, indicating if it's a specific value or a fallback to the Main LR. This enhances transparency by helping you understand how your LR settings interact, without modifying the `sd-scripts` submodule.pull/3264/head
parent
d63a7fa2b6
commit
69d8b96c1c
|
|
@ -681,6 +681,64 @@ def open_configuration(
|
|||
return tuple(values)
|
||||
|
||||
|
||||
def get_effective_lr_messages(
|
||||
main_lr_val: float,
|
||||
text_encoder_lr_val: float, # Value from the 'Text Encoder learning rate' GUI field
|
||||
unet_lr_val: float, # Value from the 'Unet learning rate' GUI field
|
||||
t5xxl_lr_val: float # Value from the 'T5XXL learning rate' GUI field
|
||||
) -> list[str]:
|
||||
messages = []
|
||||
# Format LRs to scientific notation with 2 decimal places for readability
|
||||
f_main_lr = f"{main_lr_val:.2e}"
|
||||
f_te_lr = f"{text_encoder_lr_val:.2e}"
|
||||
f_unet_lr = f"{unet_lr_val:.2e}"
|
||||
f_t5_lr = f"{t5xxl_lr_val:.2e}"
|
||||
|
||||
messages.append("Effective Learning Rate Configuration (based on GUI settings):")
|
||||
messages.append(f" - Main LR (for optimizer & fallback): {f_main_lr}")
|
||||
|
||||
# --- Text Encoder (Primary/CLIP) LR ---
|
||||
# If text_encoder_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback.
|
||||
effective_clip_lr_str = f_main_lr
|
||||
clip_lr_source_msg = "(Fallback to Main LR)"
|
||||
if text_encoder_lr_val != 0.0:
|
||||
effective_clip_lr_str = f_te_lr
|
||||
clip_lr_source_msg = "(Specific Value)"
|
||||
messages.append(f" - Text Encoder (Primary/CLIP) Effective LR: {effective_clip_lr_str} {clip_lr_source_msg}")
|
||||
|
||||
# --- Text Encoder (T5XXL, if applicable) LR ---
|
||||
# Logic based on how text_encoder_lr_list is formed in train_model for sd-scripts:
|
||||
# 1. If t5xxl_lr_val is non-zero, it's used for T5.
|
||||
# 2. Else, if text_encoder_lr_val (primary TE LR) is non-zero, it's used for T5.
|
||||
# 3. Else (both primary TE LR and specific T5XXL LR are zero), T5 uses main_lr_val.
|
||||
effective_t5_lr_str = f_main_lr # Default fallback
|
||||
t5_lr_source_msg = "(Fallback to Main LR)"
|
||||
|
||||
if t5xxl_lr_val != 0.0:
|
||||
effective_t5_lr_str = f_t5_lr
|
||||
t5_lr_source_msg = "(Specific T5XXL Value)"
|
||||
elif text_encoder_lr_val != 0.0: # No specific T5 LR, but main TE LR is set
|
||||
effective_t5_lr_str = f_te_lr # T5 inherits from the primary TE LR setting
|
||||
t5_lr_source_msg = "(Inherited from Primary TE LR)"
|
||||
# If both t5xxl_lr_val and text_encoder_lr_val are 0.0, effective_t5_lr_str remains f_main_lr.
|
||||
|
||||
# The message for T5XXL LR is always added for completeness, indicating its potential value.
|
||||
# Users should understand it's relevant only if their model architecture uses a T5XXL text encoder.
|
||||
messages.append(f" - Text Encoder (T5XXL, if applicable) Effective LR: {effective_t5_lr_str} {t5_lr_source_msg}")
|
||||
|
||||
# --- U-Net LR ---
|
||||
# If unet_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback.
|
||||
effective_unet_lr_str = f_main_lr
|
||||
unet_lr_source_msg = "(Fallback to Main LR)"
|
||||
if unet_lr_val != 0.0:
|
||||
effective_unet_lr_str = f_unet_lr
|
||||
unet_lr_source_msg = "(Specific Value)"
|
||||
messages.append(f" - U-Net Effective LR: {effective_unet_lr_str} {unet_lr_source_msg}")
|
||||
|
||||
messages.append("Note: These LRs reflect the GUI's direct settings. Advanced options in sd-scripts (e.g., block LRs, LoRA+) can further modify rates for specific layers.")
|
||||
return messages
|
||||
|
||||
|
||||
def train_model(
|
||||
headless,
|
||||
print_only,
|
||||
|
|
@ -1426,10 +1484,21 @@ def train_model(
|
|||
float(text_encoder_lr) if text_encoder_lr is not None else 0.0
|
||||
)
|
||||
unet_lr_float = float(unet_lr) if unet_lr is not None else 0.0
|
||||
t5xxl_lr_float = float(t5xxl_lr) if t5xxl_lr is not None else 0.0
|
||||
|
||||
# Log effective learning rate messages
|
||||
lr_messages = get_effective_lr_messages(
|
||||
learning_rate_float,
|
||||
text_encoder_lr_float,
|
||||
unet_lr_float,
|
||||
t5xxl_lr_float
|
||||
)
|
||||
for message in lr_messages:
|
||||
log.info(message)
|
||||
|
||||
# Determine the training configuration based on learning rate values
|
||||
# Sets flags for training specific components based on the provided learning rates.
|
||||
if learning_rate_float == unet_lr_float == text_encoder_lr_float == 0:
|
||||
if learning_rate_float == 0.0 and text_encoder_lr_float == 0.0 and unet_lr_float == 0.0:
|
||||
output_message(msg="Please input learning rate values.", headless=headless)
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
|
||||
|
|
|
|||
Loading…
Reference in New Issue