From eec6f9baf4e59bbb9bdac2ef24eff53e2d778bdc Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 15 Apr 2024 08:02:39 -0400 Subject: [PATCH] Convert str numbers to proper int or float --- config example.toml | 14 +-- kohya_gui/class_advanced_training.py | 125 +++++++++++++++++++------- kohya_gui/class_basic_training.py | 43 +++++---- kohya_gui/common_gui.py | 61 ++++++++++--- test/config/dreambooth-AdamW8bit.json | 16 ++-- 5 files changed, 185 insertions(+), 74 deletions(-) diff --git a/config example.toml b/config example.toml index 4565228..d764d01 100644 --- a/config example.toml +++ b/config example.toml @@ -42,19 +42,19 @@ learning_rate_te2 = 0.0001 # Learning rate text encoder 2 lr_scheduler = "cosine" # LR Scheduler lr_scheduler_args = "" # LR Scheduler args lr_warmup = 0 # LR Warmup (% of total steps) -lr_scheduler_num_cycles = "" # LR Scheduler num cycles -lr_scheduler_power = "" # LR Scheduler power +lr_scheduler_num_cycles = 1 # LR Scheduler num cycles +lr_scheduler_power = 1.0 # LR Scheduler power max_bucket_reso = 2048 # Max bucket resolution max_grad_norm = 1.0 # Max grad norm max_resolution = "512,512" # Max resolution -max_train_steps = "" # Max train steps -max_train_epochs = "" # Max train epochs +max_train_steps = 0 # Max train steps +max_train_epochs = 0 # Max train epochs min_bucket_reso = 256 # Min bucket resolution optimizer = "AdamW8bit" # Optimizer (AdamW, AdamW8bit, Adafactor, DAdaptation, DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptAdamPreprint, DAdaptLion, DAdaptSGD, Lion, Lion8bit, PagedAdam optimizer_args = "" # Optimizer args save_every_n_epochs = 1 # Save every n epochs save_every_n_steps = 1 # Save every n steps -seed = "1234" # Seed +seed = 1234 # Seed stop_text_encoder_training = 0 # Stop text encoder training (% of total steps) train_batch_size = 1 # Train batch size @@ -83,9 +83,9 @@ log_tracker_config_dir = "./logs" # Log tracker configs directory log_tracker_name = "" # Log tracker name loss_type = "l2" # Loss type (l2, huber, smooth_l1) masked_loss = false # Masked loss -max_data_loader_n_workers = "0" # Max data loader n workers (string) +max_data_loader_n_workers = 0 # Max data loader n workers (string) max_timestep = 1000 # Max timestep -max_token_length = "150" # Max token length ("75", "150", "225") +max_token_length = 150 # Max token length ("75", "150", "225") mem_eff_attn = false # Memory efficient attention min_snr_gamma = 0 # Min SNR gamma min_timestep = 0 # Min timestep diff --git a/kohya_gui/class_advanced_training.py b/kohya_gui/class_advanced_training.py index 35f1ffd..ffed020 100644 --- a/kohya_gui/class_advanced_training.py +++ b/kohya_gui/class_advanced_training.py @@ -75,7 +75,8 @@ class AdvancedTraining: # Exclude token padding option for LoRA training type. if training_type != "lora": self.no_token_padding = gr.Checkbox( - label="No token padding", value=self.config.get("advanced.no_token_padding", False) + label="No token padding", + value=self.config.get("advanced.no_token_padding", False), ) self.gradient_accumulation_steps = gr.Slider( label="Gradient accumulate steps", @@ -85,9 +86,15 @@ class AdvancedTraining: maximum=120, step=1, ) - self.weighted_captions = gr.Checkbox(label="Weighted captions", value=self.config.get("advanced.weighted_captions", False)) + self.weighted_captions = gr.Checkbox( + label="Weighted captions", + value=self.config.get("advanced.weighted_captions", False), + ) with gr.Group(), gr.Row(visible=not finetuning): - self.prior_loss_weight = gr.Number(label="Prior loss weight", value=self.config.get("advanced.prior_loss_weight", 1.0)) + self.prior_loss_weight = gr.Number( + label="Prior loss weight", + value=self.config.get("advanced.prior_loss_weight", 1.0), + ) def list_vae_files(path): self.current_vae_dir = path if not path == "" else "." @@ -96,14 +103,18 @@ class AdvancedTraining: self.vae = gr.Dropdown( label="VAE (Optional: Path to checkpoint of vae for training)", interactive=True, - choices=[self.config.get("advanced.vae_dir", "")] + list_vae_files(self.current_vae_dir), + choices=[self.config.get("advanced.vae_dir", "")] + + list_vae_files(self.current_vae_dir), value=self.config.get("advanced.vae_dir", ""), allow_custom_value=True, ) create_refresh_button( self.vae, lambda: None, - lambda: {"choices": [self.config.get("advanced.vae_dir", "")] + list_vae_files(self.current_vae_dir)}, + lambda: { + "choices": [self.config.get("advanced.vae_dir", "")] + + list_vae_files(self.current_vae_dir) + }, "open_folder_small", ) self.vae_button = gr.Button( @@ -116,7 +127,10 @@ class AdvancedTraining: ) self.vae.change( - fn=lambda path: gr.Dropdown(choices=[self.config.get("advanced.vae_dir", "")] + list_vae_files(path)), + fn=lambda path: gr.Dropdown( + choices=[self.config.get("advanced.vae_dir", "")] + + list_vae_files(path) + ), inputs=self.vae, outputs=self.vae, show_progress=False, @@ -189,19 +203,28 @@ class AdvancedTraining: ), gr.Checkbox(interactive=full_bf16_active) self.keep_tokens = gr.Slider( - label="Keep n tokens", value=self.config.get("advanced.keep_tokens", 0), minimum=0, maximum=32, step=1 + label="Keep n tokens", + value=self.config.get("advanced.keep_tokens", 0), + minimum=0, + maximum=32, + step=1, ) self.clip_skip = gr.Slider( - label="Clip skip", value=self.config.get("advanced.clip_skip", 1), minimum=1, maximum=12, step=1 + label="Clip skip", + value=self.config.get("advanced.clip_skip", 1), + minimum=1, + maximum=12, + step=1, ) self.max_token_length = gr.Dropdown( label="Max Token Length", choices=[ - "75", - "150", - "225", + 75, + 150, + 225, ], - value=self.config.get("advanced.max_token_length", "75"), + info="max token length of text encoder", + value=self.config.get("advanced.max_token_length", 75), ) with gr.Row(): @@ -234,14 +257,20 @@ class AdvancedTraining: with gr.Row(): self.gradient_checkpointing = gr.Checkbox( - label="Gradient checkpointing", value=self.config.get("advanced.gradient_checkpointing", False) + label="Gradient checkpointing", + value=self.config.get("advanced.gradient_checkpointing", False), + ) + self.shuffle_caption = gr.Checkbox( + label="Shuffle caption", + value=self.config.get("advanced.shuffle_caption", False), ) - self.shuffle_caption = gr.Checkbox(label="Shuffle caption", value=self.config.get("advanced.shuffle_caption", False)) self.persistent_data_loader_workers = gr.Checkbox( - label="Persistent data loader", value=self.config.get("advanced.persistent_data_loader_workers", False) + label="Persistent data loader", + value=self.config.get("advanced.persistent_data_loader_workers", False), ) self.mem_eff_attn = gr.Checkbox( - label="Memory efficient attention", value=self.config.get("advanced.mem_eff_attn", False) + label="Memory efficient attention", + value=self.config.get("advanced.mem_eff_attn", False), ) with gr.Row(): self.xformers = gr.Dropdown( @@ -267,7 +296,9 @@ class AdvancedTraining: with gr.Row(): self.scale_v_pred_loss_like_noise_pred = gr.Checkbox( label="Scale v prediction loss", - value=self.config.get("advanced.scale_v_pred_loss_like_noise_pred", False), + value=self.config.get( + "advanced.scale_v_pred_loss_like_noise_pred", False + ), info="Only for SD v2 models. By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected.", ) self.min_snr_gamma = gr.Slider( @@ -286,7 +317,8 @@ class AdvancedTraining: with gr.Row(): # self.sdpa = gr.Checkbox(label='Use sdpa', value=False, info='Use sdpa for CrossAttention') self.bucket_no_upscale = gr.Checkbox( - label="Don't upscale bucket resolution", value=self.config.get("advanced.bucket_no_upscale", True) + label="Don't upscale bucket resolution", + value=self.config.get("advanced.bucket_no_upscale", True), ) self.bucket_reso_steps = gr.Slider( label="Bucket resolution steps", @@ -295,7 +327,8 @@ class AdvancedTraining: maximum=128, ) self.random_crop = gr.Checkbox( - label="Random crop instead of center crop", value=self.config.get("advanced.random_crop", False) + label="Random crop instead of center crop", + value=self.config.get("advanced.random_crop", False), ) self.v_pred_like_loss = gr.Slider( label="V Pred like loss", @@ -345,7 +378,9 @@ class AdvancedTraining: ) self.noise_offset_random_strength = gr.Checkbox( label="Noise offset random strength", - value=self.config.get("advanced.noise_offset_random_strength", False), + value=self.config.get( + "advanced.noise_offset_random_strength", False + ), info="Use random strength between 0~noise_offset for noise offset", ) self.adaptive_noise_scale = gr.Slider( @@ -384,7 +419,9 @@ class AdvancedTraining: ) self.ip_noise_gamma_random_strength = gr.Checkbox( label="IP noise gamma random strength", - value=self.config.get("advanced.ip_noise_gamma_random_strength", False), + value=self.config.get( + "advanced.ip_noise_gamma_random_strength", False + ), info="Use random strength between 0~ip_noise_gamma for input perturbation noise", ) self.noise_offset_type.change( @@ -397,19 +434,31 @@ class AdvancedTraining: ) with gr.Row(): self.caption_dropout_every_n_epochs = gr.Number( - label="Dropout caption every n epochs", value=self.config.get("advanced.caption_dropout_every_n_epochs", 0), + label="Dropout caption every n epochs", + value=self.config.get("advanced.caption_dropout_every_n_epochs", 0), ) self.caption_dropout_rate = gr.Slider( - label="Rate of caption dropout", value=self.config.get("advanced.caption_dropout_rate", 0), minimum=0, maximum=1 + label="Rate of caption dropout", + value=self.config.get("advanced.caption_dropout_rate", 0), + minimum=0, + maximum=1, ) self.vae_batch_size = gr.Slider( - label="VAE batch size", minimum=0, maximum=32, value=self.config.get("advanced.vae_batch_size", 0), step=1 + label="VAE batch size", + minimum=0, + maximum=32, + value=self.config.get("advanced.vae_batch_size", 0), + step=1, ) with gr.Group(), gr.Row(): - self.save_state = gr.Checkbox(label="Save training state", value=self.config.get("advanced.save_state", False)) + self.save_state = gr.Checkbox( + label="Save training state", + value=self.config.get("advanced.save_state", False), + ) self.save_state_on_train_end = gr.Checkbox( - label="Save training state at end of training", value=self.config.get("advanced.save_state_on_train_end", False) + label="Save training state at end of training", + value=self.config.get("advanced.save_state_on_train_end", False), ) def list_state_dirs(path): @@ -418,7 +467,8 @@ class AdvancedTraining: self.resume = gr.Dropdown( label='Resume from saved training state (path to "last-state" state folder)', - choices=[self.config.get("advanced.state_dir", "")] + list_state_dirs(self.current_state_dir), + choices=[self.config.get("advanced.state_dir", "")] + + list_state_dirs(self.current_state_dir), value=self.config.get("advanced.state_dir", ""), interactive=True, allow_custom_value=True, @@ -426,7 +476,10 @@ class AdvancedTraining: create_refresh_button( self.resume, lambda: None, - lambda: {"choices": [self.config.get("advanced.state_dir", "")] + list_state_dirs(self.current_state_dir)}, + lambda: { + "choices": [self.config.get("advanced.state_dir", "")] + + list_state_dirs(self.current_state_dir) + }, "open_folder_small", ) self.resume_button = gr.Button( @@ -438,15 +491,20 @@ class AdvancedTraining: show_progress=False, ) self.resume.change( - fn=lambda path: gr.Dropdown(choices=[self.config.get("advanced.state_dir", "")] + list_state_dirs(path)), + fn=lambda path: gr.Dropdown( + choices=[self.config.get("advanced.state_dir", "")] + + list_state_dirs(path) + ), inputs=self.resume, outputs=self.resume, show_progress=False, ) - self.max_data_loader_n_workers = gr.Textbox( + self.max_data_loader_n_workers = gr.Number( label="Max num workers for DataLoader", - placeholder="(Optional) Override number of epoch. Default: 8", - value=self.config.get("advanced.max_data_loader_n_workers", "0"), + info="Override number of epoch. Default: 0", + step=1, + minimum=0, + value=self.config.get("advanced.max_data_loader_n_workers", 0), ) with gr.Row(): self.use_wandb = gr.Checkbox( @@ -506,7 +564,8 @@ class AdvancedTraining: ) self.log_tracker_config.change( fn=lambda path: gr.Dropdown( - choices=[self.config.get("log_tracker_config_dir", "")] + list_log_tracker_config_files(path) + choices=[self.config.get("log_tracker_config_dir", "")] + + list_log_tracker_config_files(path) ), inputs=self.log_tracker_config, outputs=self.log_tracker_config, diff --git a/kohya_gui/class_basic_training.py b/kohya_gui/class_basic_training.py index 70270d8..85f38a9 100644 --- a/kohya_gui/class_basic_training.py +++ b/kohya_gui/class_basic_training.py @@ -88,16 +88,21 @@ class BasicTraining: label="Epoch", value=self.config.get("basic.epoch", 1), precision=0 ) # Initialize the maximum train epochs input - self.max_train_epochs = gr.Textbox( + self.max_train_epochs = gr.Number( label="Max train epoch", - placeholder="(Optional) Enforce # epochs", - value=self.config.get("basic.max_train_epochs", ""), + info="training epochs (overrides max_train_steps). 0 = no override", + step=1, + # precision=0, + minimum=0, + value=self.config.get("basic.max_train_epochs", 0), ) # Initialize the maximum train steps input - self.max_train_steps = gr.Textbox( + self.max_train_steps = gr.Number( label="Max train steps", - placeholder="(Optional) Enforce # steps", - value=self.config.get("basic.max_train_steps", ""), + info="Overrides # training steps. 0 = no override", + step=1, + # precision=0, + value=self.config.get("basic.max_train_steps", 0), ) # Initialize the save every N epochs input self.save_every_n_epochs = gr.Number( @@ -119,10 +124,13 @@ class BasicTraining: """ with gr.Row(): # Initialize the seed textbox - self.seed = gr.Textbox( + self.seed = gr.Number( label="Seed", - placeholder="(Optional) eg:1234", - value=self.config.get("basic.seed", ""), + # precision=0, + step=1, + minimum=0, + value=self.config.get("basic.seed", 0), + info="Set to 0 to make random", ) # Initialize the cache latents checkbox self.cache_latents = gr.Checkbox( @@ -277,16 +285,21 @@ class BasicTraining: """ with gr.Row(visible=not self.finetuning): # Initialize the learning rate scheduler number of cycles textbox - self.lr_scheduler_num_cycles = gr.Textbox( + self.lr_scheduler_num_cycles = gr.Number( label="LR # cycles", - placeholder="(Optional) For Cosine with restart and polynomial only", - value=self.config.get("basic.lr_scheduler_num_cycles", ""), + minimum=1, + # precision=0, # round to nearest integer + step=1, # Increment value by 1 + info="Number of restarts for cosine scheduler with restarts", + value=self.config.get("basic.lr_scheduler_num_cycles", 1), ) # Initialize the learning rate scheduler power textbox - self.lr_scheduler_power = gr.Textbox( + self.lr_scheduler_power = gr.Number( label="LR power", - placeholder="(Optional) For Cosine with restart and polynomial only", - value=self.config.get("basic.lr_scheduler_power", ""), + minimum=1.0, + step=0.01, + info="Polynomial power for polynomial scheduler", + value=self.config.get("basic.lr_scheduler_power", 1.0), ) def init_resolution_and_bucket_controls(self) -> None: diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index 6bbb7a3..0d9b9d4 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -323,24 +323,63 @@ def update_my_data(my_data): my_data["model_list"] = "custom" # Convert values to int if they are strings - for key in ["epoch", "save_every_n_epochs", "lr_warmup"]: - value = my_data.get(key, 0) - if isinstance(value, str) and value.strip().isdigit(): - my_data[key] = int(value) - elif not value: - my_data[key] = 0 + for key in [ + "epoch", + "keep_tokens", + "lr_warmup", + "max_data_loader_n_workers", + "max_train_epochs", + "max_train_steps", + "save_every_n_epochs", + "seed", + ]: + value = my_data.get(key) + if value is not None: + try: + my_data[key] = int(value) + except ValueError: + # Handle the case where the string is not a valid float + my_data[key] = int(0) + + # Convert values to int if they are strings + for key in ["lr_scheduler_num_cycles"]: + value = my_data.get(key) + if value is not None: + try: + my_data[key] = int(value) + except ValueError: + # Handle the case where the string is not a valid float + my_data[key] = int(1) + + # Convert values to int if they are strings + for key in ["max_token_length"]: + value = my_data.get(key) + if value is not None: + try: + my_data[key] = int(value) + except ValueError: + # Handle the case where the string is not a valid float + my_data[key] = int(75) # Convert values to float if they are strings, correctly handling float representations for key in ["noise_offset", "learning_rate", "text_encoder_lr", "unet_lr"]: - value = my_data.get(key, 0) - if isinstance(value, str): + value = my_data.get(key) + if value is not None: try: my_data[key] = float(value) except ValueError: # Handle the case where the string is not a valid float - my_data[key] = 0 - elif not value: - my_data[key] = 0 + my_data[key] = float(0.0) + + # Convert values to float if they are strings, correctly handling float representations + for key in ["lr_scheduler_power"]: + value = my_data.get(key) + if value is not None: + try: + my_data[key] = float(value) + except ValueError: + # Handle the case where the string is not a valid float + my_data[key] = float(1.0) # Update LoRA_type if it is set to LoCon if my_data.get("LoRA_type", "Standard") == "LoCon": diff --git a/test/config/dreambooth-AdamW8bit.json b/test/config/dreambooth-AdamW8bit.json index ea1ef35..9a91ea5 100644 --- a/test/config/dreambooth-AdamW8bit.json +++ b/test/config/dreambooth-AdamW8bit.json @@ -31,7 +31,7 @@ "huggingface_token": "", "ip_noise_gamma": 0.1, "ip_noise_gamma_random_strength": true, - "keep_tokens": "0", + "keep_tokens": 0, "learning_rate": 5e-05, "learning_rate_te": 1e-05, "learning_rate_te1": 1e-05, @@ -42,18 +42,18 @@ "loss_type": "l2", "lr_scheduler": "constant", "lr_scheduler_args": "", - "lr_scheduler_num_cycles": "", - "lr_scheduler_power": "", + "lr_scheduler_num_cycles": 1, + "lr_scheduler_power": 1.02, "lr_warmup": 0, "main_process_port": 0, "masked_loss": false, "max_bucket_reso": 2048, - "max_data_loader_n_workers": "0", + "max_data_loader_n_workers": 0, "max_resolution": "512,512", "max_timestep": 1000, - "max_token_length": "75", - "max_train_epochs": "", - "max_train_steps": "", + "max_token_length": 75, + "max_train_epochs": 0, + "max_train_steps": 0, "mem_eff_attn": false, "min_bucket_reso": 256, "min_snr_gamma": 0, @@ -97,7 +97,7 @@ "save_state_to_huggingface": false, "scale_v_pred_loss_like_noise_pred": false, "sdxl": false, - "seed": "1234", + "seed": 1234, "shuffle_caption": false, "stop_text_encoder_training": 0, "train_batch_size": 4,