From 69d8b96c1c82449d808f2dd63b48989a0dae6c89 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 1 Jun 2025 13:59:07 +0000 Subject: [PATCH] Feat: Add logging for effective learning rates in LoRA GUI This commit introduces a helper function, `get_effective_lr_messages`, into `kohya_gui/lora_gui.py` and integrates it into the `train_model` function. The purpose is to provide you with clearer information about how the learning rates set in the GUI (Main LR, Text Encoder LR, U-Net LR, T5XXL LR) will be interpreted and effectively applied by the underlying `sd-scripts` training engine. Before training commences, the GUI will now log: - The Main LR. - The effective LR for the primary Text Encoder (CLIP), indicating if it's a specific value or a fallback to the Main LR. - The effective LR for the T5XXL Text Encoder (if applicable), indicating its source (specific, inherited from primary TE, or fallback to Main LR). - The effective LR for the U-Net, indicating if it's a specific value or a fallback to the Main LR. This enhances transparency by helping you understand how your LR settings interact, without modifying the `sd-scripts` submodule. --- kohya_gui/lora_gui.py | 71 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index 4f42294..edbb340 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -681,6 +681,64 @@ def open_configuration( return tuple(values) +def get_effective_lr_messages( + main_lr_val: float, + text_encoder_lr_val: float, # Value from the 'Text Encoder learning rate' GUI field + unet_lr_val: float, # Value from the 'Unet learning rate' GUI field + t5xxl_lr_val: float # Value from the 'T5XXL learning rate' GUI field +) -> list[str]: + messages = [] + # Format LRs to scientific notation with 2 decimal places for readability + f_main_lr = f"{main_lr_val:.2e}" + f_te_lr = f"{text_encoder_lr_val:.2e}" + f_unet_lr = f"{unet_lr_val:.2e}" + f_t5_lr = f"{t5xxl_lr_val:.2e}" + + messages.append("Effective Learning Rate Configuration (based on GUI settings):") + messages.append(f" - Main LR (for optimizer & fallback): {f_main_lr}") + + # --- Text Encoder (Primary/CLIP) LR --- + # If text_encoder_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback. + effective_clip_lr_str = f_main_lr + clip_lr_source_msg = "(Fallback to Main LR)" + if text_encoder_lr_val != 0.0: + effective_clip_lr_str = f_te_lr + clip_lr_source_msg = "(Specific Value)" + messages.append(f" - Text Encoder (Primary/CLIP) Effective LR: {effective_clip_lr_str} {clip_lr_source_msg}") + + # --- Text Encoder (T5XXL, if applicable) LR --- + # Logic based on how text_encoder_lr_list is formed in train_model for sd-scripts: + # 1. If t5xxl_lr_val is non-zero, it's used for T5. + # 2. Else, if text_encoder_lr_val (primary TE LR) is non-zero, it's used for T5. + # 3. Else (both primary TE LR and specific T5XXL LR are zero), T5 uses main_lr_val. + effective_t5_lr_str = f_main_lr # Default fallback + t5_lr_source_msg = "(Fallback to Main LR)" + + if t5xxl_lr_val != 0.0: + effective_t5_lr_str = f_t5_lr + t5_lr_source_msg = "(Specific T5XXL Value)" + elif text_encoder_lr_val != 0.0: # No specific T5 LR, but main TE LR is set + effective_t5_lr_str = f_te_lr # T5 inherits from the primary TE LR setting + t5_lr_source_msg = "(Inherited from Primary TE LR)" + # If both t5xxl_lr_val and text_encoder_lr_val are 0.0, effective_t5_lr_str remains f_main_lr. + + # The message for T5XXL LR is always added for completeness, indicating its potential value. + # Users should understand it's relevant only if their model architecture uses a T5XXL text encoder. + messages.append(f" - Text Encoder (T5XXL, if applicable) Effective LR: {effective_t5_lr_str} {t5_lr_source_msg}") + + # --- U-Net LR --- + # If unet_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback. + effective_unet_lr_str = f_main_lr + unet_lr_source_msg = "(Fallback to Main LR)" + if unet_lr_val != 0.0: + effective_unet_lr_str = f_unet_lr + unet_lr_source_msg = "(Specific Value)" + messages.append(f" - U-Net Effective LR: {effective_unet_lr_str} {unet_lr_source_msg}") + + messages.append("Note: These LRs reflect the GUI's direct settings. Advanced options in sd-scripts (e.g., block LRs, LoRA+) can further modify rates for specific layers.") + return messages + + def train_model( headless, print_only, @@ -1426,10 +1484,21 @@ def train_model( float(text_encoder_lr) if text_encoder_lr is not None else 0.0 ) unet_lr_float = float(unet_lr) if unet_lr is not None else 0.0 + t5xxl_lr_float = float(t5xxl_lr) if t5xxl_lr is not None else 0.0 + + # Log effective learning rate messages + lr_messages = get_effective_lr_messages( + learning_rate_float, + text_encoder_lr_float, + unet_lr_float, + t5xxl_lr_float + ) + for message in lr_messages: + log.info(message) # Determine the training configuration based on learning rate values # Sets flags for training specific components based on the provided learning rates. - if learning_rate_float == unet_lr_float == text_encoder_lr_float == 0: + if learning_rate_float == 0.0 and text_encoder_lr_float == 0.0 and unet_lr_float == 0.0: output_message(msg="Please input learning rate values.", headless=headless) return TRAIN_BUTTON_VISIBLE # Flag to train text encoder only if its learning rate is non-zero and unet's is zero.