use concurrent.futures for benchmarking

master
unknown 2024-03-22 16:26:00 -05:00
parent 3a9d87f821
commit 4ddb137b56
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
2 changed files with 35 additions and 46 deletions

View File

@ -682,7 +682,9 @@ class Worker:
if vae is not None: if vae is not None:
self.loaded_vae = vae self.loaded_vae = vae
return response self.response = response
return self
def restart(self) -> bool: def restart(self) -> bool:
err_msg = f"could not restart worker '{self.label}'" err_msg = f"could not restart worker '{self.label}'"

View File

@ -4,7 +4,7 @@ This module facilitates the creation of a stable-diffusion-webui centered distri
World: World:
The main class which should be instantiated in order to create a new sdwui distributed system. The main class which should be instantiated in order to create a new sdwui distributed system.
""" """
import concurrent.futures
import copy import copy
import json import json
import os import os
@ -210,8 +210,6 @@ class World:
""" """
unbenched_workers = [] unbenched_workers = []
benchmark_threads: List[Thread] = []
sync_threads: List[Thread] = []
def benchmark_wrapped(worker): def benchmark_wrapped(worker):
bench_func = worker.benchmark if not worker.master else self.benchmark_master bench_func = worker.benchmark if not worker.master else self.benchmark_master
@ -232,60 +230,49 @@ class World:
else: else:
worker.benchmarked = True worker.benchmarked = True
tasks = [] with concurrent.futures.ThreadPoolExecutor() as executor:
loop = asyncio.new_event_loop() futures = []
# have every unbenched worker load the same weights before the benchmark
for worker in unbenched_workers:
if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE):
continue
tasks.append( # have every unbenched worker load the same weights before the benchmark
loop.create_task( for worker in unbenched_workers:
asyncio.to_thread(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae) if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE):
, name=worker.label continue
futures.append(
executor.submit(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae)
) )
) for future in concurrent.futures.as_completed(futures):
if len(tasks) > 0: worker = future.result()
results = loop.run_until_complete(asyncio.wait(tasks)) if worker is None:
for task in results[0]: continue
worker = self[task.get_name()]
response = task.result()
if response.status_code != 200:
logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n"
f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers")
unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers))
# benchmark those that haven't been if worker.response.status_code != 200:
tasks = [] logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n"
for worker in unbenched_workers: f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers")
if worker.state in (State.DISABLED, State.UNAVAILABLE): unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers))
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") futures.clear()
continue
if worker.model_override is not None: # benchmark those that haven't been
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" for worker in unbenched_workers:
f"*all workers should be evaluated against the same model") if worker.state in (State.DISABLED, State.UNAVAILABLE):
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark")
continue
tasks.append( if worker.model_override is not None:
loop.create_task( logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
asyncio.to_thread(benchmark_wrapped, worker), f"*all workers should be evaluated against the same model")
name=worker.label
)
)
logger.info(f"benchmarking worker '{worker.label}'")
# wait for all benchmarks to finish and update stats on newly benchmarked workers futures.append(executor.submit(benchmark_wrapped, worker))
if len(tasks) > 0: logger.info(f"benchmarking worker '{worker.label}'")
results = loop.run_until_complete(asyncio.wait(tasks))
# wait for all benchmarks to finish and update stats on newly benchmarked workers
concurrent.futures.wait(futures)
logger.info("benchmarking finished") logger.info("benchmarking finished")
logger.debug(results)
# save benchmark results to workers.json # save benchmark results to workers.json
self.save_config() self.save_config()
logger.info(self.speed_summary()) logger.info(self.speed_summary())
loop.close()
def get_current_output_size(self) -> int: def get_current_output_size(self) -> int:
""" """
returns how many images would be returned from all jobs returns how many images would be returned from all jobs