Add new field for extra accelerate launch arguments (#2200)

pull/2219/head
bmaltais 2024-04-02 20:56:59 -04:00 committed by GitHub
parent 0a8395ddfe
commit c78c1ab4fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 33 additions and 0 deletions

View File

@ -440,6 +440,7 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
- Move accelerate launch parameters to new `Accelerate launch` accordion above `Model` accordion.
- Add support for `Debiased Estimation loss` to Dreambooth settings.
- Add support for "Dataset Preparation" defaults via the config.toml file.
- Add field to allow for the input of extra accelerate launch arguments.
### 2024/03/21 (v23.0.15)

View File

@ -59,10 +59,22 @@ class AccelerateLaunch:
maximum=65535,
info="The port to use to communicate with the machine of rank 0.",
)
with gr.Row():
self.extra_accelerate_launch_args = gr.Textbox(
label="Extra accelerate launch arguments",
value="",
placeholder="example: --same_network --machine_rank 4",
info="List of extra parameters to pass to accelerate launch",
)
def run_cmd(**kwargs):
run_cmd = ""
if "extra_accelerate_launch_args" in kwargs:
extra_accelerate_launch_args = kwargs.get("extra_accelerate_launch_args")
if extra_accelerate_launch_args != "":
run_cmd += fr' {extra_accelerate_launch_args}'
if "gpu_ids" in kwargs:
gpu_ids = kwargs.get("gpu_ids")
if not gpu_ids == "":

View File

@ -153,6 +153,7 @@ def save_configuration(
min_timestep,
max_timestep,
debiased_estimation_loss,
extra_accelerate_launch_args,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -292,6 +293,7 @@ def open_configuration(
min_timestep,
max_timestep,
debiased_estimation_loss,
extra_accelerate_launch_args,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -426,6 +428,7 @@ def train_model(
min_timestep,
max_timestep,
debiased_estimation_loss,
extra_accelerate_launch_args,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -564,6 +567,7 @@ def train_model(
main_process_port=main_process_port,
num_cpu_threads_per_process=num_cpu_threads_per_process,
mixed_precision=mixed_precision,
extra_accelerate_launch_args=extra_accelerate_launch_args,
)
if sdxl:
@ -913,6 +917,7 @@ def dreambooth_tab(
advanced_training.min_timestep,
advanced_training.max_timestep,
advanced_training.debiased_estimation_loss,
accelerate_launch.extra_accelerate_launch_args,
]
configuration.button_open_config.click(

View File

@ -161,6 +161,7 @@ def save_configuration(
sdxl_no_half_vae,
min_timestep,
max_timestep,
extra_accelerate_launch_args,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -307,6 +308,7 @@ def open_configuration(
sdxl_no_half_vae,
min_timestep,
max_timestep,
extra_accelerate_launch_args,
training_preset,
):
# Get list of function parameters and values
@ -460,6 +462,7 @@ def train_model(
sdxl_no_half_vae,
min_timestep,
max_timestep,
extra_accelerate_launch_args,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -585,6 +588,7 @@ def train_model(
main_process_port=main_process_port,
num_cpu_threads_per_process=num_cpu_threads_per_process,
mixed_precision=mixed_precision,
extra_accelerate_launch_args=extra_accelerate_launch_args,
)
if sdxl_checkbox:
@ -1001,6 +1005,7 @@ def finetune_tab(headless=False, config: dict = {}):
sdxl_params.sdxl_no_half_vae,
advanced_training.min_timestep,
advanced_training.max_timestep,
accelerate_launch.extra_accelerate_launch_args,
]
configuration.button_open_config.click(

View File

@ -223,6 +223,7 @@ def save_configuration(
vae,
LyCORIS_preset,
debiased_estimation_loss,
extra_accelerate_launch_args,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -410,6 +411,7 @@ def open_configuration(
vae,
LyCORIS_preset,
debiased_estimation_loss,
extra_accelerate_launch_args,
training_preset,
):
# Get list of function parameters and values
@ -625,6 +627,7 @@ def train_model(
vae,
LyCORIS_preset,
debiased_estimation_loss,
extra_accelerate_launch_args,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -792,6 +795,7 @@ def train_model(
main_process_port=main_process_port,
num_cpu_threads_per_process=num_cpu_threads_per_process,
mixed_precision=mixed_precision,
extra_accelerate_launch_args=extra_accelerate_launch_args,
)
if sdxl:
@ -2071,6 +2075,7 @@ def lora_tab(
advanced_training.vae,
LyCORIS_preset,
advanced_training.debiased_estimation_loss,
accelerate_launch.extra_accelerate_launch_args,
]
configuration.button_open_config.click(

View File

@ -151,6 +151,7 @@ def save_configuration(
min_timestep,
max_timestep,
sdxl_no_half_vae,
extra_accelerate_launch_args,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -291,6 +292,7 @@ def open_configuration(
min_timestep,
max_timestep,
sdxl_no_half_vae,
extra_accelerate_launch_args,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -424,6 +426,7 @@ def train_model(
min_timestep,
max_timestep,
sdxl_no_half_vae,
extra_accelerate_launch_args,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -548,6 +551,7 @@ def train_model(
main_process_port=main_process_port,
num_cpu_threads_per_process=num_cpu_threads_per_process,
mixed_precision=mixed_precision,
extra_accelerate_launch_args=extra_accelerate_launch_args,
)
if sdxl:
@ -973,6 +977,7 @@ def ti_tab(headless=False, default_output_dir=None, config: dict = {}):
advanced_training.min_timestep,
advanced_training.max_timestep,
sdxl_params.sdxl_no_half_vae,
accelerate_launch.extra_accelerate_launch_args,
]
configuration.button_open_config.click(