make benchmarking queueable

master
papuSpartan 2024-06-01 00:33:55 -05:00
parent 27c0523ec6
commit 9cd7c7c351
2 changed files with 43 additions and 32 deletions

View File

@ -7,6 +7,8 @@ from modules.shared import opts
from modules.shared import state as webui_state
from .shared import logger, LOG_LEVEL, gui_handler
from .worker import State
from modules.call_queue import queue_lock
from modules import progress
worker_select_dropdown = None
@ -61,6 +63,10 @@ class UI:
"""debug utility that will clear the internal webui queue. sometimes good for jams"""
logger.debug(webui_state.__dict__)
webui_state.end()
progress.pending_tasks.clear()
progress.current_task = None
if queue_lock._lock.locked():
queue_lock.release()
def status_btn(self):
"""updates a simplified overview of registered workers and their jobs"""

View File

@ -181,26 +181,14 @@ class World:
Thread(target=worker.refresh_checkpoints, args=()).start()
def sample_master(self) -> float:
# progress.finish_task(progress.current_task)
if queue_lock._lock.locked():
queue_lock.release()
master_bench_payload = StableDiffusionProcessingTxt2Img()
p = StableDiffusionProcessingTxt2Img()
d = sh.benchmark_payload.dict()
for key in d:
setattr(master_bench_payload, key, d[key])
setattr(p, key, d[key])
p.do_not_save_samples = True
# Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to.
master_bench_payload.do_not_save_samples = True
shared.state.end()
wrapped = wrap_queued_call(process_images)
start = time.time()
wrapped(master_bench_payload)
# seems counter-intuitive but the lock will later be released again once the original task is ended by wui
# only doing things this way so that we can bench and then have an original user request immediately resume
if progress.current_task is not None: # could be no task, ie. running bench from utils tab
queue_lock.acquire()
process_images(p)
return time.time() - start
def benchmark(self, rebenchmark: bool = False):
@ -208,6 +196,7 @@ class World:
Attempts to benchmark all workers a part of the world.
"""
local_task_id = 'task(distributed_bench)'
unbenched_workers = []
if rebenchmark:
for worker in self._workers:
@ -246,26 +235,42 @@ class World:
futures.clear()
# benchmark those that haven't been
for worker in unbenched_workers:
if worker.state in (State.DISABLED, State.UNAVAILABLE):
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark")
continue
if len(unbenched_workers) > 0:
queue_lock.acquire()
gradio.Info("Distributed: benchmarking in progress, please wait")
for worker in unbenched_workers:
if worker.state in (State.DISABLED, State.UNAVAILABLE):
logger.debug(f"worker '{worker.label}' is {worker.state.name}, refusing to benchmark")
continue
if worker.model_override is not None:
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
f"*all workers should be evaluated against the same model")
if worker.model_override is not None:
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
f"*all workers should be evaluated against the same model")
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
futures.append(executor.submit(chosen, worker))
logger.info(f"benchmarking worker '{worker.label}'")
if worker.master:
if progress.current_task is None:
progress.add_task_to_queue(local_task_id)
progress.start_task(local_task_id)
shared.state.begin(job=local_task_id)
shared.state.job_count = sh.warmup_samples + sh.samples
# wait for all benchmarks to finish and update stats on newly benchmarked workers
concurrent.futures.wait(futures)
logger.info("benchmarking finished")
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
futures.append(executor.submit(chosen, worker))
logger.info(f"benchmarking worker '{worker.label}'")
# save benchmark results to workers.json
self.save_config()
logger.info(self.speed_summary())
if len(futures) > 0:
# wait for all benchmarks to finish and update stats on newly benchmarked workers
concurrent.futures.wait(futures)
if progress.current_task == local_task_id:
shared.state.end()
progress.finish_task(local_task_id)
queue_lock.release()
logger.info("benchmarking finished")
logger.info(self.speed_summary())
gradio.Info("Distributed: benchmarking complete!")
self.save_config()
def get_current_output_size(self) -> int:
"""