diff --git a/mikazuki/process.py b/mikazuki/process.py index bb292ad..3b6b138 100644 --- a/mikazuki/process.py +++ b/mikazuki/process.py @@ -21,7 +21,7 @@ def run_train(toml_path: str, "--config_file", toml_path, ] if multi_gpu: - args[3:2] = ["--multi_gpu", "--num_processes=2"] + args[3:3] = ["--multi_gpu", "--num_processes", "2"] customize_env = os.environ.copy() customize_env["ACCELERATE_DISABLE_RICH"] = "1"