bench threads -> coro
parent
c84d7c8a33
commit
3a9d87f821
|
|
@ -3,6 +3,7 @@ from PIL import Image
|
||||||
from modules.api.api import encode_pil_to_base64
|
from modules.api.api import encode_pil_to_base64
|
||||||
from scripts.spartan.shared import logger
|
from scripts.spartan.shared import logger
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
def np_to_b64(image: np.ndarray):
|
def np_to_b64(image: np.ndarray):
|
||||||
|
|
@ -62,4 +63,10 @@ def pack_control_net(cn_units) -> dict:
|
||||||
# remove anything unserializable
|
# remove anything unserializable
|
||||||
del unit['input_mode']
|
del unit['input_mode']
|
||||||
|
|
||||||
|
try:
|
||||||
|
json.dumps(controlnet)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"failed to serialize controlnet\nfirst unit:\n{controlnet['controlnet']['args'][0]}")
|
||||||
|
return {}
|
||||||
|
|
||||||
return controlnet
|
return controlnet
|
||||||
|
|
|
||||||
|
|
@ -312,7 +312,7 @@ class UI:
|
||||||
# API authentication
|
# API authentication
|
||||||
worker_api_auth_cbx = gradio.Checkbox(label='API Authentication')
|
worker_api_auth_cbx = gradio.Checkbox(label='API Authentication')
|
||||||
worker_user_field = gradio.Textbox(label='Username')
|
worker_user_field = gradio.Textbox(label='Username')
|
||||||
worker_password_field = gradio.Textbox(label='Password')
|
worker_password_field = gradio.Textbox(label='Password', type='password')
|
||||||
update_credentials_btn = gradio.Button(value='Update API Credentials')
|
update_credentials_btn = gradio.Button(value='Update API Credentials')
|
||||||
update_credentials_btn.click(self.update_credentials_btn, inputs=[
|
update_credentials_btn.click(self.update_credentials_btn, inputs=[
|
||||||
worker_api_auth_cbx,
|
worker_api_auth_cbx,
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import copy
|
import copy
|
||||||
import io
|
import io
|
||||||
|
|
@ -156,6 +157,11 @@ class Worker:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}"
|
return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}"
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, Worker) and other.label == self.label:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self) -> Worker_Model:
|
def model(self) -> Worker_Model:
|
||||||
return Worker_Model(**self.__dict__)
|
return Worker_Model(**self.__dict__)
|
||||||
|
|
@ -510,7 +516,7 @@ class Worker:
|
||||||
t: Thread
|
t: Thread
|
||||||
samples = 2 # number of times to benchmark the remote / accuracy
|
samples = 2 # number of times to benchmark the remote / accuracy
|
||||||
|
|
||||||
if self.state == State.DISABLED or self.state == State.UNAVAILABLE:
|
if self.state in (State.DISABLED, State.UNAVAILABLE):
|
||||||
logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark")
|
logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
@ -533,7 +539,6 @@ class Worker:
|
||||||
results: List[float] = []
|
results: List[float] = []
|
||||||
# it used to be lower for the first couple of generations
|
# it used to be lower for the first couple of generations
|
||||||
# this was due to something torch does at startup according to auto and is now done at sdwui startup
|
# this was due to something torch does at startup according to auto and is now done at sdwui startup
|
||||||
self.state = State.WORKING
|
|
||||||
for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
||||||
if self.state == State.UNAVAILABLE:
|
if self.state == State.UNAVAILABLE:
|
||||||
self.response = None
|
self.response = None
|
||||||
|
|
@ -677,6 +682,8 @@ class Worker:
|
||||||
if vae is not None:
|
if vae is not None:
|
||||||
self.loaded_vae = vae
|
self.loaded_vae = vae
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
def restart(self) -> bool:
|
def restart(self) -> bool:
|
||||||
err_msg = f"could not restart worker '{self.label}'"
|
err_msg = f"could not restart worker '{self.label}'"
|
||||||
success_msg = f"worker '{self.label}' is restarting"
|
success_msg = f"worker '{self.label}' is restarting"
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from . import shared as sh
|
||||||
from .pmodels import ConfigModel, Benchmark_Payload
|
from .pmodels import ConfigModel, Benchmark_Payload
|
||||||
from .shared import logger, warmup_samples, extension_path
|
from .shared import logger, warmup_samples, extension_path
|
||||||
from .worker import Worker, State
|
from .worker import Worker, State
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
class NotBenchmarked(Exception):
|
class NotBenchmarked(Exception):
|
||||||
|
|
@ -231,38 +232,60 @@ class World:
|
||||||
else:
|
else:
|
||||||
worker.benchmarked = True
|
worker.benchmarked = True
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
# have every unbenched worker load the same weights before the benchmark
|
# have every unbenched worker load the same weights before the benchmark
|
||||||
for worker in unbenched_workers:
|
for worker in unbenched_workers:
|
||||||
if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE):
|
if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sync_thread = Thread(target=worker.load_options, args=(shared.opts.sd_model_checkpoint, shared.opts.sd_vae))
|
tasks.append(
|
||||||
sync_threads.append(sync_thread)
|
loop.create_task(
|
||||||
sync_thread.start()
|
asyncio.to_thread(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae)
|
||||||
for thread in sync_threads:
|
, name=worker.label
|
||||||
thread.join()
|
)
|
||||||
|
)
|
||||||
|
if len(tasks) > 0:
|
||||||
|
results = loop.run_until_complete(asyncio.wait(tasks))
|
||||||
|
for task in results[0]:
|
||||||
|
worker = self[task.get_name()]
|
||||||
|
response = task.result()
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n"
|
||||||
|
f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers")
|
||||||
|
unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers))
|
||||||
|
|
||||||
# benchmark those that haven't been
|
# benchmark those that haven't been
|
||||||
|
tasks = []
|
||||||
for worker in unbenched_workers:
|
for worker in unbenched_workers:
|
||||||
if worker.state in (State.DISABLED, State.UNAVAILABLE):
|
if worker.state in (State.DISABLED, State.UNAVAILABLE):
|
||||||
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark")
|
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
t = Thread(target=benchmark_wrapped, args=(worker, ), name=f"{worker.label}_benchmark")
|
if worker.model_override is not None:
|
||||||
benchmark_threads.append(t)
|
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
|
||||||
t.start()
|
f"*all workers should be evaluated against the same model")
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
|
loop.create_task(
|
||||||
|
asyncio.to_thread(benchmark_wrapped, worker),
|
||||||
|
name=worker.label
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.info(f"benchmarking worker '{worker.label}'")
|
logger.info(f"benchmarking worker '{worker.label}'")
|
||||||
|
|
||||||
# wait for all benchmarks to finish and update stats on newly benchmarked workers
|
# wait for all benchmarks to finish and update stats on newly benchmarked workers
|
||||||
if len(benchmark_threads) > 0:
|
if len(tasks) > 0:
|
||||||
for t in benchmark_threads:
|
results = loop.run_until_complete(asyncio.wait(tasks))
|
||||||
t.join()
|
|
||||||
logger.info("benchmarking finished")
|
logger.info("benchmarking finished")
|
||||||
|
logger.debug(results)
|
||||||
|
|
||||||
# save benchmark results to workers.json
|
# save benchmark results to workers.json
|
||||||
self.save_config()
|
self.save_config()
|
||||||
logger.info(self.speed_summary())
|
logger.info(self.speed_summary())
|
||||||
|
|
||||||
|
loop.close()
|
||||||
|
|
||||||
def get_current_output_size(self) -> int:
|
def get_current_output_size(self) -> int:
|
||||||
"""
|
"""
|
||||||
returns how many images would be returned from all jobs
|
returns how many images would be returned from all jobs
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue