alpha version of worker management/setup gui
parent
c22af2d772
commit
80852253c9
|
|
@ -54,6 +54,8 @@ class Script(scripts.Script):
|
|||
else:
|
||||
logger.fatal(f"Found no worker info passed as arguments. Did you populate --distributed-remotes ?")
|
||||
|
||||
world.load_config()
|
||||
|
||||
def title(self):
|
||||
return "Distribute"
|
||||
|
||||
|
|
@ -110,7 +112,7 @@ class Script(scripts.Script):
|
|||
if p.n_iter > 1: # if splitting by batch count
|
||||
num_remote_images *= p.n_iter - 1
|
||||
|
||||
logger.debug(f"iteration {iteration}/{p.n_iter}, image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, info-index: {info_index}")
|
||||
logger.debug(f"image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, info-index: {info_index}")
|
||||
|
||||
if Script.world.thin_client_mode:
|
||||
p.all_negative_prompts = processed.all_negative_prompts
|
||||
|
|
@ -233,10 +235,6 @@ class Script(scripts.Script):
|
|||
# runs every time the generate button is hit
|
||||
def run(self, p, *args):
|
||||
current_thread().name = "distributed_main"
|
||||
|
||||
if cmd_opts.distributed_remotes is None:
|
||||
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
|
||||
|
||||
Script.initialize(initial_payload=p)
|
||||
|
||||
# strip scripts that aren't yet supported and warn user
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import io
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
|
@ -5,6 +6,7 @@ import gradio
|
|||
from scripts.spartan.shared import logger, log_level
|
||||
from scripts.spartan.Worker import State
|
||||
from modules.shared import state as webui_state
|
||||
import json
|
||||
|
||||
|
||||
class UI:
|
||||
|
|
@ -41,12 +43,6 @@ class UI:
|
|||
logger.info("Redoing benchmarks...")
|
||||
self.world.benchmark(rebenchmark=True)
|
||||
|
||||
def interrupt_btn(self):
|
||||
self.world.interrupt_remotes()
|
||||
|
||||
def refresh_ckpts_btn(self):
|
||||
self.world.refresh_checkpoints()
|
||||
|
||||
def clear_queue_btn(self):
|
||||
logger.debug(webui_state.__dict__)
|
||||
webui_state.end()
|
||||
|
|
@ -76,6 +72,25 @@ class UI:
|
|||
self.world.job_timeout = job_timeout
|
||||
logger.debug(f"job timeout is now {job_timeout} seconds")
|
||||
|
||||
def save_worker_btn(self, name, address, port, tls):
|
||||
worker = self.world.add_worker(name, address, port, tls)
|
||||
|
||||
workers_info = {}
|
||||
with open(self.world.worker_info_path, 'r', encoding='utf-8') as worker_info_file:
|
||||
try:
|
||||
workers_info = json.load(worker_info_file)
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.error(f"corrupt or invalid config file... ignoring")
|
||||
except io.UnsupportedOperation:
|
||||
pass
|
||||
|
||||
with open(self.world.worker_info_path, 'w', encoding='utf-8') as worker_info_file:
|
||||
inf: dict = worker.info()
|
||||
workers_info[name] = inf[name]
|
||||
|
||||
json.dump(workers_info, worker_info_file, indent=3)
|
||||
|
||||
|
||||
# end handlers
|
||||
|
||||
def create_root(self):
|
||||
|
|
@ -96,25 +111,41 @@ class UI:
|
|||
with gradio.Tab('Utils'):
|
||||
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
|
||||
refresh_checkpoints_btn.style(full_width=False)
|
||||
refresh_checkpoints_btn.click(self.refresh_ckpts_btn, inputs=[], outputs=[])
|
||||
refresh_checkpoints_btn.click(self.world.refresh_checkpoints)
|
||||
|
||||
run_usr_btn = gradio.Button(value='Run user script')
|
||||
run_usr_btn.style(full_width=False)
|
||||
run_usr_btn.click(self.user_script_btn, inputs=[], outputs=[])
|
||||
run_usr_btn.click(self.user_script_btn)
|
||||
|
||||
interrupt_all_btn = gradio.Button(value='Interrupt all', variant='stop')
|
||||
interrupt_all_btn.style(full_width=False)
|
||||
interrupt_all_btn.click(self.interrupt_btn, inputs=[], outputs=[])
|
||||
interrupt_all_btn.click(self.world.interrupt_remotes)
|
||||
|
||||
redo_benchmarks_btn = gradio.Button(value='Redo benchmarks', variant='stop')
|
||||
redo_benchmarks_btn.style(full_width=False)
|
||||
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
|
||||
|
||||
reload_config_btn = gradio.Button(value='Reload config from file')
|
||||
reload_config_btn.style(full_width=False)
|
||||
reload_config_btn.click(self.world.load_config)
|
||||
|
||||
if log_level == 'DEBUG':
|
||||
clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop')
|
||||
clear_queue_btn.style(full_width=False)
|
||||
clear_queue_btn.click(self.clear_queue_btn)
|
||||
|
||||
with gradio.Tab('Worker Config'):
|
||||
worker_name_field = gradio.Textbox(label='Name')
|
||||
worker_address_field = gradio.Textbox(label='Address')
|
||||
worker_port_field = gradio.Textbox(label='Port', value='7860')
|
||||
worker_tls_cbx = gradio.Checkbox(
|
||||
label='connect to worker using https'
|
||||
)
|
||||
save_worker_btn = gradio.Button(
|
||||
value='Add Worker'
|
||||
)
|
||||
save_worker_btn.click(self.save_worker_btn, inputs=[worker_name_field, worker_address_field, worker_port_field, worker_tls_cbx])
|
||||
|
||||
with gradio.Tab('Settings'):
|
||||
thin_client_cbx = gradio.Checkbox(
|
||||
label='Thin-client mode (experimental)',
|
||||
|
|
|
|||
|
|
@ -44,8 +44,6 @@ class Worker:
|
|||
avg_ipm (int): The average images per minute of the node. Defaults to None.
|
||||
uuid (str): The unique identifier/name of the worker node. Defaults to None.
|
||||
queried (bool): Whether this worker's memory status has been polled yet. Defaults to False.
|
||||
free_vram (bytes): The amount of (currently) available VRAM on the worker node. Defaults to 0.
|
||||
# TODO check this
|
||||
verify_remotes (bool): Whether to verify the validity of remote worker certificates. Defaults to False.
|
||||
master (bool): Whether this worker is the master node. Defaults to False.
|
||||
benchmarked (bool): Whether this worker has been benchmarked. Defaults to False.
|
||||
|
|
@ -60,7 +58,6 @@ class Worker:
|
|||
avg_ipm: float = None
|
||||
uuid: str = None
|
||||
queried: bool = False # whether this worker has been connected to yet
|
||||
free_vram: bytes = 0
|
||||
verify_remotes: bool = False
|
||||
master: bool = False
|
||||
benchmarked: bool = False
|
||||
|
|
@ -70,6 +67,7 @@ class Worker:
|
|||
loaded_model: str = None
|
||||
loaded_vae: str = None
|
||||
state: State = None
|
||||
tls: bool = False
|
||||
|
||||
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
|
||||
# We compare to euler a because that is what we currently benchmark each node with.
|
||||
|
|
@ -95,7 +93,7 @@ class Worker:
|
|||
}
|
||||
|
||||
def __init__(self, address: str = None, port: int = None, uuid: str = None, verify_remotes: bool = None,
|
||||
master: bool = False):
|
||||
master: bool = False, tls: bool = False):
|
||||
if master is True:
|
||||
self.master = master
|
||||
self.uuid = 'master'
|
||||
|
|
@ -125,18 +123,14 @@ class Worker:
|
|||
return f"{self.address}:{self.port}"
|
||||
|
||||
def info(self) -> dict:
|
||||
"""
|
||||
Stores the payload used to benchmark the world and certain attributes of the worker.
|
||||
These things are used to draw certain conclusions after the first session.
|
||||
|
||||
Returns:
|
||||
dict: Worker info, including how it was benchmarked.
|
||||
"""
|
||||
|
||||
d = {}
|
||||
data = {
|
||||
"avg_ipm": self.avg_ipm,
|
||||
"master": self.master,
|
||||
"address": self.address,
|
||||
"port": self.port,
|
||||
"last_mpe": self.last_mpe,
|
||||
"tls": self.tls
|
||||
}
|
||||
|
||||
d[self.uuid] = data
|
||||
|
|
@ -169,8 +163,8 @@ class Worker:
|
|||
str: The full url.
|
||||
"""
|
||||
|
||||
# TODO check if using http or https
|
||||
return f"http://{self.__str__()}/sdapi/v1/{route}"
|
||||
protocol = 'http' if not self.tls else 'https'
|
||||
return f"{protocol}://{self.__str__()}/sdapi/v1/{route}"
|
||||
|
||||
def batch_eta_hr(self, payload: dict) -> float:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ class World:
|
|||
if job.worker.master:
|
||||
return job
|
||||
|
||||
def add_worker(self, uuid: str, address: str, port: int):
|
||||
def add_worker(self, uuid: str, address: str, port: int, tls: bool = False):
|
||||
"""
|
||||
Registers a worker with the world.
|
||||
|
||||
|
|
@ -147,9 +147,15 @@ class World:
|
|||
port (int): The port number.
|
||||
"""
|
||||
|
||||
worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes)
|
||||
worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls)
|
||||
|
||||
for w in self.__workers:
|
||||
if w.uuid == uuid:
|
||||
self.__workers.remove(w)
|
||||
self.__workers.append(worker)
|
||||
|
||||
return worker
|
||||
|
||||
def interrupt_remotes(self):
|
||||
|
||||
for worker in self.get_workers():
|
||||
|
|
@ -186,8 +192,11 @@ class World:
|
|||
if saved:
|
||||
with open(self.worker_info_path, 'r') as worker_info_file:
|
||||
workers_info = json.load(worker_info_file)
|
||||
sh.benchmark_payload = workers_info['benchmark_payload']
|
||||
logger.info(f"Using saved benchmark config:\n{sh.benchmark_payload}")
|
||||
try:
|
||||
sh.benchmark_payload = workers_info['benchmark_payload']
|
||||
logger.info(f"Using saved benchmark config:\n{sh.benchmark_payload}")
|
||||
except KeyError:
|
||||
logger.debug(f"using default benchmark payload")
|
||||
|
||||
saved = False
|
||||
workers = self.get_workers()
|
||||
|
|
@ -211,10 +220,14 @@ class World:
|
|||
if saved and not rebenchmark:
|
||||
logger.debug(f"loaded saved configuration: \n{workers_info}")
|
||||
|
||||
for worker in self.get_workers():
|
||||
for worker in self.__workers:
|
||||
try:
|
||||
worker.avg_ipm = workers_info[worker.uuid]['avg_ipm']
|
||||
worker.benchmarked = True
|
||||
if worker.avg_ipm <= 0:
|
||||
logger.debug(f"{worker.uuid} has recorded ipm of 0... marking as unbenched")
|
||||
unbenched_workers.append(worker)
|
||||
else:
|
||||
worker.benchmarked = True
|
||||
except KeyError:
|
||||
logger.debug(f"worker '{worker.uuid}' not found in workers.json")
|
||||
unbenched_workers.append(worker)
|
||||
|
|
@ -491,3 +504,35 @@ class World:
|
|||
if self.jobs[last].batch_size < 1:
|
||||
del self.jobs[last]
|
||||
last -= 1
|
||||
|
||||
def load_config(self):
|
||||
if not os.path.exists(self.worker_info_path):
|
||||
logger.debug(f"Config was not found at '{self.worker_info_path}'")
|
||||
return
|
||||
|
||||
with open(self.worker_info_path, 'r') as config:
|
||||
|
||||
try:
|
||||
config_json = json.load(config)
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.debug(f"config is corrupt or invalid JSON, unable to load")
|
||||
return
|
||||
|
||||
for key in config_json:
|
||||
if key == "benchmark_payload" or key == "master":
|
||||
continue
|
||||
|
||||
w = config_json[key]
|
||||
try:
|
||||
worker = self.add_worker(
|
||||
uuid=key,
|
||||
address=w['address'],
|
||||
port=w['port'],
|
||||
tls=w['tls']
|
||||
)
|
||||
worker.address = w['address']
|
||||
worker.port = w['port']
|
||||
worker.last_mpe = w['last_mpe']
|
||||
except KeyError:
|
||||
logger.error(f"invalid configuration in file for worker {key}... ignoring")
|
||||
continue
|
||||
|
|
|
|||
Loading…
Reference in New Issue