Add support for LoRA-GGPO

pull/3174/head
bmaltais 2025-03-30 14:41:40 -04:00
parent f2efdcf207
commit 1c7ab4d4f3
4 changed files with 527 additions and 493 deletions

View File

@ -68,7 +68,7 @@ The GUI allows you to set the training parameters and generate and run the requi
## ToDo
- [ ] Add support for LoRA-GGPO introduced in sd-scripts merge of March 30, 2025
- [X] Add support for LoRA-GGPO introduced in sd-scripts merge of March 30, 2025
## 🦒 Colab

View File

@ -61,7 +61,7 @@ huggingface = None
use_shell = False
train_state_value = time.time()
document_symbol = "\U0001F4C4" # 📄
document_symbol = "\U0001f4c4" # 📄
presets_dir = rf"{scriptdir}/presets"
@ -79,7 +79,6 @@ LYCORIS_PRESETS_CHOICES = [
def save_configuration(
save_as_bool,
file_path,
# source model section
pretrained_model_name_or_path,
v2,
@ -93,12 +92,10 @@ def save_configuration(
output_name,
model_list,
training_comment,
# folders section
logging_dir,
reg_data_dir,
output_dir,
# basic training section
max_resolution,
learning_rate,
@ -125,7 +122,6 @@ def save_configuration(
lr_scheduler_args,
lr_scheduler_type,
max_grad_norm,
# accelerate launch section
mixed_precision,
num_cpu_threads_per_process,
@ -139,7 +135,6 @@ def save_configuration(
dynamo_use_fullgraph,
dynamo_use_dynamic,
extra_accelerate_launch_args,
### advanced training section
gradient_checkpointing,
fp8_base,
@ -203,11 +198,9 @@ def save_configuration(
vae,
weighted_captions,
debiased_estimation_loss,
# sdxl parameters section
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
###
text_encoder_lr,
t5xxl_lr,
@ -252,7 +245,9 @@ def save_configuration(
loraplus_lr_ratio,
loraplus_text_encoder_lr_ratio,
loraplus_unet_lr_ratio,
train_lora_ggpo,
ggpo_sigma,
ggpo_beta,
# huggingface section
huggingface_repo_id,
huggingface_token,
@ -262,14 +257,12 @@ def save_configuration(
save_state_to_huggingface,
resume_from_huggingface,
async_upload,
# metadata section
metadata_author,
metadata_description,
metadata_license,
metadata_tags,
metadata_title,
# Flux1
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
@ -303,7 +296,6 @@ def save_configuration(
in_dims,
train_double_block_indices,
train_single_block_indices,
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
@ -373,7 +365,6 @@ def open_configuration(
ask_for_file,
apply_preset,
file_path,
# source model section
pretrained_model_name_or_path,
v2,
@ -387,12 +378,10 @@ def open_configuration(
output_name,
model_list,
training_comment,
# folders section
logging_dir,
reg_data_dir,
output_dir,
# basic training section
max_resolution,
learning_rate,
@ -419,7 +408,6 @@ def open_configuration(
lr_scheduler_args,
lr_scheduler_type,
max_grad_norm,
# accelerate launch section
mixed_precision,
num_cpu_threads_per_process,
@ -433,7 +421,6 @@ def open_configuration(
dynamo_use_fullgraph,
dynamo_use_dynamic,
extra_accelerate_launch_args,
### advanced training section
gradient_checkpointing,
fp8_base,
@ -497,11 +484,9 @@ def open_configuration(
vae,
weighted_captions,
debiased_estimation_loss,
# sdxl parameters section
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
###
text_encoder_lr,
t5xxl_lr,
@ -546,7 +531,9 @@ def open_configuration(
loraplus_lr_ratio,
loraplus_text_encoder_lr_ratio,
loraplus_unet_lr_ratio,
train_lora_ggpo,
ggpo_sigma,
ggpo_beta,
# huggingface section
huggingface_repo_id,
huggingface_token,
@ -556,14 +543,12 @@ def open_configuration(
save_state_to_huggingface,
resume_from_huggingface,
async_upload,
# metadata section
metadata_author,
metadata_description,
metadata_license,
metadata_tags,
metadata_title,
# Flux1
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
@ -597,7 +582,6 @@ def open_configuration(
in_dims,
train_double_block_indices,
train_single_block_indices,
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
@ -621,7 +605,6 @@ def open_configuration(
sd3_text_encoder_batch_size,
weighting_scheme,
sd3_checkbox,
##
training_preset,
):
@ -701,7 +684,6 @@ def open_configuration(
def train_model(
headless,
print_only,
# source model section
pretrained_model_name_or_path,
v2,
@ -715,12 +697,10 @@ def train_model(
output_name,
model_list,
training_comment,
# folders section
logging_dir,
reg_data_dir,
output_dir,
# basic training section
max_resolution,
learning_rate,
@ -747,7 +727,6 @@ def train_model(
lr_scheduler_args,
lr_scheduler_type,
max_grad_norm,
# accelerate launch section
mixed_precision,
num_cpu_threads_per_process,
@ -761,7 +740,6 @@ def train_model(
dynamo_use_fullgraph,
dynamo_use_dynamic,
extra_accelerate_launch_args,
### advanced training section
gradient_checkpointing,
fp8_base,
@ -825,11 +803,9 @@ def train_model(
vae,
weighted_captions,
debiased_estimation_loss,
# sdxl parameters section
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
###
text_encoder_lr,
t5xxl_lr,
@ -874,7 +850,9 @@ def train_model(
loraplus_lr_ratio,
loraplus_text_encoder_lr_ratio,
loraplus_unet_lr_ratio,
train_lora_ggpo,
ggpo_sigma,
ggpo_beta,
# huggingface section
huggingface_repo_id,
huggingface_token,
@ -884,14 +862,12 @@ def train_model(
save_state_to_huggingface,
resume_from_huggingface,
async_upload,
# metadata section
metadata_author,
metadata_description,
metadata_license,
metadata_tags,
metadata_title,
# Flux1
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
@ -925,7 +901,6 @@ def train_model(
in_dims,
train_double_block_indices,
train_single_block_indices,
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
@ -976,8 +951,14 @@ def train_model(
if flux1_checkbox:
log.info(f"Validating lora type is Flux1 if flux1 checkbox is checked...")
if (LoRA_type != "Flux1") and (LoRA_type != "Flux1 OFT") and ("LyCORIS" not in LoRA_type):
log.error("LoRA type must be set to 'Flux1', 'Flux1 OFT' or 'LyCORIS' if Flux1 checkbox is checked.")
if (
(LoRA_type != "Flux1")
and (LoRA_type != "Flux1 OFT")
and ("LyCORIS" not in LoRA_type)
):
log.error(
"LoRA type must be set to 'Flux1', 'Flux1 OFT' or 'LyCORIS' if Flux1 checkbox is checked."
)
return TRAIN_BUTTON_VISIBLE
#
@ -1182,7 +1163,9 @@ def train_model(
if lr_warmup_steps > 0:
lr_warmup_steps = int(lr_warmup_steps)
if lr_warmup > 0:
log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.")
log.warning(
"Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used."
)
elif lr_warmup != 0:
lr_warmup_steps = lr_warmup / 100
else:
@ -1255,7 +1238,7 @@ def train_model(
if LoRA_type == "LyCORIS/LoHa":
network_module = "lycoris.kohya"
network_args = f' preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} use_tucker={use_tucker} rank_dropout={rank_dropout} bypass_mode={bypass_mode} dora_wd={dora_wd} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout_scale={rank_dropout_scale} algo=loha train_norm={train_norm}'
network_args = f" preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} use_tucker={use_tucker} rank_dropout={rank_dropout} bypass_mode={bypass_mode} dora_wd={dora_wd} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout_scale={rank_dropout_scale} algo=loha train_norm={train_norm}"
if LoRA_type == "LyCORIS/LoKr":
network_module = "lycoris.kohya"
@ -1265,7 +1248,7 @@ def train_model(
network_module = "lycoris.kohya"
network_args = f" preset={LyCORIS_preset} rank_dropout={rank_dropout} module_dropout={module_dropout} rank_dropout_scale={rank_dropout_scale} algo=full train_norm={train_norm}"
if LoRA_type == "Flux1":
if LoRA_type in ["Flux1"]:
# Add a list of supported network arguments for Flux1 below when supported
kohya_lora_var_list = [
"img_attn_dim",
@ -1280,6 +1263,11 @@ def train_model(
"train_double_block_indices",
"train_single_block_indices",
]
if train_lora_ggpo:
kohya_lora_var_list += [
"ggpo_beta",
"ggpo_sigma",
]
network_module = "networks.lora_flux"
kohya_lora_vars = {
key: value
@ -1418,7 +1406,9 @@ def train_model(
# Set the text_encoder_lr to multiple values if both text_encoder_lr and t5xxl_lr are set
if text_encoder_lr == 0 and t5xxl_lr > 0:
log.error("When specifying T5XXL learning rate, text encoder learning rate need to be a value greater than 0.")
log.error(
"When specifying T5XXL learning rate, text encoder learning rate need to be a value greater than 0."
)
return TRAIN_BUTTON_VISIBLE
text_encoder_lr_list = []
@ -1489,8 +1479,10 @@ def train_model(
),
"cache_text_encoder_outputs_to_disk": (
True
if flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk
or sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk
if flux1_checkbox
and flux1_cache_text_encoder_outputs_to_disk
or sd3_checkbox
and sd3_cache_text_encoder_outputs_to_disk
else None
),
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
@ -1658,7 +1650,6 @@ def train_model(
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
"weighted_captions": weighted_captions,
"xformers": True if xformers == "xformers" else None,
# SD3 only Parameters
# "cache_text_encoder_outputs": see previous assignment above for code
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
@ -1666,7 +1657,9 @@ def train_model(
"clip_g_dropout_rate": clip_g_dropout_rate if sd3_checkbox else None,
# "clip_l": see previous assignment above for code
"clip_l_dropout_rate": sd3_clip_l_dropout_rate if sd3_checkbox else None,
"enable_scaled_pos_embed": sd3_enable_scaled_pos_embed if sd3_checkbox else None,
"enable_scaled_pos_embed": (
sd3_enable_scaled_pos_embed if sd3_checkbox else None
),
"logit_mean": logit_mean if sd3_checkbox else None,
"logit_std": logit_std if sd3_checkbox else None,
"mode_scale": mode_scale if sd3_checkbox else None,
@ -1681,7 +1674,6 @@ def train_model(
sd3_text_encoder_batch_size if sd3_checkbox else None
),
"weighting_scheme": weighting_scheme if sd3_checkbox else None,
# Flux.1 specific parameters
# "cache_text_encoder_outputs": see previous assignment above for code
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
@ -1692,11 +1684,15 @@ def train_model(
"model_prediction_type": model_prediction_type if flux1_checkbox else None,
"timestep_sampling": timestep_sampling if flux1_checkbox else None,
"split_mode": split_mode if flux1_checkbox else None,
"t5xxl_max_token_length": int(t5xxl_max_token_length) if flux1_checkbox else None,
"t5xxl_max_token_length": (
int(t5xxl_max_token_length) if flux1_checkbox else None
),
"guidance_scale": float(guidance_scale) if flux1_checkbox else None,
"mem_eff_save": mem_eff_save if flux1_checkbox else None,
"apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None,
"cpu_offload_checkpointing": cpu_offload_checkpointing if flux1_checkbox else None,
"cpu_offload_checkpointing": (
cpu_offload_checkpointing if flux1_checkbox else None
),
"blocks_to_swap": blocks_to_swap if flux1_checkbox or sd3_checkbox else None,
"single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None,
"double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None,
@ -1887,7 +1883,7 @@ def lora_tab(
visible=False,
interactive=True,
allow_custom_value=True,
info="Use path_to_config_file.toml to choose config file (for LyCORIS module settings)"
info="Use path_to_config_file.toml to choose config file (for LyCORIS module settings)",
)
with gr.Group():
with gr.Row():
@ -1971,9 +1967,7 @@ def lora_tab(
maximum=128,
)
# Add SDXL Parameters
sdxl_params = SDXLParameters(
source_model.sdxl_checkbox, config=config
)
sdxl_params = SDXLParameters(source_model.sdxl_checkbox, config=config)
# LyCORIS Specific parameters
with gr.Accordion("LyCORIS", visible=False) as lycoris_accordion:
@ -2131,6 +2125,40 @@ def lora_tab(
interactive=True,
)
with gr.Row(visible=False) as train_lora_ggpo_row:
train_lora_ggpo = gr.Checkbox(
label="Train LoRA GGPO",
value=False,
info="Train LoRA GGPO",
interactive=True,
)
with gr.Row(visible=False) as ggpo_row:
ggpo_sigma = gr.Number(
label="GGPO sigma",
value=0.03,
info="Specify the sigma of GGPO.",
interactive=True,
)
ggpo_beta = gr.Number(
label="GGPO beta",
value=0.01,
info="Specify the beta of GGPO.",
interactive=True,
)
# Update the visibility of the GGPO row based on the state of the "Train LoRA GGPO" checkbox
train_lora_ggpo.change(
lambda train_lora_ggpo: gr.Row(visible=train_lora_ggpo),
inputs=[train_lora_ggpo],
outputs=[ggpo_row],
)
# Update the visibility of the train lora ggpo row based on the model type being Flux.1
source_model.flux1_checkbox.change(
lambda flux1_checkbox: gr.Row(visible=flux1_checkbox),
inputs=[source_model.flux1_checkbox],
outputs=[train_lora_ggpo_row],
)
# Show or hide LoCon conv settings depending on LoRA type selection
def update_LoRA_settings(
LoRA_type,
@ -2586,7 +2614,9 @@ def lora_tab(
headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox
)
with gr.Accordion("Advanced", open=False, elem_classes="advanced_background"):
with gr.Accordion(
"Advanced", open=False, elem_classes="advanced_background"
):
# with gr.Accordion('Advanced Configuration', open=False):
with gr.Row(visible=True) as kohya_advanced_lora:
with gr.Tab(label="Weights"):
@ -2648,7 +2678,9 @@ def lora_tab(
sample = SampleImages(config=config)
global huggingface
with gr.Accordion("HuggingFace", open=False, elem_classes="huggingface_background"):
with gr.Accordion(
"HuggingFace", open=False, elem_classes="huggingface_background"
):
huggingface = HuggingFace(config=config)
LoRA_type.change(
@ -2860,6 +2892,9 @@ def lora_tab(
loraplus_lr_ratio,
loraplus_text_encoder_lr_ratio,
loraplus_unet_lr_ratio,
train_lora_ggpo,
ggpo_sigma,
ggpo_beta,
huggingface.huggingface_repo_id,
huggingface.huggingface_token,
huggingface.huggingface_repo_type,
@ -2906,7 +2941,6 @@ def lora_tab(
flux1_training.in_dims,
flux1_training.train_double_block_indices,
flux1_training.train_single_block_indices,
# SD3 Parameters
sd3_training.sd3_cache_text_encoder_outputs,
sd3_training.sd3_cache_text_encoder_outputs_to_disk,

View File

@ -12,7 +12,7 @@
"cache_latents_to_disk": false,
"caption_dropout_every_n_epochs": 0,
"caption_dropout_rate": 0.05,
"caption_extension": "",
"caption_extension": ".txt",
"clip_g": "",
"clip_l": "",
"clip_skip": 2,

View File

@ -748,7 +748,7 @@ wheels = [
[[package]]
name = "kohya-ss"
version = "0.1.0"
version = "25.0.4"
source = { virtual = "." }
dependencies = [
{ name = "accelerate" },