diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index cd0b87a..5aac350 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -7,6 +7,8 @@ from modules.shared import opts from modules.shared import state as webui_state from .shared import logger, LOG_LEVEL, gui_handler from .worker import State +from modules.call_queue import queue_lock +from modules import progress worker_select_dropdown = None @@ -61,6 +63,10 @@ class UI: """debug utility that will clear the internal webui queue. sometimes good for jams""" logger.debug(webui_state.__dict__) webui_state.end() + progress.pending_tasks.clear() + progress.current_task = None + if queue_lock._lock.locked(): + queue_lock.release() def status_btn(self): """updates a simplified overview of registered workers and their jobs""" diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 4ab5c24..0e3df16 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -181,26 +181,14 @@ class World: Thread(target=worker.refresh_checkpoints, args=()).start() def sample_master(self) -> float: - # progress.finish_task(progress.current_task) - if queue_lock._lock.locked(): - queue_lock.release() - - master_bench_payload = StableDiffusionProcessingTxt2Img() + p = StableDiffusionProcessingTxt2Img() d = sh.benchmark_payload.dict() for key in d: - setattr(master_bench_payload, key, d[key]) + setattr(p, key, d[key]) + p.do_not_save_samples = True - # Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to. - master_bench_payload.do_not_save_samples = True - shared.state.end() - wrapped = wrap_queued_call(process_images) start = time.time() - wrapped(master_bench_payload) - # seems counter-intuitive but the lock will later be released again once the original task is ended by wui - # only doing things this way so that we can bench and then have an original user request immediately resume - if progress.current_task is not None: # could be no task, ie. running bench from utils tab - queue_lock.acquire() - + process_images(p) return time.time() - start def benchmark(self, rebenchmark: bool = False): @@ -208,6 +196,7 @@ class World: Attempts to benchmark all workers a part of the world. """ + local_task_id = 'task(distributed_bench)' unbenched_workers = [] if rebenchmark: for worker in self._workers: @@ -246,26 +235,42 @@ class World: futures.clear() # benchmark those that haven't been - for worker in unbenched_workers: - if worker.state in (State.DISABLED, State.UNAVAILABLE): - logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") - continue + if len(unbenched_workers) > 0: + queue_lock.acquire() + gradio.Info("Distributed: benchmarking in progress, please wait") + for worker in unbenched_workers: + if worker.state in (State.DISABLED, State.UNAVAILABLE): + logger.debug(f"worker '{worker.label}' is {worker.state.name}, refusing to benchmark") + continue - if worker.model_override is not None: - logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" - f"*all workers should be evaluated against the same model") + if worker.model_override is not None: + logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" + f"*all workers should be evaluated against the same model") - chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master) - futures.append(executor.submit(chosen, worker)) - logger.info(f"benchmarking worker '{worker.label}'") + if worker.master: + if progress.current_task is None: + progress.add_task_to_queue(local_task_id) + progress.start_task(local_task_id) + shared.state.begin(job=local_task_id) + shared.state.job_count = sh.warmup_samples + sh.samples - # wait for all benchmarks to finish and update stats on newly benchmarked workers - concurrent.futures.wait(futures) - logger.info("benchmarking finished") + chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master) + futures.append(executor.submit(chosen, worker)) + logger.info(f"benchmarking worker '{worker.label}'") - # save benchmark results to workers.json - self.save_config() - logger.info(self.speed_summary()) + if len(futures) > 0: + # wait for all benchmarks to finish and update stats on newly benchmarked workers + concurrent.futures.wait(futures) + + if progress.current_task == local_task_id: + shared.state.end() + progress.finish_task(local_task_id) + queue_lock.release() + + logger.info("benchmarking finished") + logger.info(self.speed_summary()) + gradio.Info("Distributed: benchmarking complete!") + self.save_config() def get_current_output_size(self) -> int: """