alpha version of worker management/setup gui

pull/17/head
unknown 2023-07-03 03:35:35 -05:00
parent c22af2d772
commit 80852253c9
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
4 changed files with 102 additions and 34 deletions

View File

@ -54,6 +54,8 @@ class Script(scripts.Script):
else:
logger.fatal(f"Found no worker info passed as arguments. Did you populate --distributed-remotes ?")
world.load_config()
def title(self):
return "Distribute"
@ -110,7 +112,7 @@ class Script(scripts.Script):
if p.n_iter > 1: # if splitting by batch count
num_remote_images *= p.n_iter - 1
logger.debug(f"iteration {iteration}/{p.n_iter}, image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, info-index: {info_index}")
logger.debug(f"image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, info-index: {info_index}")
if Script.world.thin_client_mode:
p.all_negative_prompts = processed.all_negative_prompts
@ -233,10 +235,6 @@ class Script(scripts.Script):
# runs every time the generate button is hit
def run(self, p, *args):
current_thread().name = "distributed_main"
if cmd_opts.distributed_remotes is None:
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
Script.initialize(initial_payload=p)
# strip scripts that aren't yet supported and warn user

View File

@ -1,3 +1,4 @@
import io
import os
import subprocess
from pathlib import Path
@ -5,6 +6,7 @@ import gradio
from scripts.spartan.shared import logger, log_level
from scripts.spartan.Worker import State
from modules.shared import state as webui_state
import json
class UI:
@ -41,12 +43,6 @@ class UI:
logger.info("Redoing benchmarks...")
self.world.benchmark(rebenchmark=True)
def interrupt_btn(self):
self.world.interrupt_remotes()
def refresh_ckpts_btn(self):
self.world.refresh_checkpoints()
def clear_queue_btn(self):
logger.debug(webui_state.__dict__)
webui_state.end()
@ -76,6 +72,25 @@ class UI:
self.world.job_timeout = job_timeout
logger.debug(f"job timeout is now {job_timeout} seconds")
def save_worker_btn(self, name, address, port, tls):
worker = self.world.add_worker(name, address, port, tls)
workers_info = {}
with open(self.world.worker_info_path, 'r', encoding='utf-8') as worker_info_file:
try:
workers_info = json.load(worker_info_file)
except json.decoder.JSONDecodeError:
logger.error(f"corrupt or invalid config file... ignoring")
except io.UnsupportedOperation:
pass
with open(self.world.worker_info_path, 'w', encoding='utf-8') as worker_info_file:
inf: dict = worker.info()
workers_info[name] = inf[name]
json.dump(workers_info, worker_info_file, indent=3)
# end handlers
def create_root(self):
@ -96,25 +111,41 @@ class UI:
with gradio.Tab('Utils'):
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
refresh_checkpoints_btn.style(full_width=False)
refresh_checkpoints_btn.click(self.refresh_ckpts_btn, inputs=[], outputs=[])
refresh_checkpoints_btn.click(self.world.refresh_checkpoints)
run_usr_btn = gradio.Button(value='Run user script')
run_usr_btn.style(full_width=False)
run_usr_btn.click(self.user_script_btn, inputs=[], outputs=[])
run_usr_btn.click(self.user_script_btn)
interrupt_all_btn = gradio.Button(value='Interrupt all', variant='stop')
interrupt_all_btn.style(full_width=False)
interrupt_all_btn.click(self.interrupt_btn, inputs=[], outputs=[])
interrupt_all_btn.click(self.world.interrupt_remotes)
redo_benchmarks_btn = gradio.Button(value='Redo benchmarks', variant='stop')
redo_benchmarks_btn.style(full_width=False)
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
reload_config_btn = gradio.Button(value='Reload config from file')
reload_config_btn.style(full_width=False)
reload_config_btn.click(self.world.load_config)
if log_level == 'DEBUG':
clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop')
clear_queue_btn.style(full_width=False)
clear_queue_btn.click(self.clear_queue_btn)
with gradio.Tab('Worker Config'):
worker_name_field = gradio.Textbox(label='Name')
worker_address_field = gradio.Textbox(label='Address')
worker_port_field = gradio.Textbox(label='Port', value='7860')
worker_tls_cbx = gradio.Checkbox(
label='connect to worker using https'
)
save_worker_btn = gradio.Button(
value='Add Worker'
)
save_worker_btn.click(self.save_worker_btn, inputs=[worker_name_field, worker_address_field, worker_port_field, worker_tls_cbx])
with gradio.Tab('Settings'):
thin_client_cbx = gradio.Checkbox(
label='Thin-client mode (experimental)',

View File

@ -44,8 +44,6 @@ class Worker:
avg_ipm (int): The average images per minute of the node. Defaults to None.
uuid (str): The unique identifier/name of the worker node. Defaults to None.
queried (bool): Whether this worker's memory status has been polled yet. Defaults to False.
free_vram (bytes): The amount of (currently) available VRAM on the worker node. Defaults to 0.
# TODO check this
verify_remotes (bool): Whether to verify the validity of remote worker certificates. Defaults to False.
master (bool): Whether this worker is the master node. Defaults to False.
benchmarked (bool): Whether this worker has been benchmarked. Defaults to False.
@ -60,7 +58,6 @@ class Worker:
avg_ipm: float = None
uuid: str = None
queried: bool = False # whether this worker has been connected to yet
free_vram: bytes = 0
verify_remotes: bool = False
master: bool = False
benchmarked: bool = False
@ -70,6 +67,7 @@ class Worker:
loaded_model: str = None
loaded_vae: str = None
state: State = None
tls: bool = False
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
# We compare to euler a because that is what we currently benchmark each node with.
@ -95,7 +93,7 @@ class Worker:
}
def __init__(self, address: str = None, port: int = None, uuid: str = None, verify_remotes: bool = None,
master: bool = False):
master: bool = False, tls: bool = False):
if master is True:
self.master = master
self.uuid = 'master'
@ -125,18 +123,14 @@ class Worker:
return f"{self.address}:{self.port}"
def info(self) -> dict:
"""
Stores the payload used to benchmark the world and certain attributes of the worker.
These things are used to draw certain conclusions after the first session.
Returns:
dict: Worker info, including how it was benchmarked.
"""
d = {}
data = {
"avg_ipm": self.avg_ipm,
"master": self.master,
"address": self.address,
"port": self.port,
"last_mpe": self.last_mpe,
"tls": self.tls
}
d[self.uuid] = data
@ -169,8 +163,8 @@ class Worker:
str: The full url.
"""
# TODO check if using http or https
return f"http://{self.__str__()}/sdapi/v1/{route}"
protocol = 'http' if not self.tls else 'https'
return f"{protocol}://{self.__str__()}/sdapi/v1/{route}"
def batch_eta_hr(self, payload: dict) -> float:
"""

View File

@ -137,7 +137,7 @@ class World:
if job.worker.master:
return job
def add_worker(self, uuid: str, address: str, port: int):
def add_worker(self, uuid: str, address: str, port: int, tls: bool = False):
"""
Registers a worker with the world.
@ -147,9 +147,15 @@ class World:
port (int): The port number.
"""
worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes)
worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls)
for w in self.__workers:
if w.uuid == uuid:
self.__workers.remove(w)
self.__workers.append(worker)
return worker
def interrupt_remotes(self):
for worker in self.get_workers():
@ -186,8 +192,11 @@ class World:
if saved:
with open(self.worker_info_path, 'r') as worker_info_file:
workers_info = json.load(worker_info_file)
sh.benchmark_payload = workers_info['benchmark_payload']
logger.info(f"Using saved benchmark config:\n{sh.benchmark_payload}")
try:
sh.benchmark_payload = workers_info['benchmark_payload']
logger.info(f"Using saved benchmark config:\n{sh.benchmark_payload}")
except KeyError:
logger.debug(f"using default benchmark payload")
saved = False
workers = self.get_workers()
@ -211,10 +220,14 @@ class World:
if saved and not rebenchmark:
logger.debug(f"loaded saved configuration: \n{workers_info}")
for worker in self.get_workers():
for worker in self.__workers:
try:
worker.avg_ipm = workers_info[worker.uuid]['avg_ipm']
worker.benchmarked = True
if worker.avg_ipm <= 0:
logger.debug(f"{worker.uuid} has recorded ipm of 0... marking as unbenched")
unbenched_workers.append(worker)
else:
worker.benchmarked = True
except KeyError:
logger.debug(f"worker '{worker.uuid}' not found in workers.json")
unbenched_workers.append(worker)
@ -491,3 +504,35 @@ class World:
if self.jobs[last].batch_size < 1:
del self.jobs[last]
last -= 1
def load_config(self):
if not os.path.exists(self.worker_info_path):
logger.debug(f"Config was not found at '{self.worker_info_path}'")
return
with open(self.worker_info_path, 'r') as config:
try:
config_json = json.load(config)
except json.decoder.JSONDecodeError:
logger.debug(f"config is corrupt or invalid JSON, unable to load")
return
for key in config_json:
if key == "benchmark_payload" or key == "master":
continue
w = config_json[key]
try:
worker = self.add_worker(
uuid=key,
address=w['address'],
port=w['port'],
tls=w['tls']
)
worker.address = w['address']
worker.port = w['port']
worker.last_mpe = w['last_mpe']
except KeyError:
logger.error(f"invalid configuration in file for worker {key}... ignoring")
continue