Update accelerate launch parameters

pull/2199/head
bmaltais 2024-04-02 12:55:11 -04:00
parent ce1369d47b
commit d4e23f96e5
1 changed files with 3 additions and 3 deletions

View File

@ -82,17 +82,17 @@ class AccelerateLaunch:
if "num_processes" in kwargs:
num_processes = kwargs.get("num_processes")
if int(num_processes) > 1:
if int(num_processes) > 0:
run_cmd += f" --num_processes={int(num_processes)}"
if "num_machines" in kwargs:
num_machines = kwargs.get("num_machines")
if int(num_machines) > 1:
if int(num_machines) > 0:
run_cmd += f" --num_machines={int(num_machines)}"
if "num_cpu_threads_per_process" in kwargs:
num_cpu_threads_per_process = kwargs.get("num_cpu_threads_per_process")
if int(num_cpu_threads_per_process) > 1:
if int(num_cpu_threads_per_process) > 0:
run_cmd += (
f" --num_cpu_threads_per_process={int(num_cpu_threads_per_process)}"
)