From 3a9d87f82138a332af190e20cf2d4e698bbc6e38 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 22 Mar 2024 15:14:26 -0500 Subject: [PATCH] bench threads -> coro --- scripts/spartan/control_net.py | 7 ++++++ scripts/spartan/ui.py | 2 +- scripts/spartan/worker.py | 11 +++++++-- scripts/spartan/world.py | 45 +++++++++++++++++++++++++--------- 4 files changed, 51 insertions(+), 14 deletions(-) diff --git a/scripts/spartan/control_net.py b/scripts/spartan/control_net.py index a6c70fa..dc55cd9 100644 --- a/scripts/spartan/control_net.py +++ b/scripts/spartan/control_net.py @@ -3,6 +3,7 @@ from PIL import Image from modules.api.api import encode_pil_to_base64 from scripts.spartan.shared import logger import numpy as np +import json def np_to_b64(image: np.ndarray): @@ -62,4 +63,10 @@ def pack_control_net(cn_units) -> dict: # remove anything unserializable del unit['input_mode'] + try: + json.dumps(controlnet) + except Exception as e: + logger.error(f"failed to serialize controlnet\nfirst unit:\n{controlnet['controlnet']['args'][0]}") + return {} + return controlnet diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index d23f065..4c710b6 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -312,7 +312,7 @@ class UI: # API authentication worker_api_auth_cbx = gradio.Checkbox(label='API Authentication') worker_user_field = gradio.Textbox(label='Username') - worker_password_field = gradio.Textbox(label='Password') + worker_password_field = gradio.Textbox(label='Password', type='password') update_credentials_btn = gradio.Button(value='Update API Credentials') update_credentials_btn.click(self.update_credentials_btn, inputs=[ worker_api_auth_cbx, diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 9bc1057..cf92520 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -1,3 +1,4 @@ +import asyncio import base64 import copy import io @@ -156,6 +157,11 @@ class Worker: def __repr__(self): return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}" + def __eq__(self, other): + if isinstance(other, Worker) and other.label == self.label: + return True + return False + @property def model(self) -> Worker_Model: return Worker_Model(**self.__dict__) @@ -510,7 +516,7 @@ class Worker: t: Thread samples = 2 # number of times to benchmark the remote / accuracy - if self.state == State.DISABLED or self.state == State.UNAVAILABLE: + if self.state in (State.DISABLED, State.UNAVAILABLE): logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark") return 0 @@ -533,7 +539,6 @@ class Worker: results: List[float] = [] # it used to be lower for the first couple of generations # this was due to something torch does at startup according to auto and is now done at sdwui startup - self.state = State.WORKING for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up" if self.state == State.UNAVAILABLE: self.response = None @@ -677,6 +682,8 @@ class Worker: if vae is not None: self.loaded_vae = vae + return response + def restart(self) -> bool: err_msg = f"could not restart worker '{self.label}'" success_msg = f"worker '{self.label}' is restarting" diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index df6edd1..ae70006 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -18,6 +18,7 @@ from . import shared as sh from .pmodels import ConfigModel, Benchmark_Payload from .shared import logger, warmup_samples, extension_path from .worker import Worker, State +import asyncio class NotBenchmarked(Exception): @@ -231,38 +232,60 @@ class World: else: worker.benchmarked = True + tasks = [] + loop = asyncio.new_event_loop() # have every unbenched worker load the same weights before the benchmark for worker in unbenched_workers: if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE): continue - sync_thread = Thread(target=worker.load_options, args=(shared.opts.sd_model_checkpoint, shared.opts.sd_vae)) - sync_threads.append(sync_thread) - sync_thread.start() - for thread in sync_threads: - thread.join() + tasks.append( + loop.create_task( + asyncio.to_thread(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae) + , name=worker.label + ) + ) + if len(tasks) > 0: + results = loop.run_until_complete(asyncio.wait(tasks)) + for task in results[0]: + worker = self[task.get_name()] + response = task.result() + if response.status_code != 200: + logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n" + f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers") + unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers)) # benchmark those that haven't been + tasks = [] for worker in unbenched_workers: if worker.state in (State.DISABLED, State.UNAVAILABLE): logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") continue - t = Thread(target=benchmark_wrapped, args=(worker, ), name=f"{worker.label}_benchmark") - benchmark_threads.append(t) - t.start() + if worker.model_override is not None: + logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" + f"*all workers should be evaluated against the same model") + + tasks.append( + loop.create_task( + asyncio.to_thread(benchmark_wrapped, worker), + name=worker.label + ) + ) logger.info(f"benchmarking worker '{worker.label}'") # wait for all benchmarks to finish and update stats on newly benchmarked workers - if len(benchmark_threads) > 0: - for t in benchmark_threads: - t.join() + if len(tasks) > 0: + results = loop.run_until_complete(asyncio.wait(tasks)) logger.info("benchmarking finished") + logger.debug(results) # save benchmark results to workers.json self.save_config() logger.info(self.speed_summary()) + loop.close() + def get_current_output_size(self) -> int: """ returns how many images would be returned from all jobs