mirror of https://github.com/bmaltais/kohya_ss
Add new field for extra accelerate launch arguments (#2200)
parent
0a8395ddfe
commit
c78c1ab4fe
|
|
@ -440,6 +440,7 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
|
|||
- Move accelerate launch parameters to new `Accelerate launch` accordion above `Model` accordion.
|
||||
- Add support for `Debiased Estimation loss` to Dreambooth settings.
|
||||
- Add support for "Dataset Preparation" defaults via the config.toml file.
|
||||
- Add field to allow for the input of extra accelerate launch arguments.
|
||||
|
||||
### 2024/03/21 (v23.0.15)
|
||||
|
||||
|
|
|
|||
|
|
@ -59,10 +59,22 @@ class AccelerateLaunch:
|
|||
maximum=65535,
|
||||
info="The port to use to communicate with the machine of rank 0.",
|
||||
)
|
||||
with gr.Row():
|
||||
self.extra_accelerate_launch_args = gr.Textbox(
|
||||
label="Extra accelerate launch arguments",
|
||||
value="",
|
||||
placeholder="example: --same_network --machine_rank 4",
|
||||
info="List of extra parameters to pass to accelerate launch",
|
||||
)
|
||||
|
||||
def run_cmd(**kwargs):
|
||||
run_cmd = ""
|
||||
|
||||
if "extra_accelerate_launch_args" in kwargs:
|
||||
extra_accelerate_launch_args = kwargs.get("extra_accelerate_launch_args")
|
||||
if extra_accelerate_launch_args != "":
|
||||
run_cmd += fr' {extra_accelerate_launch_args}'
|
||||
|
||||
if "gpu_ids" in kwargs:
|
||||
gpu_ids = kwargs.get("gpu_ids")
|
||||
if not gpu_ids == "":
|
||||
|
|
|
|||
|
|
@ -153,6 +153,7 @@ def save_configuration(
|
|||
min_timestep,
|
||||
max_timestep,
|
||||
debiased_estimation_loss,
|
||||
extra_accelerate_launch_args,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -292,6 +293,7 @@ def open_configuration(
|
|||
min_timestep,
|
||||
max_timestep,
|
||||
debiased_estimation_loss,
|
||||
extra_accelerate_launch_args,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -426,6 +428,7 @@ def train_model(
|
|||
min_timestep,
|
||||
max_timestep,
|
||||
debiased_estimation_loss,
|
||||
extra_accelerate_launch_args,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -564,6 +567,7 @@ def train_model(
|
|||
main_process_port=main_process_port,
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process,
|
||||
mixed_precision=mixed_precision,
|
||||
extra_accelerate_launch_args=extra_accelerate_launch_args,
|
||||
)
|
||||
|
||||
if sdxl:
|
||||
|
|
@ -913,6 +917,7 @@ def dreambooth_tab(
|
|||
advanced_training.min_timestep,
|
||||
advanced_training.max_timestep,
|
||||
advanced_training.debiased_estimation_loss,
|
||||
accelerate_launch.extra_accelerate_launch_args,
|
||||
]
|
||||
|
||||
configuration.button_open_config.click(
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ def save_configuration(
|
|||
sdxl_no_half_vae,
|
||||
min_timestep,
|
||||
max_timestep,
|
||||
extra_accelerate_launch_args,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -307,6 +308,7 @@ def open_configuration(
|
|||
sdxl_no_half_vae,
|
||||
min_timestep,
|
||||
max_timestep,
|
||||
extra_accelerate_launch_args,
|
||||
training_preset,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
|
|
@ -460,6 +462,7 @@ def train_model(
|
|||
sdxl_no_half_vae,
|
||||
min_timestep,
|
||||
max_timestep,
|
||||
extra_accelerate_launch_args,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -585,6 +588,7 @@ def train_model(
|
|||
main_process_port=main_process_port,
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process,
|
||||
mixed_precision=mixed_precision,
|
||||
extra_accelerate_launch_args=extra_accelerate_launch_args,
|
||||
)
|
||||
|
||||
if sdxl_checkbox:
|
||||
|
|
@ -1001,6 +1005,7 @@ def finetune_tab(headless=False, config: dict = {}):
|
|||
sdxl_params.sdxl_no_half_vae,
|
||||
advanced_training.min_timestep,
|
||||
advanced_training.max_timestep,
|
||||
accelerate_launch.extra_accelerate_launch_args,
|
||||
]
|
||||
|
||||
configuration.button_open_config.click(
|
||||
|
|
|
|||
|
|
@ -223,6 +223,7 @@ def save_configuration(
|
|||
vae,
|
||||
LyCORIS_preset,
|
||||
debiased_estimation_loss,
|
||||
extra_accelerate_launch_args,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -410,6 +411,7 @@ def open_configuration(
|
|||
vae,
|
||||
LyCORIS_preset,
|
||||
debiased_estimation_loss,
|
||||
extra_accelerate_launch_args,
|
||||
training_preset,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
|
|
@ -625,6 +627,7 @@ def train_model(
|
|||
vae,
|
||||
LyCORIS_preset,
|
||||
debiased_estimation_loss,
|
||||
extra_accelerate_launch_args,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -792,6 +795,7 @@ def train_model(
|
|||
main_process_port=main_process_port,
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process,
|
||||
mixed_precision=mixed_precision,
|
||||
extra_accelerate_launch_args=extra_accelerate_launch_args,
|
||||
)
|
||||
|
||||
if sdxl:
|
||||
|
|
@ -2071,6 +2075,7 @@ def lora_tab(
|
|||
advanced_training.vae,
|
||||
LyCORIS_preset,
|
||||
advanced_training.debiased_estimation_loss,
|
||||
accelerate_launch.extra_accelerate_launch_args,
|
||||
]
|
||||
|
||||
configuration.button_open_config.click(
|
||||
|
|
|
|||
|
|
@ -151,6 +151,7 @@ def save_configuration(
|
|||
min_timestep,
|
||||
max_timestep,
|
||||
sdxl_no_half_vae,
|
||||
extra_accelerate_launch_args,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -291,6 +292,7 @@ def open_configuration(
|
|||
min_timestep,
|
||||
max_timestep,
|
||||
sdxl_no_half_vae,
|
||||
extra_accelerate_launch_args,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -424,6 +426,7 @@ def train_model(
|
|||
min_timestep,
|
||||
max_timestep,
|
||||
sdxl_no_half_vae,
|
||||
extra_accelerate_launch_args,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -548,6 +551,7 @@ def train_model(
|
|||
main_process_port=main_process_port,
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process,
|
||||
mixed_precision=mixed_precision,
|
||||
extra_accelerate_launch_args=extra_accelerate_launch_args,
|
||||
)
|
||||
|
||||
if sdxl:
|
||||
|
|
@ -973,6 +977,7 @@ def ti_tab(headless=False, default_output_dir=None, config: dict = {}):
|
|||
advanced_training.min_timestep,
|
||||
advanced_training.max_timestep,
|
||||
sdxl_params.sdxl_no_half_vae,
|
||||
accelerate_launch.extra_accelerate_launch_args,
|
||||
]
|
||||
|
||||
configuration.button_open_config.click(
|
||||
|
|
|
|||
Loading…
Reference in New Issue