Convert str numbers to proper int or float

pull/2292/head
bmaltais 2024-04-15 08:02:39 -04:00
parent 5316db3fa7
commit eec6f9baf4
5 changed files with 185 additions and 74 deletions

View File

@ -42,19 +42,19 @@ learning_rate_te2 = 0.0001 # Learning rate text encoder 2
lr_scheduler = "cosine" # LR Scheduler
lr_scheduler_args = "" # LR Scheduler args
lr_warmup = 0 # LR Warmup (% of total steps)
lr_scheduler_num_cycles = "" # LR Scheduler num cycles
lr_scheduler_power = "" # LR Scheduler power
lr_scheduler_num_cycles = 1 # LR Scheduler num cycles
lr_scheduler_power = 1.0 # LR Scheduler power
max_bucket_reso = 2048 # Max bucket resolution
max_grad_norm = 1.0 # Max grad norm
max_resolution = "512,512" # Max resolution
max_train_steps = "" # Max train steps
max_train_epochs = "" # Max train epochs
max_train_steps = 0 # Max train steps
max_train_epochs = 0 # Max train epochs
min_bucket_reso = 256 # Min bucket resolution
optimizer = "AdamW8bit" # Optimizer (AdamW, AdamW8bit, Adafactor, DAdaptation, DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptAdamPreprint, DAdaptLion, DAdaptSGD, Lion, Lion8bit, PagedAdam
optimizer_args = "" # Optimizer args
save_every_n_epochs = 1 # Save every n epochs
save_every_n_steps = 1 # Save every n steps
seed = "1234" # Seed
seed = 1234 # Seed
stop_text_encoder_training = 0 # Stop text encoder training (% of total steps)
train_batch_size = 1 # Train batch size
@ -83,9 +83,9 @@ log_tracker_config_dir = "./logs" # Log tracker configs directory
log_tracker_name = "" # Log tracker name
loss_type = "l2" # Loss type (l2, huber, smooth_l1)
masked_loss = false # Masked loss
max_data_loader_n_workers = "0" # Max data loader n workers (string)
max_data_loader_n_workers = 0 # Max data loader n workers (string)
max_timestep = 1000 # Max timestep
max_token_length = "150" # Max token length ("75", "150", "225")
max_token_length = 150 # Max token length ("75", "150", "225")
mem_eff_attn = false # Memory efficient attention
min_snr_gamma = 0 # Min SNR gamma
min_timestep = 0 # Min timestep

View File

@ -75,7 +75,8 @@ class AdvancedTraining:
# Exclude token padding option for LoRA training type.
if training_type != "lora":
self.no_token_padding = gr.Checkbox(
label="No token padding", value=self.config.get("advanced.no_token_padding", False)
label="No token padding",
value=self.config.get("advanced.no_token_padding", False),
)
self.gradient_accumulation_steps = gr.Slider(
label="Gradient accumulate steps",
@ -85,9 +86,15 @@ class AdvancedTraining:
maximum=120,
step=1,
)
self.weighted_captions = gr.Checkbox(label="Weighted captions", value=self.config.get("advanced.weighted_captions", False))
self.weighted_captions = gr.Checkbox(
label="Weighted captions",
value=self.config.get("advanced.weighted_captions", False),
)
with gr.Group(), gr.Row(visible=not finetuning):
self.prior_loss_weight = gr.Number(label="Prior loss weight", value=self.config.get("advanced.prior_loss_weight", 1.0))
self.prior_loss_weight = gr.Number(
label="Prior loss weight",
value=self.config.get("advanced.prior_loss_weight", 1.0),
)
def list_vae_files(path):
self.current_vae_dir = path if not path == "" else "."
@ -96,14 +103,18 @@ class AdvancedTraining:
self.vae = gr.Dropdown(
label="VAE (Optional: Path to checkpoint of vae for training)",
interactive=True,
choices=[self.config.get("advanced.vae_dir", "")] + list_vae_files(self.current_vae_dir),
choices=[self.config.get("advanced.vae_dir", "")]
+ list_vae_files(self.current_vae_dir),
value=self.config.get("advanced.vae_dir", ""),
allow_custom_value=True,
)
create_refresh_button(
self.vae,
lambda: None,
lambda: {"choices": [self.config.get("advanced.vae_dir", "")] + list_vae_files(self.current_vae_dir)},
lambda: {
"choices": [self.config.get("advanced.vae_dir", "")]
+ list_vae_files(self.current_vae_dir)
},
"open_folder_small",
)
self.vae_button = gr.Button(
@ -116,7 +127,10 @@ class AdvancedTraining:
)
self.vae.change(
fn=lambda path: gr.Dropdown(choices=[self.config.get("advanced.vae_dir", "")] + list_vae_files(path)),
fn=lambda path: gr.Dropdown(
choices=[self.config.get("advanced.vae_dir", "")]
+ list_vae_files(path)
),
inputs=self.vae,
outputs=self.vae,
show_progress=False,
@ -189,19 +203,28 @@ class AdvancedTraining:
), gr.Checkbox(interactive=full_bf16_active)
self.keep_tokens = gr.Slider(
label="Keep n tokens", value=self.config.get("advanced.keep_tokens", 0), minimum=0, maximum=32, step=1
label="Keep n tokens",
value=self.config.get("advanced.keep_tokens", 0),
minimum=0,
maximum=32,
step=1,
)
self.clip_skip = gr.Slider(
label="Clip skip", value=self.config.get("advanced.clip_skip", 1), minimum=1, maximum=12, step=1
label="Clip skip",
value=self.config.get("advanced.clip_skip", 1),
minimum=1,
maximum=12,
step=1,
)
self.max_token_length = gr.Dropdown(
label="Max Token Length",
choices=[
"75",
"150",
"225",
75,
150,
225,
],
value=self.config.get("advanced.max_token_length", "75"),
info="max token length of text encoder",
value=self.config.get("advanced.max_token_length", 75),
)
with gr.Row():
@ -234,14 +257,20 @@ class AdvancedTraining:
with gr.Row():
self.gradient_checkpointing = gr.Checkbox(
label="Gradient checkpointing", value=self.config.get("advanced.gradient_checkpointing", False)
label="Gradient checkpointing",
value=self.config.get("advanced.gradient_checkpointing", False),
)
self.shuffle_caption = gr.Checkbox(
label="Shuffle caption",
value=self.config.get("advanced.shuffle_caption", False),
)
self.shuffle_caption = gr.Checkbox(label="Shuffle caption", value=self.config.get("advanced.shuffle_caption", False))
self.persistent_data_loader_workers = gr.Checkbox(
label="Persistent data loader", value=self.config.get("advanced.persistent_data_loader_workers", False)
label="Persistent data loader",
value=self.config.get("advanced.persistent_data_loader_workers", False),
)
self.mem_eff_attn = gr.Checkbox(
label="Memory efficient attention", value=self.config.get("advanced.mem_eff_attn", False)
label="Memory efficient attention",
value=self.config.get("advanced.mem_eff_attn", False),
)
with gr.Row():
self.xformers = gr.Dropdown(
@ -267,7 +296,9 @@ class AdvancedTraining:
with gr.Row():
self.scale_v_pred_loss_like_noise_pred = gr.Checkbox(
label="Scale v prediction loss",
value=self.config.get("advanced.scale_v_pred_loss_like_noise_pred", False),
value=self.config.get(
"advanced.scale_v_pred_loss_like_noise_pred", False
),
info="Only for SD v2 models. By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected.",
)
self.min_snr_gamma = gr.Slider(
@ -286,7 +317,8 @@ class AdvancedTraining:
with gr.Row():
# self.sdpa = gr.Checkbox(label='Use sdpa', value=False, info='Use sdpa for CrossAttention')
self.bucket_no_upscale = gr.Checkbox(
label="Don't upscale bucket resolution", value=self.config.get("advanced.bucket_no_upscale", True)
label="Don't upscale bucket resolution",
value=self.config.get("advanced.bucket_no_upscale", True),
)
self.bucket_reso_steps = gr.Slider(
label="Bucket resolution steps",
@ -295,7 +327,8 @@ class AdvancedTraining:
maximum=128,
)
self.random_crop = gr.Checkbox(
label="Random crop instead of center crop", value=self.config.get("advanced.random_crop", False)
label="Random crop instead of center crop",
value=self.config.get("advanced.random_crop", False),
)
self.v_pred_like_loss = gr.Slider(
label="V Pred like loss",
@ -345,7 +378,9 @@ class AdvancedTraining:
)
self.noise_offset_random_strength = gr.Checkbox(
label="Noise offset random strength",
value=self.config.get("advanced.noise_offset_random_strength", False),
value=self.config.get(
"advanced.noise_offset_random_strength", False
),
info="Use random strength between 0~noise_offset for noise offset",
)
self.adaptive_noise_scale = gr.Slider(
@ -384,7 +419,9 @@ class AdvancedTraining:
)
self.ip_noise_gamma_random_strength = gr.Checkbox(
label="IP noise gamma random strength",
value=self.config.get("advanced.ip_noise_gamma_random_strength", False),
value=self.config.get(
"advanced.ip_noise_gamma_random_strength", False
),
info="Use random strength between 0~ip_noise_gamma for input perturbation noise",
)
self.noise_offset_type.change(
@ -397,19 +434,31 @@ class AdvancedTraining:
)
with gr.Row():
self.caption_dropout_every_n_epochs = gr.Number(
label="Dropout caption every n epochs", value=self.config.get("advanced.caption_dropout_every_n_epochs", 0),
label="Dropout caption every n epochs",
value=self.config.get("advanced.caption_dropout_every_n_epochs", 0),
)
self.caption_dropout_rate = gr.Slider(
label="Rate of caption dropout", value=self.config.get("advanced.caption_dropout_rate", 0), minimum=0, maximum=1
label="Rate of caption dropout",
value=self.config.get("advanced.caption_dropout_rate", 0),
minimum=0,
maximum=1,
)
self.vae_batch_size = gr.Slider(
label="VAE batch size", minimum=0, maximum=32, value=self.config.get("advanced.vae_batch_size", 0), step=1
label="VAE batch size",
minimum=0,
maximum=32,
value=self.config.get("advanced.vae_batch_size", 0),
step=1,
)
with gr.Group(), gr.Row():
self.save_state = gr.Checkbox(label="Save training state", value=self.config.get("advanced.save_state", False))
self.save_state = gr.Checkbox(
label="Save training state",
value=self.config.get("advanced.save_state", False),
)
self.save_state_on_train_end = gr.Checkbox(
label="Save training state at end of training", value=self.config.get("advanced.save_state_on_train_end", False)
label="Save training state at end of training",
value=self.config.get("advanced.save_state_on_train_end", False),
)
def list_state_dirs(path):
@ -418,7 +467,8 @@ class AdvancedTraining:
self.resume = gr.Dropdown(
label='Resume from saved training state (path to "last-state" state folder)',
choices=[self.config.get("advanced.state_dir", "")] + list_state_dirs(self.current_state_dir),
choices=[self.config.get("advanced.state_dir", "")]
+ list_state_dirs(self.current_state_dir),
value=self.config.get("advanced.state_dir", ""),
interactive=True,
allow_custom_value=True,
@ -426,7 +476,10 @@ class AdvancedTraining:
create_refresh_button(
self.resume,
lambda: None,
lambda: {"choices": [self.config.get("advanced.state_dir", "")] + list_state_dirs(self.current_state_dir)},
lambda: {
"choices": [self.config.get("advanced.state_dir", "")]
+ list_state_dirs(self.current_state_dir)
},
"open_folder_small",
)
self.resume_button = gr.Button(
@ -438,15 +491,20 @@ class AdvancedTraining:
show_progress=False,
)
self.resume.change(
fn=lambda path: gr.Dropdown(choices=[self.config.get("advanced.state_dir", "")] + list_state_dirs(path)),
fn=lambda path: gr.Dropdown(
choices=[self.config.get("advanced.state_dir", "")]
+ list_state_dirs(path)
),
inputs=self.resume,
outputs=self.resume,
show_progress=False,
)
self.max_data_loader_n_workers = gr.Textbox(
self.max_data_loader_n_workers = gr.Number(
label="Max num workers for DataLoader",
placeholder="(Optional) Override number of epoch. Default: 8",
value=self.config.get("advanced.max_data_loader_n_workers", "0"),
info="Override number of epoch. Default: 0",
step=1,
minimum=0,
value=self.config.get("advanced.max_data_loader_n_workers", 0),
)
with gr.Row():
self.use_wandb = gr.Checkbox(
@ -506,7 +564,8 @@ class AdvancedTraining:
)
self.log_tracker_config.change(
fn=lambda path: gr.Dropdown(
choices=[self.config.get("log_tracker_config_dir", "")] + list_log_tracker_config_files(path)
choices=[self.config.get("log_tracker_config_dir", "")]
+ list_log_tracker_config_files(path)
),
inputs=self.log_tracker_config,
outputs=self.log_tracker_config,

View File

@ -88,16 +88,21 @@ class BasicTraining:
label="Epoch", value=self.config.get("basic.epoch", 1), precision=0
)
# Initialize the maximum train epochs input
self.max_train_epochs = gr.Textbox(
self.max_train_epochs = gr.Number(
label="Max train epoch",
placeholder="(Optional) Enforce # epochs",
value=self.config.get("basic.max_train_epochs", ""),
info="training epochs (overrides max_train_steps). 0 = no override",
step=1,
# precision=0,
minimum=0,
value=self.config.get("basic.max_train_epochs", 0),
)
# Initialize the maximum train steps input
self.max_train_steps = gr.Textbox(
self.max_train_steps = gr.Number(
label="Max train steps",
placeholder="(Optional) Enforce # steps",
value=self.config.get("basic.max_train_steps", ""),
info="Overrides # training steps. 0 = no override",
step=1,
# precision=0,
value=self.config.get("basic.max_train_steps", 0),
)
# Initialize the save every N epochs input
self.save_every_n_epochs = gr.Number(
@ -119,10 +124,13 @@ class BasicTraining:
"""
with gr.Row():
# Initialize the seed textbox
self.seed = gr.Textbox(
self.seed = gr.Number(
label="Seed",
placeholder="(Optional) eg:1234",
value=self.config.get("basic.seed", ""),
# precision=0,
step=1,
minimum=0,
value=self.config.get("basic.seed", 0),
info="Set to 0 to make random",
)
# Initialize the cache latents checkbox
self.cache_latents = gr.Checkbox(
@ -277,16 +285,21 @@ class BasicTraining:
"""
with gr.Row(visible=not self.finetuning):
# Initialize the learning rate scheduler number of cycles textbox
self.lr_scheduler_num_cycles = gr.Textbox(
self.lr_scheduler_num_cycles = gr.Number(
label="LR # cycles",
placeholder="(Optional) For Cosine with restart and polynomial only",
value=self.config.get("basic.lr_scheduler_num_cycles", ""),
minimum=1,
# precision=0, # round to nearest integer
step=1, # Increment value by 1
info="Number of restarts for cosine scheduler with restarts",
value=self.config.get("basic.lr_scheduler_num_cycles", 1),
)
# Initialize the learning rate scheduler power textbox
self.lr_scheduler_power = gr.Textbox(
self.lr_scheduler_power = gr.Number(
label="LR power",
placeholder="(Optional) For Cosine with restart and polynomial only",
value=self.config.get("basic.lr_scheduler_power", ""),
minimum=1.0,
step=0.01,
info="Polynomial power for polynomial scheduler",
value=self.config.get("basic.lr_scheduler_power", 1.0),
)
def init_resolution_and_bucket_controls(self) -> None:

View File

@ -323,24 +323,63 @@ def update_my_data(my_data):
my_data["model_list"] = "custom"
# Convert values to int if they are strings
for key in ["epoch", "save_every_n_epochs", "lr_warmup"]:
value = my_data.get(key, 0)
if isinstance(value, str) and value.strip().isdigit():
my_data[key] = int(value)
elif not value:
my_data[key] = 0
for key in [
"epoch",
"keep_tokens",
"lr_warmup",
"max_data_loader_n_workers",
"max_train_epochs",
"max_train_steps",
"save_every_n_epochs",
"seed",
]:
value = my_data.get(key)
if value is not None:
try:
my_data[key] = int(value)
except ValueError:
# Handle the case where the string is not a valid float
my_data[key] = int(0)
# Convert values to int if they are strings
for key in ["lr_scheduler_num_cycles"]:
value = my_data.get(key)
if value is not None:
try:
my_data[key] = int(value)
except ValueError:
# Handle the case where the string is not a valid float
my_data[key] = int(1)
# Convert values to int if they are strings
for key in ["max_token_length"]:
value = my_data.get(key)
if value is not None:
try:
my_data[key] = int(value)
except ValueError:
# Handle the case where the string is not a valid float
my_data[key] = int(75)
# Convert values to float if they are strings, correctly handling float representations
for key in ["noise_offset", "learning_rate", "text_encoder_lr", "unet_lr"]:
value = my_data.get(key, 0)
if isinstance(value, str):
value = my_data.get(key)
if value is not None:
try:
my_data[key] = float(value)
except ValueError:
# Handle the case where the string is not a valid float
my_data[key] = 0
elif not value:
my_data[key] = 0
my_data[key] = float(0.0)
# Convert values to float if they are strings, correctly handling float representations
for key in ["lr_scheduler_power"]:
value = my_data.get(key)
if value is not None:
try:
my_data[key] = float(value)
except ValueError:
# Handle the case where the string is not a valid float
my_data[key] = float(1.0)
# Update LoRA_type if it is set to LoCon
if my_data.get("LoRA_type", "Standard") == "LoCon":

View File

@ -31,7 +31,7 @@
"huggingface_token": "",
"ip_noise_gamma": 0.1,
"ip_noise_gamma_random_strength": true,
"keep_tokens": "0",
"keep_tokens": 0,
"learning_rate": 5e-05,
"learning_rate_te": 1e-05,
"learning_rate_te1": 1e-05,
@ -42,18 +42,18 @@
"loss_type": "l2",
"lr_scheduler": "constant",
"lr_scheduler_args": "",
"lr_scheduler_num_cycles": "",
"lr_scheduler_power": "",
"lr_scheduler_num_cycles": 1,
"lr_scheduler_power": 1.02,
"lr_warmup": 0,
"main_process_port": 0,
"masked_loss": false,
"max_bucket_reso": 2048,
"max_data_loader_n_workers": "0",
"max_data_loader_n_workers": 0,
"max_resolution": "512,512",
"max_timestep": 1000,
"max_token_length": "75",
"max_train_epochs": "",
"max_train_steps": "",
"max_token_length": 75,
"max_train_epochs": 0,
"max_train_steps": 0,
"mem_eff_attn": false,
"min_bucket_reso": 256,
"min_snr_gamma": 0,
@ -97,7 +97,7 @@
"save_state_to_huggingface": false,
"scale_v_pred_loss_like_noise_pred": false,
"sdxl": false,
"seed": "1234",
"seed": 1234,
"shuffle_caption": false,
"stop_text_encoder_training": 0,
"train_batch_size": 4,