make benchmarking queueable
parent
27c0523ec6
commit
9cd7c7c351
|
|
@ -7,6 +7,8 @@ from modules.shared import opts
|
|||
from modules.shared import state as webui_state
|
||||
from .shared import logger, LOG_LEVEL, gui_handler
|
||||
from .worker import State
|
||||
from modules.call_queue import queue_lock
|
||||
from modules import progress
|
||||
|
||||
worker_select_dropdown = None
|
||||
|
||||
|
|
@ -61,6 +63,10 @@ class UI:
|
|||
"""debug utility that will clear the internal webui queue. sometimes good for jams"""
|
||||
logger.debug(webui_state.__dict__)
|
||||
webui_state.end()
|
||||
progress.pending_tasks.clear()
|
||||
progress.current_task = None
|
||||
if queue_lock._lock.locked():
|
||||
queue_lock.release()
|
||||
|
||||
def status_btn(self):
|
||||
"""updates a simplified overview of registered workers and their jobs"""
|
||||
|
|
|
|||
|
|
@ -181,26 +181,14 @@ class World:
|
|||
Thread(target=worker.refresh_checkpoints, args=()).start()
|
||||
|
||||
def sample_master(self) -> float:
|
||||
# progress.finish_task(progress.current_task)
|
||||
if queue_lock._lock.locked():
|
||||
queue_lock.release()
|
||||
|
||||
master_bench_payload = StableDiffusionProcessingTxt2Img()
|
||||
p = StableDiffusionProcessingTxt2Img()
|
||||
d = sh.benchmark_payload.dict()
|
||||
for key in d:
|
||||
setattr(master_bench_payload, key, d[key])
|
||||
setattr(p, key, d[key])
|
||||
p.do_not_save_samples = True
|
||||
|
||||
# Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to.
|
||||
master_bench_payload.do_not_save_samples = True
|
||||
shared.state.end()
|
||||
wrapped = wrap_queued_call(process_images)
|
||||
start = time.time()
|
||||
wrapped(master_bench_payload)
|
||||
# seems counter-intuitive but the lock will later be released again once the original task is ended by wui
|
||||
# only doing things this way so that we can bench and then have an original user request immediately resume
|
||||
if progress.current_task is not None: # could be no task, ie. running bench from utils tab
|
||||
queue_lock.acquire()
|
||||
|
||||
process_images(p)
|
||||
return time.time() - start
|
||||
|
||||
def benchmark(self, rebenchmark: bool = False):
|
||||
|
|
@ -208,6 +196,7 @@ class World:
|
|||
Attempts to benchmark all workers a part of the world.
|
||||
"""
|
||||
|
||||
local_task_id = 'task(distributed_bench)'
|
||||
unbenched_workers = []
|
||||
if rebenchmark:
|
||||
for worker in self._workers:
|
||||
|
|
@ -246,26 +235,42 @@ class World:
|
|||
futures.clear()
|
||||
|
||||
# benchmark those that haven't been
|
||||
for worker in unbenched_workers:
|
||||
if worker.state in (State.DISABLED, State.UNAVAILABLE):
|
||||
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark")
|
||||
continue
|
||||
if len(unbenched_workers) > 0:
|
||||
queue_lock.acquire()
|
||||
gradio.Info("Distributed: benchmarking in progress, please wait")
|
||||
for worker in unbenched_workers:
|
||||
if worker.state in (State.DISABLED, State.UNAVAILABLE):
|
||||
logger.debug(f"worker '{worker.label}' is {worker.state.name}, refusing to benchmark")
|
||||
continue
|
||||
|
||||
if worker.model_override is not None:
|
||||
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
|
||||
f"*all workers should be evaluated against the same model")
|
||||
if worker.model_override is not None:
|
||||
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
|
||||
f"*all workers should be evaluated against the same model")
|
||||
|
||||
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
|
||||
futures.append(executor.submit(chosen, worker))
|
||||
logger.info(f"benchmarking worker '{worker.label}'")
|
||||
if worker.master:
|
||||
if progress.current_task is None:
|
||||
progress.add_task_to_queue(local_task_id)
|
||||
progress.start_task(local_task_id)
|
||||
shared.state.begin(job=local_task_id)
|
||||
shared.state.job_count = sh.warmup_samples + sh.samples
|
||||
|
||||
# wait for all benchmarks to finish and update stats on newly benchmarked workers
|
||||
concurrent.futures.wait(futures)
|
||||
logger.info("benchmarking finished")
|
||||
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
|
||||
futures.append(executor.submit(chosen, worker))
|
||||
logger.info(f"benchmarking worker '{worker.label}'")
|
||||
|
||||
# save benchmark results to workers.json
|
||||
self.save_config()
|
||||
logger.info(self.speed_summary())
|
||||
if len(futures) > 0:
|
||||
# wait for all benchmarks to finish and update stats on newly benchmarked workers
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
if progress.current_task == local_task_id:
|
||||
shared.state.end()
|
||||
progress.finish_task(local_task_id)
|
||||
queue_lock.release()
|
||||
|
||||
logger.info("benchmarking finished")
|
||||
logger.info(self.speed_summary())
|
||||
gradio.Info("Distributed: benchmarking complete!")
|
||||
self.save_config()
|
||||
|
||||
def get_current_output_size(self) -> int:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue