diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index cf92520..2ff1c71 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -682,7 +682,9 @@ class Worker: if vae is not None: self.loaded_vae = vae - return response + self.response = response + + return self def restart(self) -> bool: err_msg = f"could not restart worker '{self.label}'" diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index ae70006..f1e238b 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -4,7 +4,7 @@ This module facilitates the creation of a stable-diffusion-webui centered distri World: The main class which should be instantiated in order to create a new sdwui distributed system. """ - +import concurrent.futures import copy import json import os @@ -210,8 +210,6 @@ class World: """ unbenched_workers = [] - benchmark_threads: List[Thread] = [] - sync_threads: List[Thread] = [] def benchmark_wrapped(worker): bench_func = worker.benchmark if not worker.master else self.benchmark_master @@ -232,60 +230,49 @@ class World: else: worker.benchmarked = True - tasks = [] - loop = asyncio.new_event_loop() - # have every unbenched worker load the same weights before the benchmark - for worker in unbenched_workers: - if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE): - continue + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] - tasks.append( - loop.create_task( - asyncio.to_thread(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae) - , name=worker.label + # have every unbenched worker load the same weights before the benchmark + for worker in unbenched_workers: + if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE): + continue + + futures.append( + executor.submit(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae) ) - ) - if len(tasks) > 0: - results = loop.run_until_complete(asyncio.wait(tasks)) - for task in results[0]: - worker = self[task.get_name()] - response = task.result() - if response.status_code != 200: - logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n" - f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers") - unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers)) + for future in concurrent.futures.as_completed(futures): + worker = future.result() + if worker is None: + continue - # benchmark those that haven't been - tasks = [] - for worker in unbenched_workers: - if worker.state in (State.DISABLED, State.UNAVAILABLE): - logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") - continue + if worker.response.status_code != 200: + logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n" + f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers") + unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers)) + futures.clear() - if worker.model_override is not None: - logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" - f"*all workers should be evaluated against the same model") + # benchmark those that haven't been + for worker in unbenched_workers: + if worker.state in (State.DISABLED, State.UNAVAILABLE): + logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") + continue - tasks.append( - loop.create_task( - asyncio.to_thread(benchmark_wrapped, worker), - name=worker.label - ) - ) - logger.info(f"benchmarking worker '{worker.label}'") + if worker.model_override is not None: + logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" + f"*all workers should be evaluated against the same model") - # wait for all benchmarks to finish and update stats on newly benchmarked workers - if len(tasks) > 0: - results = loop.run_until_complete(asyncio.wait(tasks)) + futures.append(executor.submit(benchmark_wrapped, worker)) + logger.info(f"benchmarking worker '{worker.label}'") + + # wait for all benchmarks to finish and update stats on newly benchmarked workers + concurrent.futures.wait(futures) logger.info("benchmarking finished") - logger.debug(results) # save benchmark results to workers.json self.save_config() logger.info(self.speed_summary()) - loop.close() - def get_current_output_size(self) -> int: """ returns how many images would be returned from all jobs