mirror of https://github.com/bmaltais/kohya_ss
Feat: Add logging for effective learning rates in LoRA GUI
This commit introduces a helper function, `get_effective_lr_messages`, into `kohya_gui/lora_gui.py` and integrates it into the `train_model` function. The purpose is to provide you with clearer information about how the learning rates set in the GUI (Main LR, Text Encoder LR, U-Net LR, T5XXL LR) will be interpreted and effectively applied by the underlying `sd-scripts` training engine. Before training commences, the GUI will now log: - The Main LR. - The effective LR for the primary Text Encoder (CLIP), indicating if it's a specific value or a fallback to the Main LR. - The effective LR for the T5XXL Text Encoder (if applicable), indicating its source (specific, inherited from primary TE, or fallback to Main LR). - The effective LR for the U-Net, indicating if it's a specific value or a fallback to the Main LR. This enhances transparency by helping you understand how your LR settings interact, without modifying the `sd-scripts` submodule.pull/3264/head
parent
d63a7fa2b6
commit
69d8b96c1c
|
|
@ -681,6 +681,64 @@ def open_configuration(
|
||||||
return tuple(values)
|
return tuple(values)
|
||||||
|
|
||||||
|
|
||||||
|
def get_effective_lr_messages(
|
||||||
|
main_lr_val: float,
|
||||||
|
text_encoder_lr_val: float, # Value from the 'Text Encoder learning rate' GUI field
|
||||||
|
unet_lr_val: float, # Value from the 'Unet learning rate' GUI field
|
||||||
|
t5xxl_lr_val: float # Value from the 'T5XXL learning rate' GUI field
|
||||||
|
) -> list[str]:
|
||||||
|
messages = []
|
||||||
|
# Format LRs to scientific notation with 2 decimal places for readability
|
||||||
|
f_main_lr = f"{main_lr_val:.2e}"
|
||||||
|
f_te_lr = f"{text_encoder_lr_val:.2e}"
|
||||||
|
f_unet_lr = f"{unet_lr_val:.2e}"
|
||||||
|
f_t5_lr = f"{t5xxl_lr_val:.2e}"
|
||||||
|
|
||||||
|
messages.append("Effective Learning Rate Configuration (based on GUI settings):")
|
||||||
|
messages.append(f" - Main LR (for optimizer & fallback): {f_main_lr}")
|
||||||
|
|
||||||
|
# --- Text Encoder (Primary/CLIP) LR ---
|
||||||
|
# If text_encoder_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback.
|
||||||
|
effective_clip_lr_str = f_main_lr
|
||||||
|
clip_lr_source_msg = "(Fallback to Main LR)"
|
||||||
|
if text_encoder_lr_val != 0.0:
|
||||||
|
effective_clip_lr_str = f_te_lr
|
||||||
|
clip_lr_source_msg = "(Specific Value)"
|
||||||
|
messages.append(f" - Text Encoder (Primary/CLIP) Effective LR: {effective_clip_lr_str} {clip_lr_source_msg}")
|
||||||
|
|
||||||
|
# --- Text Encoder (T5XXL, if applicable) LR ---
|
||||||
|
# Logic based on how text_encoder_lr_list is formed in train_model for sd-scripts:
|
||||||
|
# 1. If t5xxl_lr_val is non-zero, it's used for T5.
|
||||||
|
# 2. Else, if text_encoder_lr_val (primary TE LR) is non-zero, it's used for T5.
|
||||||
|
# 3. Else (both primary TE LR and specific T5XXL LR are zero), T5 uses main_lr_val.
|
||||||
|
effective_t5_lr_str = f_main_lr # Default fallback
|
||||||
|
t5_lr_source_msg = "(Fallback to Main LR)"
|
||||||
|
|
||||||
|
if t5xxl_lr_val != 0.0:
|
||||||
|
effective_t5_lr_str = f_t5_lr
|
||||||
|
t5_lr_source_msg = "(Specific T5XXL Value)"
|
||||||
|
elif text_encoder_lr_val != 0.0: # No specific T5 LR, but main TE LR is set
|
||||||
|
effective_t5_lr_str = f_te_lr # T5 inherits from the primary TE LR setting
|
||||||
|
t5_lr_source_msg = "(Inherited from Primary TE LR)"
|
||||||
|
# If both t5xxl_lr_val and text_encoder_lr_val are 0.0, effective_t5_lr_str remains f_main_lr.
|
||||||
|
|
||||||
|
# The message for T5XXL LR is always added for completeness, indicating its potential value.
|
||||||
|
# Users should understand it's relevant only if their model architecture uses a T5XXL text encoder.
|
||||||
|
messages.append(f" - Text Encoder (T5XXL, if applicable) Effective LR: {effective_t5_lr_str} {t5_lr_source_msg}")
|
||||||
|
|
||||||
|
# --- U-Net LR ---
|
||||||
|
# If unet_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback.
|
||||||
|
effective_unet_lr_str = f_main_lr
|
||||||
|
unet_lr_source_msg = "(Fallback to Main LR)"
|
||||||
|
if unet_lr_val != 0.0:
|
||||||
|
effective_unet_lr_str = f_unet_lr
|
||||||
|
unet_lr_source_msg = "(Specific Value)"
|
||||||
|
messages.append(f" - U-Net Effective LR: {effective_unet_lr_str} {unet_lr_source_msg}")
|
||||||
|
|
||||||
|
messages.append("Note: These LRs reflect the GUI's direct settings. Advanced options in sd-scripts (e.g., block LRs, LoRA+) can further modify rates for specific layers.")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def train_model(
|
def train_model(
|
||||||
headless,
|
headless,
|
||||||
print_only,
|
print_only,
|
||||||
|
|
@ -1426,10 +1484,21 @@ def train_model(
|
||||||
float(text_encoder_lr) if text_encoder_lr is not None else 0.0
|
float(text_encoder_lr) if text_encoder_lr is not None else 0.0
|
||||||
)
|
)
|
||||||
unet_lr_float = float(unet_lr) if unet_lr is not None else 0.0
|
unet_lr_float = float(unet_lr) if unet_lr is not None else 0.0
|
||||||
|
t5xxl_lr_float = float(t5xxl_lr) if t5xxl_lr is not None else 0.0
|
||||||
|
|
||||||
|
# Log effective learning rate messages
|
||||||
|
lr_messages = get_effective_lr_messages(
|
||||||
|
learning_rate_float,
|
||||||
|
text_encoder_lr_float,
|
||||||
|
unet_lr_float,
|
||||||
|
t5xxl_lr_float
|
||||||
|
)
|
||||||
|
for message in lr_messages:
|
||||||
|
log.info(message)
|
||||||
|
|
||||||
# Determine the training configuration based on learning rate values
|
# Determine the training configuration based on learning rate values
|
||||||
# Sets flags for training specific components based on the provided learning rates.
|
# Sets flags for training specific components based on the provided learning rates.
|
||||||
if learning_rate_float == unet_lr_float == text_encoder_lr_float == 0:
|
if learning_rate_float == 0.0 and text_encoder_lr_float == 0.0 and unet_lr_float == 0.0:
|
||||||
output_message(msg="Please input learning rate values.", headless=headless)
|
output_message(msg="Please input learning rate values.", headless=headless)
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
|
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue