bench threads -> coro

master
unknown 2024-03-22 15:14:26 -05:00
parent c84d7c8a33
commit 3a9d87f821
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
4 changed files with 51 additions and 14 deletions

View File

@ -3,6 +3,7 @@ from PIL import Image
from modules.api.api import encode_pil_to_base64 from modules.api.api import encode_pil_to_base64
from scripts.spartan.shared import logger from scripts.spartan.shared import logger
import numpy as np import numpy as np
import json
def np_to_b64(image: np.ndarray): def np_to_b64(image: np.ndarray):
@ -62,4 +63,10 @@ def pack_control_net(cn_units) -> dict:
# remove anything unserializable # remove anything unserializable
del unit['input_mode'] del unit['input_mode']
try:
json.dumps(controlnet)
except Exception as e:
logger.error(f"failed to serialize controlnet\nfirst unit:\n{controlnet['controlnet']['args'][0]}")
return {}
return controlnet return controlnet

View File

@ -312,7 +312,7 @@ class UI:
# API authentication # API authentication
worker_api_auth_cbx = gradio.Checkbox(label='API Authentication') worker_api_auth_cbx = gradio.Checkbox(label='API Authentication')
worker_user_field = gradio.Textbox(label='Username') worker_user_field = gradio.Textbox(label='Username')
worker_password_field = gradio.Textbox(label='Password') worker_password_field = gradio.Textbox(label='Password', type='password')
update_credentials_btn = gradio.Button(value='Update API Credentials') update_credentials_btn = gradio.Button(value='Update API Credentials')
update_credentials_btn.click(self.update_credentials_btn, inputs=[ update_credentials_btn.click(self.update_credentials_btn, inputs=[
worker_api_auth_cbx, worker_api_auth_cbx,

View File

@ -1,3 +1,4 @@
import asyncio
import base64 import base64
import copy import copy
import io import io
@ -156,6 +157,11 @@ class Worker:
def __repr__(self): def __repr__(self):
return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}" return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}"
def __eq__(self, other):
if isinstance(other, Worker) and other.label == self.label:
return True
return False
@property @property
def model(self) -> Worker_Model: def model(self) -> Worker_Model:
return Worker_Model(**self.__dict__) return Worker_Model(**self.__dict__)
@ -510,7 +516,7 @@ class Worker:
t: Thread t: Thread
samples = 2 # number of times to benchmark the remote / accuracy samples = 2 # number of times to benchmark the remote / accuracy
if self.state == State.DISABLED or self.state == State.UNAVAILABLE: if self.state in (State.DISABLED, State.UNAVAILABLE):
logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark") logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark")
return 0 return 0
@ -533,7 +539,6 @@ class Worker:
results: List[float] = [] results: List[float] = []
# it used to be lower for the first couple of generations # it used to be lower for the first couple of generations
# this was due to something torch does at startup according to auto and is now done at sdwui startup # this was due to something torch does at startup according to auto and is now done at sdwui startup
self.state = State.WORKING
for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up" for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up"
if self.state == State.UNAVAILABLE: if self.state == State.UNAVAILABLE:
self.response = None self.response = None
@ -677,6 +682,8 @@ class Worker:
if vae is not None: if vae is not None:
self.loaded_vae = vae self.loaded_vae = vae
return response
def restart(self) -> bool: def restart(self) -> bool:
err_msg = f"could not restart worker '{self.label}'" err_msg = f"could not restart worker '{self.label}'"
success_msg = f"worker '{self.label}' is restarting" success_msg = f"worker '{self.label}' is restarting"

View File

@ -18,6 +18,7 @@ from . import shared as sh
from .pmodels import ConfigModel, Benchmark_Payload from .pmodels import ConfigModel, Benchmark_Payload
from .shared import logger, warmup_samples, extension_path from .shared import logger, warmup_samples, extension_path
from .worker import Worker, State from .worker import Worker, State
import asyncio
class NotBenchmarked(Exception): class NotBenchmarked(Exception):
@ -231,38 +232,60 @@ class World:
else: else:
worker.benchmarked = True worker.benchmarked = True
tasks = []
loop = asyncio.new_event_loop()
# have every unbenched worker load the same weights before the benchmark # have every unbenched worker load the same weights before the benchmark
for worker in unbenched_workers: for worker in unbenched_workers:
if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE): if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE):
continue continue
sync_thread = Thread(target=worker.load_options, args=(shared.opts.sd_model_checkpoint, shared.opts.sd_vae)) tasks.append(
sync_threads.append(sync_thread) loop.create_task(
sync_thread.start() asyncio.to_thread(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae)
for thread in sync_threads: , name=worker.label
thread.join() )
)
if len(tasks) > 0:
results = loop.run_until_complete(asyncio.wait(tasks))
for task in results[0]:
worker = self[task.get_name()]
response = task.result()
if response.status_code != 200:
logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n"
f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers")
unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers))
# benchmark those that haven't been # benchmark those that haven't been
tasks = []
for worker in unbenched_workers: for worker in unbenched_workers:
if worker.state in (State.DISABLED, State.UNAVAILABLE): if worker.state in (State.DISABLED, State.UNAVAILABLE):
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark")
continue continue
t = Thread(target=benchmark_wrapped, args=(worker, ), name=f"{worker.label}_benchmark") if worker.model_override is not None:
benchmark_threads.append(t) logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
t.start() f"*all workers should be evaluated against the same model")
tasks.append(
loop.create_task(
asyncio.to_thread(benchmark_wrapped, worker),
name=worker.label
)
)
logger.info(f"benchmarking worker '{worker.label}'") logger.info(f"benchmarking worker '{worker.label}'")
# wait for all benchmarks to finish and update stats on newly benchmarked workers # wait for all benchmarks to finish and update stats on newly benchmarked workers
if len(benchmark_threads) > 0: if len(tasks) > 0:
for t in benchmark_threads: results = loop.run_until_complete(asyncio.wait(tasks))
t.join()
logger.info("benchmarking finished") logger.info("benchmarking finished")
logger.debug(results)
# save benchmark results to workers.json # save benchmark results to workers.json
self.save_config() self.save_config()
logger.info(self.speed_summary()) logger.info(self.speed_summary())
loop.close()
def get_current_output_size(self) -> int: def get_current_output_size(self) -> int:
""" """
returns how many images would be returned from all jobs returns how many images would be returned from all jobs