mirror of https://github.com/bmaltais/kohya_ss
Add support for LoRA-GGPO
parent
f2efdcf207
commit
1c7ab4d4f3
|
|
@ -68,7 +68,7 @@ The GUI allows you to set the training parameters and generate and run the requi
|
|||
|
||||
## ToDo
|
||||
|
||||
- [ ] Add support for LoRA-GGPO introduced in sd-scripts merge of March 30, 2025
|
||||
- [X] Add support for LoRA-GGPO introduced in sd-scripts merge of March 30, 2025
|
||||
|
||||
## 🦒 Colab
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ huggingface = None
|
|||
use_shell = False
|
||||
train_state_value = time.time()
|
||||
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
document_symbol = "\U0001f4c4" # 📄
|
||||
|
||||
|
||||
presets_dir = rf"{scriptdir}/presets"
|
||||
|
|
@ -79,7 +79,6 @@ LYCORIS_PRESETS_CHOICES = [
|
|||
def save_configuration(
|
||||
save_as_bool,
|
||||
file_path,
|
||||
|
||||
# source model section
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
|
|
@ -93,12 +92,10 @@ def save_configuration(
|
|||
output_name,
|
||||
model_list,
|
||||
training_comment,
|
||||
|
||||
# folders section
|
||||
logging_dir,
|
||||
reg_data_dir,
|
||||
output_dir,
|
||||
|
||||
# basic training section
|
||||
max_resolution,
|
||||
learning_rate,
|
||||
|
|
@ -125,7 +122,6 @@ def save_configuration(
|
|||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
max_grad_norm,
|
||||
|
||||
# accelerate launch section
|
||||
mixed_precision,
|
||||
num_cpu_threads_per_process,
|
||||
|
|
@ -139,7 +135,6 @@ def save_configuration(
|
|||
dynamo_use_fullgraph,
|
||||
dynamo_use_dynamic,
|
||||
extra_accelerate_launch_args,
|
||||
|
||||
### advanced training section
|
||||
gradient_checkpointing,
|
||||
fp8_base,
|
||||
|
|
@ -203,11 +198,9 @@ def save_configuration(
|
|||
vae,
|
||||
weighted_captions,
|
||||
debiased_estimation_loss,
|
||||
|
||||
# sdxl parameters section
|
||||
sdxl_cache_text_encoder_outputs,
|
||||
sdxl_no_half_vae,
|
||||
|
||||
###
|
||||
text_encoder_lr,
|
||||
t5xxl_lr,
|
||||
|
|
@ -252,7 +245,9 @@ def save_configuration(
|
|||
loraplus_lr_ratio,
|
||||
loraplus_text_encoder_lr_ratio,
|
||||
loraplus_unet_lr_ratio,
|
||||
|
||||
train_lora_ggpo,
|
||||
ggpo_sigma,
|
||||
ggpo_beta,
|
||||
# huggingface section
|
||||
huggingface_repo_id,
|
||||
huggingface_token,
|
||||
|
|
@ -262,14 +257,12 @@ def save_configuration(
|
|||
save_state_to_huggingface,
|
||||
resume_from_huggingface,
|
||||
async_upload,
|
||||
|
||||
# metadata section
|
||||
metadata_author,
|
||||
metadata_description,
|
||||
metadata_license,
|
||||
metadata_tags,
|
||||
metadata_title,
|
||||
|
||||
# Flux1
|
||||
flux1_cache_text_encoder_outputs,
|
||||
flux1_cache_text_encoder_outputs_to_disk,
|
||||
|
|
@ -303,7 +296,6 @@ def save_configuration(
|
|||
in_dims,
|
||||
train_double_block_indices,
|
||||
train_single_block_indices,
|
||||
|
||||
# SD3 parameters
|
||||
sd3_cache_text_encoder_outputs,
|
||||
sd3_cache_text_encoder_outputs_to_disk,
|
||||
|
|
@ -373,7 +365,6 @@ def open_configuration(
|
|||
ask_for_file,
|
||||
apply_preset,
|
||||
file_path,
|
||||
|
||||
# source model section
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
|
|
@ -387,12 +378,10 @@ def open_configuration(
|
|||
output_name,
|
||||
model_list,
|
||||
training_comment,
|
||||
|
||||
# folders section
|
||||
logging_dir,
|
||||
reg_data_dir,
|
||||
output_dir,
|
||||
|
||||
# basic training section
|
||||
max_resolution,
|
||||
learning_rate,
|
||||
|
|
@ -419,7 +408,6 @@ def open_configuration(
|
|||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
max_grad_norm,
|
||||
|
||||
# accelerate launch section
|
||||
mixed_precision,
|
||||
num_cpu_threads_per_process,
|
||||
|
|
@ -433,7 +421,6 @@ def open_configuration(
|
|||
dynamo_use_fullgraph,
|
||||
dynamo_use_dynamic,
|
||||
extra_accelerate_launch_args,
|
||||
|
||||
### advanced training section
|
||||
gradient_checkpointing,
|
||||
fp8_base,
|
||||
|
|
@ -497,11 +484,9 @@ def open_configuration(
|
|||
vae,
|
||||
weighted_captions,
|
||||
debiased_estimation_loss,
|
||||
|
||||
# sdxl parameters section
|
||||
sdxl_cache_text_encoder_outputs,
|
||||
sdxl_no_half_vae,
|
||||
|
||||
###
|
||||
text_encoder_lr,
|
||||
t5xxl_lr,
|
||||
|
|
@ -546,7 +531,9 @@ def open_configuration(
|
|||
loraplus_lr_ratio,
|
||||
loraplus_text_encoder_lr_ratio,
|
||||
loraplus_unet_lr_ratio,
|
||||
|
||||
train_lora_ggpo,
|
||||
ggpo_sigma,
|
||||
ggpo_beta,
|
||||
# huggingface section
|
||||
huggingface_repo_id,
|
||||
huggingface_token,
|
||||
|
|
@ -556,14 +543,12 @@ def open_configuration(
|
|||
save_state_to_huggingface,
|
||||
resume_from_huggingface,
|
||||
async_upload,
|
||||
|
||||
# metadata section
|
||||
metadata_author,
|
||||
metadata_description,
|
||||
metadata_license,
|
||||
metadata_tags,
|
||||
metadata_title,
|
||||
|
||||
# Flux1
|
||||
flux1_cache_text_encoder_outputs,
|
||||
flux1_cache_text_encoder_outputs_to_disk,
|
||||
|
|
@ -597,7 +582,6 @@ def open_configuration(
|
|||
in_dims,
|
||||
train_double_block_indices,
|
||||
train_single_block_indices,
|
||||
|
||||
# SD3 parameters
|
||||
sd3_cache_text_encoder_outputs,
|
||||
sd3_cache_text_encoder_outputs_to_disk,
|
||||
|
|
@ -621,7 +605,6 @@ def open_configuration(
|
|||
sd3_text_encoder_batch_size,
|
||||
weighting_scheme,
|
||||
sd3_checkbox,
|
||||
|
||||
##
|
||||
training_preset,
|
||||
):
|
||||
|
|
@ -701,7 +684,6 @@ def open_configuration(
|
|||
def train_model(
|
||||
headless,
|
||||
print_only,
|
||||
|
||||
# source model section
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
|
|
@ -715,12 +697,10 @@ def train_model(
|
|||
output_name,
|
||||
model_list,
|
||||
training_comment,
|
||||
|
||||
# folders section
|
||||
logging_dir,
|
||||
reg_data_dir,
|
||||
output_dir,
|
||||
|
||||
# basic training section
|
||||
max_resolution,
|
||||
learning_rate,
|
||||
|
|
@ -747,7 +727,6 @@ def train_model(
|
|||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
max_grad_norm,
|
||||
|
||||
# accelerate launch section
|
||||
mixed_precision,
|
||||
num_cpu_threads_per_process,
|
||||
|
|
@ -761,7 +740,6 @@ def train_model(
|
|||
dynamo_use_fullgraph,
|
||||
dynamo_use_dynamic,
|
||||
extra_accelerate_launch_args,
|
||||
|
||||
### advanced training section
|
||||
gradient_checkpointing,
|
||||
fp8_base,
|
||||
|
|
@ -825,11 +803,9 @@ def train_model(
|
|||
vae,
|
||||
weighted_captions,
|
||||
debiased_estimation_loss,
|
||||
|
||||
# sdxl parameters section
|
||||
sdxl_cache_text_encoder_outputs,
|
||||
sdxl_no_half_vae,
|
||||
|
||||
###
|
||||
text_encoder_lr,
|
||||
t5xxl_lr,
|
||||
|
|
@ -874,7 +850,9 @@ def train_model(
|
|||
loraplus_lr_ratio,
|
||||
loraplus_text_encoder_lr_ratio,
|
||||
loraplus_unet_lr_ratio,
|
||||
|
||||
train_lora_ggpo,
|
||||
ggpo_sigma,
|
||||
ggpo_beta,
|
||||
# huggingface section
|
||||
huggingface_repo_id,
|
||||
huggingface_token,
|
||||
|
|
@ -884,14 +862,12 @@ def train_model(
|
|||
save_state_to_huggingface,
|
||||
resume_from_huggingface,
|
||||
async_upload,
|
||||
|
||||
# metadata section
|
||||
metadata_author,
|
||||
metadata_description,
|
||||
metadata_license,
|
||||
metadata_tags,
|
||||
metadata_title,
|
||||
|
||||
# Flux1
|
||||
flux1_cache_text_encoder_outputs,
|
||||
flux1_cache_text_encoder_outputs_to_disk,
|
||||
|
|
@ -925,7 +901,6 @@ def train_model(
|
|||
in_dims,
|
||||
train_double_block_indices,
|
||||
train_single_block_indices,
|
||||
|
||||
# SD3 parameters
|
||||
sd3_cache_text_encoder_outputs,
|
||||
sd3_cache_text_encoder_outputs_to_disk,
|
||||
|
|
@ -976,8 +951,14 @@ def train_model(
|
|||
|
||||
if flux1_checkbox:
|
||||
log.info(f"Validating lora type is Flux1 if flux1 checkbox is checked...")
|
||||
if (LoRA_type != "Flux1") and (LoRA_type != "Flux1 OFT") and ("LyCORIS" not in LoRA_type):
|
||||
log.error("LoRA type must be set to 'Flux1', 'Flux1 OFT' or 'LyCORIS' if Flux1 checkbox is checked.")
|
||||
if (
|
||||
(LoRA_type != "Flux1")
|
||||
and (LoRA_type != "Flux1 OFT")
|
||||
and ("LyCORIS" not in LoRA_type)
|
||||
):
|
||||
log.error(
|
||||
"LoRA type must be set to 'Flux1', 'Flux1 OFT' or 'LyCORIS' if Flux1 checkbox is checked."
|
||||
)
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
#
|
||||
|
|
@ -1182,7 +1163,9 @@ def train_model(
|
|||
if lr_warmup_steps > 0:
|
||||
lr_warmup_steps = int(lr_warmup_steps)
|
||||
if lr_warmup > 0:
|
||||
log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.")
|
||||
log.warning(
|
||||
"Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used."
|
||||
)
|
||||
elif lr_warmup != 0:
|
||||
lr_warmup_steps = lr_warmup / 100
|
||||
else:
|
||||
|
|
@ -1255,7 +1238,7 @@ def train_model(
|
|||
|
||||
if LoRA_type == "LyCORIS/LoHa":
|
||||
network_module = "lycoris.kohya"
|
||||
network_args = f' preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} use_tucker={use_tucker} rank_dropout={rank_dropout} bypass_mode={bypass_mode} dora_wd={dora_wd} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout_scale={rank_dropout_scale} algo=loha train_norm={train_norm}'
|
||||
network_args = f" preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} use_tucker={use_tucker} rank_dropout={rank_dropout} bypass_mode={bypass_mode} dora_wd={dora_wd} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout_scale={rank_dropout_scale} algo=loha train_norm={train_norm}"
|
||||
|
||||
if LoRA_type == "LyCORIS/LoKr":
|
||||
network_module = "lycoris.kohya"
|
||||
|
|
@ -1265,7 +1248,7 @@ def train_model(
|
|||
network_module = "lycoris.kohya"
|
||||
network_args = f" preset={LyCORIS_preset} rank_dropout={rank_dropout} module_dropout={module_dropout} rank_dropout_scale={rank_dropout_scale} algo=full train_norm={train_norm}"
|
||||
|
||||
if LoRA_type == "Flux1":
|
||||
if LoRA_type in ["Flux1"]:
|
||||
# Add a list of supported network arguments for Flux1 below when supported
|
||||
kohya_lora_var_list = [
|
||||
"img_attn_dim",
|
||||
|
|
@ -1280,6 +1263,11 @@ def train_model(
|
|||
"train_double_block_indices",
|
||||
"train_single_block_indices",
|
||||
]
|
||||
if train_lora_ggpo:
|
||||
kohya_lora_var_list += [
|
||||
"ggpo_beta",
|
||||
"ggpo_sigma",
|
||||
]
|
||||
network_module = "networks.lora_flux"
|
||||
kohya_lora_vars = {
|
||||
key: value
|
||||
|
|
@ -1418,7 +1406,9 @@ def train_model(
|
|||
|
||||
# Set the text_encoder_lr to multiple values if both text_encoder_lr and t5xxl_lr are set
|
||||
if text_encoder_lr == 0 and t5xxl_lr > 0:
|
||||
log.error("When specifying T5XXL learning rate, text encoder learning rate need to be a value greater than 0.")
|
||||
log.error(
|
||||
"When specifying T5XXL learning rate, text encoder learning rate need to be a value greater than 0."
|
||||
)
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
text_encoder_lr_list = []
|
||||
|
|
@ -1489,8 +1479,10 @@ def train_model(
|
|||
),
|
||||
"cache_text_encoder_outputs_to_disk": (
|
||||
True
|
||||
if flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk
|
||||
or sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk
|
||||
if flux1_checkbox
|
||||
and flux1_cache_text_encoder_outputs_to_disk
|
||||
or sd3_checkbox
|
||||
and sd3_cache_text_encoder_outputs_to_disk
|
||||
else None
|
||||
),
|
||||
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
|
||||
|
|
@ -1658,7 +1650,6 @@ def train_model(
|
|||
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
|
||||
"weighted_captions": weighted_captions,
|
||||
"xformers": True if xformers == "xformers" else None,
|
||||
|
||||
# SD3 only Parameters
|
||||
# "cache_text_encoder_outputs": see previous assignment above for code
|
||||
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
|
||||
|
|
@ -1666,7 +1657,9 @@ def train_model(
|
|||
"clip_g_dropout_rate": clip_g_dropout_rate if sd3_checkbox else None,
|
||||
# "clip_l": see previous assignment above for code
|
||||
"clip_l_dropout_rate": sd3_clip_l_dropout_rate if sd3_checkbox else None,
|
||||
"enable_scaled_pos_embed": sd3_enable_scaled_pos_embed if sd3_checkbox else None,
|
||||
"enable_scaled_pos_embed": (
|
||||
sd3_enable_scaled_pos_embed if sd3_checkbox else None
|
||||
),
|
||||
"logit_mean": logit_mean if sd3_checkbox else None,
|
||||
"logit_std": logit_std if sd3_checkbox else None,
|
||||
"mode_scale": mode_scale if sd3_checkbox else None,
|
||||
|
|
@ -1681,7 +1674,6 @@ def train_model(
|
|||
sd3_text_encoder_batch_size if sd3_checkbox else None
|
||||
),
|
||||
"weighting_scheme": weighting_scheme if sd3_checkbox else None,
|
||||
|
||||
# Flux.1 specific parameters
|
||||
# "cache_text_encoder_outputs": see previous assignment above for code
|
||||
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
|
||||
|
|
@ -1692,11 +1684,15 @@ def train_model(
|
|||
"model_prediction_type": model_prediction_type if flux1_checkbox else None,
|
||||
"timestep_sampling": timestep_sampling if flux1_checkbox else None,
|
||||
"split_mode": split_mode if flux1_checkbox else None,
|
||||
"t5xxl_max_token_length": int(t5xxl_max_token_length) if flux1_checkbox else None,
|
||||
"t5xxl_max_token_length": (
|
||||
int(t5xxl_max_token_length) if flux1_checkbox else None
|
||||
),
|
||||
"guidance_scale": float(guidance_scale) if flux1_checkbox else None,
|
||||
"mem_eff_save": mem_eff_save if flux1_checkbox else None,
|
||||
"apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None,
|
||||
"cpu_offload_checkpointing": cpu_offload_checkpointing if flux1_checkbox else None,
|
||||
"cpu_offload_checkpointing": (
|
||||
cpu_offload_checkpointing if flux1_checkbox else None
|
||||
),
|
||||
"blocks_to_swap": blocks_to_swap if flux1_checkbox or sd3_checkbox else None,
|
||||
"single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None,
|
||||
"double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None,
|
||||
|
|
@ -1887,7 +1883,7 @@ def lora_tab(
|
|||
visible=False,
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
info="Use path_to_config_file.toml to choose config file (for LyCORIS module settings)"
|
||||
info="Use path_to_config_file.toml to choose config file (for LyCORIS module settings)",
|
||||
)
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
|
|
@ -1971,9 +1967,7 @@ def lora_tab(
|
|||
maximum=128,
|
||||
)
|
||||
# Add SDXL Parameters
|
||||
sdxl_params = SDXLParameters(
|
||||
source_model.sdxl_checkbox, config=config
|
||||
)
|
||||
sdxl_params = SDXLParameters(source_model.sdxl_checkbox, config=config)
|
||||
|
||||
# LyCORIS Specific parameters
|
||||
with gr.Accordion("LyCORIS", visible=False) as lycoris_accordion:
|
||||
|
|
@ -2131,6 +2125,40 @@ def lora_tab(
|
|||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False) as train_lora_ggpo_row:
|
||||
train_lora_ggpo = gr.Checkbox(
|
||||
label="Train LoRA GGPO",
|
||||
value=False,
|
||||
info="Train LoRA GGPO",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row(visible=False) as ggpo_row:
|
||||
ggpo_sigma = gr.Number(
|
||||
label="GGPO sigma",
|
||||
value=0.03,
|
||||
info="Specify the sigma of GGPO.",
|
||||
interactive=True,
|
||||
)
|
||||
ggpo_beta = gr.Number(
|
||||
label="GGPO beta",
|
||||
value=0.01,
|
||||
info="Specify the beta of GGPO.",
|
||||
interactive=True,
|
||||
)
|
||||
# Update the visibility of the GGPO row based on the state of the "Train LoRA GGPO" checkbox
|
||||
train_lora_ggpo.change(
|
||||
lambda train_lora_ggpo: gr.Row(visible=train_lora_ggpo),
|
||||
inputs=[train_lora_ggpo],
|
||||
outputs=[ggpo_row],
|
||||
)
|
||||
|
||||
# Update the visibility of the train lora ggpo row based on the model type being Flux.1
|
||||
source_model.flux1_checkbox.change(
|
||||
lambda flux1_checkbox: gr.Row(visible=flux1_checkbox),
|
||||
inputs=[source_model.flux1_checkbox],
|
||||
outputs=[train_lora_ggpo_row],
|
||||
)
|
||||
|
||||
# Show or hide LoCon conv settings depending on LoRA type selection
|
||||
def update_LoRA_settings(
|
||||
LoRA_type,
|
||||
|
|
@ -2586,7 +2614,9 @@ def lora_tab(
|
|||
headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced", open=False, elem_classes="advanced_background"):
|
||||
with gr.Accordion(
|
||||
"Advanced", open=False, elem_classes="advanced_background"
|
||||
):
|
||||
# with gr.Accordion('Advanced Configuration', open=False):
|
||||
with gr.Row(visible=True) as kohya_advanced_lora:
|
||||
with gr.Tab(label="Weights"):
|
||||
|
|
@ -2648,7 +2678,9 @@ def lora_tab(
|
|||
sample = SampleImages(config=config)
|
||||
|
||||
global huggingface
|
||||
with gr.Accordion("HuggingFace", open=False, elem_classes="huggingface_background"):
|
||||
with gr.Accordion(
|
||||
"HuggingFace", open=False, elem_classes="huggingface_background"
|
||||
):
|
||||
huggingface = HuggingFace(config=config)
|
||||
|
||||
LoRA_type.change(
|
||||
|
|
@ -2860,6 +2892,9 @@ def lora_tab(
|
|||
loraplus_lr_ratio,
|
||||
loraplus_text_encoder_lr_ratio,
|
||||
loraplus_unet_lr_ratio,
|
||||
train_lora_ggpo,
|
||||
ggpo_sigma,
|
||||
ggpo_beta,
|
||||
huggingface.huggingface_repo_id,
|
||||
huggingface.huggingface_token,
|
||||
huggingface.huggingface_repo_type,
|
||||
|
|
@ -2906,7 +2941,6 @@ def lora_tab(
|
|||
flux1_training.in_dims,
|
||||
flux1_training.train_double_block_indices,
|
||||
flux1_training.train_single_block_indices,
|
||||
|
||||
# SD3 Parameters
|
||||
sd3_training.sd3_cache_text_encoder_outputs,
|
||||
sd3_training.sd3_cache_text_encoder_outputs_to_disk,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
"cache_latents_to_disk": false,
|
||||
"caption_dropout_every_n_epochs": 0,
|
||||
"caption_dropout_rate": 0.05,
|
||||
"caption_extension": "",
|
||||
"caption_extension": ".txt",
|
||||
"clip_g": "",
|
||||
"clip_l": "",
|
||||
"clip_skip": 2,
|
||||
|
|
|
|||
Loading…
Reference in New Issue