mirror of https://github.com/bmaltais/kohya_ss
Add basic section to config example.toml
parent
0aa91418b3
commit
5bd439f3d0
|
|
@ -19,6 +19,35 @@ logging_dir = "./logs" # Logging directory
|
|||
[configuration]
|
||||
config_dir = "./presets" # Load/Save Config file
|
||||
|
||||
[basic]
|
||||
cache_latents = true # Cache latents
|
||||
cache_latents_to_disk = false # Cache latents to disk
|
||||
caption_extension = ".txt" # Caption extension
|
||||
enable_bucket = true # Enable bucket
|
||||
epoch = 1 # Epoch
|
||||
learning_rate = 0.0001 # Learning rate
|
||||
learning_rate_te = 0.0001 # Learning rate text encoder
|
||||
learning_rate_te1 = 0.0001 # Learning rate text encoder 1
|
||||
learning_rate_te2 = 0.0001 # Learning rate text encoder 2
|
||||
lr_scheduler = "cosine" # LR Scheduler
|
||||
lr_scheduler_args = "" # LR Scheduler args
|
||||
lr_warmup = 0 # LR Warmup (% of total steps)
|
||||
lr_scheduler_num_cycles = "" # LR Scheduler num cycles
|
||||
lr_scheduler_power = "" # LR Scheduler power
|
||||
max_bucket_reso = 2048 # Max bucket resolution
|
||||
max_grad_norm = 1.0 # Max grad norm
|
||||
max_resolution = "512,512" # Max resolution
|
||||
max_train_steps = "" # Max train steps
|
||||
max_train_epochs = "" # Max train epochs
|
||||
min_bucket_reso = 256 # Min bucket resolution
|
||||
optimizer = "AdamW8bit" # Optimizer (AdamW, AdamW8bit, Adafactor, DAdaptation, DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptAdamPreprint, DAdaptLion, DAdaptSGD, Lion, Lion8bit, PagedAdam
|
||||
optimizer_args = "" # Optimizer args
|
||||
save_every_n_epochs = 1 # Save every n epochs
|
||||
save_every_n_steps = 1 # Save every n steps
|
||||
seed = "1234" # Seed
|
||||
stop_text_encoder_training = 0 # Stop text encoder training (% of total steps)
|
||||
train_batch_size = 1 # Train batch size
|
||||
|
||||
[advanced]
|
||||
adaptive_noise_scale = 0 # Adaptive noise scale
|
||||
additional_parameters = "" # Additional parameters
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class BasicTraining:
|
|||
lr_warmup_value: str = "0",
|
||||
finetuning: bool = False,
|
||||
dreambooth: bool = False,
|
||||
config: dict = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the BasicTraining object with the given parameters.
|
||||
|
|
@ -42,6 +43,7 @@ class BasicTraining:
|
|||
self.lr_warmup_value = lr_warmup_value
|
||||
self.finetuning = finetuning
|
||||
self.dreambooth = dreambooth
|
||||
self.config = config
|
||||
|
||||
# Initialize the UI components
|
||||
self.initialize_ui_components()
|
||||
|
|
@ -75,28 +77,31 @@ class BasicTraining:
|
|||
with gr.Row():
|
||||
# Initialize the train batch size slider
|
||||
self.train_batch_size = gr.Slider(
|
||||
minimum=1, maximum=64, label="Train batch size", value=1, step=1
|
||||
minimum=1, maximum=64, label="Train batch size", value=1, step=self.config.get("basic.train_batch_size", 1),
|
||||
)
|
||||
# Initialize the epoch number input
|
||||
self.epoch = gr.Number(label="Epoch", value=1, precision=0)
|
||||
self.epoch = gr.Number(label="Epoch", value=self.config.get("basic.epoch", 1), precision=0)
|
||||
# Initialize the maximum train epochs input
|
||||
self.max_train_epochs = gr.Textbox(
|
||||
label="Max train epoch",
|
||||
placeholder="(Optional) Enforce # epochs",
|
||||
value=self.config.get("basic.max_train_epochs", ""),
|
||||
)
|
||||
# Initialize the maximum train steps input
|
||||
self.max_train_steps = gr.Textbox(
|
||||
label="Max train steps",
|
||||
placeholder="(Optional) Enforce # steps",
|
||||
value=self.config.get("basic.max_train_steps", ""),
|
||||
)
|
||||
# Initialize the save every N epochs input
|
||||
self.save_every_n_epochs = gr.Number(
|
||||
label="Save every N epochs", value=1, precision=0
|
||||
label="Save every N epochs", value=self.config.get("basic.save_every_n_epochs", 1), precision=0
|
||||
)
|
||||
# Initialize the caption extension input
|
||||
self.caption_extension = gr.Textbox(
|
||||
label="Caption Extension",
|
||||
placeholder="(Optional) default: .caption",
|
||||
value=self.config.get("basic.caption_extension", ""),
|
||||
)
|
||||
|
||||
def init_precision_and_resources_controls(self) -> None:
|
||||
|
|
@ -105,12 +110,12 @@ class BasicTraining:
|
|||
"""
|
||||
with gr.Row():
|
||||
# Initialize the seed textbox
|
||||
self.seed = gr.Textbox(label="Seed", placeholder="(Optional) eg:1234")
|
||||
self.seed = gr.Textbox(label="Seed", placeholder="(Optional) eg:1234", value=self.config.get("basic.seed", ""))
|
||||
# Initialize the cache latents checkbox
|
||||
self.cache_latents = gr.Checkbox(label="Cache latents", value=True)
|
||||
self.cache_latents = gr.Checkbox(label="Cache latents", value=self.config.get("basic.cache_latents", True))
|
||||
# Initialize the cache latents to disk checkbox
|
||||
self.cache_latents_to_disk = gr.Checkbox(
|
||||
label="Cache latents to disk", value=False
|
||||
label="Cache latents to disk", value=self.config.get("basic.cache_latents_to_disk", False)
|
||||
)
|
||||
|
||||
def init_lr_and_optimizer_controls(self) -> None:
|
||||
|
|
@ -130,7 +135,7 @@ class BasicTraining:
|
|||
"linear",
|
||||
"polynomial",
|
||||
],
|
||||
value=self.lr_scheduler_value,
|
||||
value=self.config.get("basic.lr_scheduler", self.lr_scheduler_value),
|
||||
)
|
||||
# Initialize the optimizer dropdown
|
||||
self.optimizer = gr.Dropdown(
|
||||
|
|
@ -156,7 +161,7 @@ class BasicTraining:
|
|||
"SGDNesterov",
|
||||
"SGDNesterov8bit",
|
||||
],
|
||||
value="AdamW8bit",
|
||||
value=self.config.get("basic.optimizer", "AdamW8bit"),
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
|
|
@ -167,19 +172,21 @@ class BasicTraining:
|
|||
with gr.Row():
|
||||
# Initialize the maximum gradient norm slider
|
||||
self.max_grad_norm = gr.Slider(
|
||||
label="Max grad norm", value=1.0, minimum=0.0, maximum=1.0
|
||||
label="Max grad norm", value=self.config.get("basic.max_grad_norm", 1.0), minimum=0.0, maximum=1.0
|
||||
)
|
||||
# Initialize the learning rate scheduler extra arguments textbox
|
||||
self.lr_scheduler_args = gr.Textbox(
|
||||
label="LR scheduler extra arguments",
|
||||
lines=2,
|
||||
placeholder='(Optional) eg: "milestones=[1,10,30,50]" "gamma=0.1"',
|
||||
value=self.config.get("basic.lr_scheduler_args", ""),
|
||||
)
|
||||
# Initialize the optimizer extra arguments textbox
|
||||
self.optimizer_args = gr.Textbox(
|
||||
label="Optimizer extra arguments",
|
||||
lines=2,
|
||||
placeholder="(Optional) eg: relative_step=True scale_parameter=True warmup_init=True",
|
||||
value=self.config.get("basic.optimizer_args", ""),
|
||||
)
|
||||
|
||||
def init_learning_rate_controls(self) -> None:
|
||||
|
|
@ -196,7 +203,7 @@ class BasicTraining:
|
|||
# Initialize the learning rate number input
|
||||
self.learning_rate = gr.Number(
|
||||
label=lr_label,
|
||||
value=self.learning_rate_value,
|
||||
value=self.config.get("basic.learning_rate", self.learning_rate_value),
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
info="Set to 0 to not train the Unet",
|
||||
|
|
@ -204,7 +211,7 @@ class BasicTraining:
|
|||
# Initialize the learning rate TE number input
|
||||
self.learning_rate_te = gr.Number(
|
||||
label="Learning rate TE",
|
||||
value=self.learning_rate_value,
|
||||
value=self.config.get("basic.learning_rate_te", self.learning_rate_value),
|
||||
visible=self.finetuning or self.dreambooth,
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
|
|
@ -213,7 +220,7 @@ class BasicTraining:
|
|||
# Initialize the learning rate TE1 number input
|
||||
self.learning_rate_te1 = gr.Number(
|
||||
label="Learning rate TE1",
|
||||
value=self.learning_rate_value,
|
||||
value=self.config.get("basic.learning_rate_te1", self.learning_rate_value),
|
||||
visible=False,
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
|
|
@ -222,7 +229,7 @@ class BasicTraining:
|
|||
# Initialize the learning rate TE2 number input
|
||||
self.learning_rate_te2 = gr.Number(
|
||||
label="Learning rate TE2",
|
||||
value=self.learning_rate_value,
|
||||
value=self.config.get("basic.learning_rate_te2", self.learning_rate_value),
|
||||
visible=False,
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
|
|
@ -231,7 +238,7 @@ class BasicTraining:
|
|||
# Initialize the learning rate warmup slider
|
||||
self.lr_warmup = gr.Slider(
|
||||
label="LR warmup (% of total steps)",
|
||||
value=self.lr_warmup_value,
|
||||
value=self.config.get("basic.lr_warmup", self.lr_warmup_value),
|
||||
minimum=0,
|
||||
maximum=100,
|
||||
step=1,
|
||||
|
|
@ -246,11 +253,13 @@ class BasicTraining:
|
|||
self.lr_scheduler_num_cycles = gr.Textbox(
|
||||
label="LR # cycles",
|
||||
placeholder="(Optional) For Cosine with restart and polynomial only",
|
||||
value=self.config.get("basic.lr_scheduler_num_cycles", ""),
|
||||
)
|
||||
# Initialize the learning rate scheduler power textbox
|
||||
self.lr_scheduler_power = gr.Textbox(
|
||||
label="LR power",
|
||||
placeholder="(Optional) For Cosine with restart and polynomial only",
|
||||
value=self.config.get("basic.lr_scheduler_power", ""),
|
||||
)
|
||||
|
||||
def init_resolution_and_bucket_controls(self) -> None:
|
||||
|
|
@ -260,22 +269,22 @@ class BasicTraining:
|
|||
with gr.Row(visible=not self.finetuning):
|
||||
# Initialize the maximum resolution textbox
|
||||
self.max_resolution = gr.Textbox(
|
||||
label="Max resolution", value="512,512", placeholder="512,512"
|
||||
label="Max resolution", value=self.config.get("basic.max_resolution", "512,512"), placeholder="512,512"
|
||||
)
|
||||
# Initialize the stop text encoder training slider
|
||||
self.stop_text_encoder_training = gr.Slider(
|
||||
minimum=-1,
|
||||
maximum=100,
|
||||
value=0,
|
||||
value=self.config.get("basic.stop_text_encoder_training", 0),
|
||||
step=1,
|
||||
label="Stop TE (% of total steps)",
|
||||
)
|
||||
# Initialize the enable buckets checkbox
|
||||
self.enable_bucket = gr.Checkbox(label="Enable buckets", value=True)
|
||||
self.enable_bucket = gr.Checkbox(label="Enable buckets", value=self.config.get("basic.enable_bucket", True))
|
||||
# Initialize the minimum bucket resolution slider
|
||||
self.min_bucket_reso = gr.Slider(
|
||||
label="Minimum bucket resolution",
|
||||
value=256,
|
||||
value=self.config.get("basic.min_bucket_reso", 256),
|
||||
minimum=64,
|
||||
maximum=4096,
|
||||
step=64,
|
||||
|
|
@ -284,7 +293,7 @@ class BasicTraining:
|
|||
# Initialize the maximum bucket resolution slider
|
||||
self.max_bucket_reso = gr.Slider(
|
||||
label="Maximum bucket resolution",
|
||||
value=2048,
|
||||
value=self.config.get("basic.max_bucket_reso", 2048),
|
||||
minimum=64,
|
||||
maximum=4096,
|
||||
step=64,
|
||||
|
|
|
|||
|
|
@ -753,6 +753,7 @@ def dreambooth_tab(
|
|||
lr_warmup_value="10",
|
||||
dreambooth=True,
|
||||
sdxl_checkbox=source_model.sdxl_checkbox,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# # Add SDXL Parameters
|
||||
|
|
|
|||
|
|
@ -804,6 +804,7 @@ def finetune_tab(headless=False, config: dict = {}):
|
|||
learning_rate_value="1e-5",
|
||||
finetuning=True,
|
||||
sdxl_checkbox=source_model.sdxl_checkbox,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Add SDXL Parameters
|
||||
|
|
|
|||
|
|
@ -1198,6 +1198,7 @@ def lora_tab(
|
|||
lr_scheduler_value="cosine",
|
||||
lr_warmup_value="10",
|
||||
sdxl_checkbox=source_model.sdxl_checkbox,
|
||||
config=config,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
|
|
|
|||
|
|
@ -812,6 +812,7 @@ def ti_tab(headless=False, default_output_dir=None, config: dict = {}):
|
|||
lr_scheduler_value="cosine",
|
||||
lr_warmup_value="10",
|
||||
sdxl_checkbox=source_model.sdxl_checkbox,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Add SDXL Parameters
|
||||
|
|
|
|||
Loading…
Reference in New Issue