diff --git a/scripts/extension.py b/scripts/extension.py index fb4a1f2..a86faf0 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -54,6 +54,8 @@ class Script(scripts.Script): else: logger.fatal(f"Found no worker info passed as arguments. Did you populate --distributed-remotes ?") + world.load_config() + def title(self): return "Distribute" @@ -110,7 +112,7 @@ class Script(scripts.Script): if p.n_iter > 1: # if splitting by batch count num_remote_images *= p.n_iter - 1 - logger.debug(f"iteration {iteration}/{p.n_iter}, image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, info-index: {info_index}") + logger.debug(f"image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, info-index: {info_index}") if Script.world.thin_client_mode: p.all_negative_prompts = processed.all_negative_prompts @@ -233,10 +235,6 @@ class Script(scripts.Script): # runs every time the generate button is hit def run(self, p, *args): current_thread().name = "distributed_main" - - if cmd_opts.distributed_remotes is None: - raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)") - Script.initialize(initial_payload=p) # strip scripts that aren't yet supported and warn user diff --git a/scripts/spartan/UI.py b/scripts/spartan/UI.py index a6399b3..2ef2460 100644 --- a/scripts/spartan/UI.py +++ b/scripts/spartan/UI.py @@ -1,3 +1,4 @@ +import io import os import subprocess from pathlib import Path @@ -5,6 +6,7 @@ import gradio from scripts.spartan.shared import logger, log_level from scripts.spartan.Worker import State from modules.shared import state as webui_state +import json class UI: @@ -41,12 +43,6 @@ class UI: logger.info("Redoing benchmarks...") self.world.benchmark(rebenchmark=True) - def interrupt_btn(self): - self.world.interrupt_remotes() - - def refresh_ckpts_btn(self): - self.world.refresh_checkpoints() - def clear_queue_btn(self): logger.debug(webui_state.__dict__) webui_state.end() @@ -76,6 +72,25 @@ class UI: self.world.job_timeout = job_timeout logger.debug(f"job timeout is now {job_timeout} seconds") + def save_worker_btn(self, name, address, port, tls): + worker = self.world.add_worker(name, address, port, tls) + + workers_info = {} + with open(self.world.worker_info_path, 'r', encoding='utf-8') as worker_info_file: + try: + workers_info = json.load(worker_info_file) + except json.decoder.JSONDecodeError: + logger.error(f"corrupt or invalid config file... ignoring") + except io.UnsupportedOperation: + pass + + with open(self.world.worker_info_path, 'w', encoding='utf-8') as worker_info_file: + inf: dict = worker.info() + workers_info[name] = inf[name] + + json.dump(workers_info, worker_info_file, indent=3) + + # end handlers def create_root(self): @@ -96,25 +111,41 @@ class UI: with gradio.Tab('Utils'): refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints') refresh_checkpoints_btn.style(full_width=False) - refresh_checkpoints_btn.click(self.refresh_ckpts_btn, inputs=[], outputs=[]) + refresh_checkpoints_btn.click(self.world.refresh_checkpoints) run_usr_btn = gradio.Button(value='Run user script') run_usr_btn.style(full_width=False) - run_usr_btn.click(self.user_script_btn, inputs=[], outputs=[]) + run_usr_btn.click(self.user_script_btn) interrupt_all_btn = gradio.Button(value='Interrupt all', variant='stop') interrupt_all_btn.style(full_width=False) - interrupt_all_btn.click(self.interrupt_btn, inputs=[], outputs=[]) + interrupt_all_btn.click(self.world.interrupt_remotes) redo_benchmarks_btn = gradio.Button(value='Redo benchmarks', variant='stop') redo_benchmarks_btn.style(full_width=False) redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[]) + reload_config_btn = gradio.Button(value='Reload config from file') + reload_config_btn.style(full_width=False) + reload_config_btn.click(self.world.load_config) + if log_level == 'DEBUG': clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop') clear_queue_btn.style(full_width=False) clear_queue_btn.click(self.clear_queue_btn) + with gradio.Tab('Worker Config'): + worker_name_field = gradio.Textbox(label='Name') + worker_address_field = gradio.Textbox(label='Address') + worker_port_field = gradio.Textbox(label='Port', value='7860') + worker_tls_cbx = gradio.Checkbox( + label='connect to worker using https' + ) + save_worker_btn = gradio.Button( + value='Add Worker' + ) + save_worker_btn.click(self.save_worker_btn, inputs=[worker_name_field, worker_address_field, worker_port_field, worker_tls_cbx]) + with gradio.Tab('Settings'): thin_client_cbx = gradio.Checkbox( label='Thin-client mode (experimental)', diff --git a/scripts/spartan/Worker.py b/scripts/spartan/Worker.py index aecc0ae..eb1c9f8 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/Worker.py @@ -44,8 +44,6 @@ class Worker: avg_ipm (int): The average images per minute of the node. Defaults to None. uuid (str): The unique identifier/name of the worker node. Defaults to None. queried (bool): Whether this worker's memory status has been polled yet. Defaults to False. - free_vram (bytes): The amount of (currently) available VRAM on the worker node. Defaults to 0. - # TODO check this verify_remotes (bool): Whether to verify the validity of remote worker certificates. Defaults to False. master (bool): Whether this worker is the master node. Defaults to False. benchmarked (bool): Whether this worker has been benchmarked. Defaults to False. @@ -60,7 +58,6 @@ class Worker: avg_ipm: float = None uuid: str = None queried: bool = False # whether this worker has been connected to yet - free_vram: bytes = 0 verify_remotes: bool = False master: bool = False benchmarked: bool = False @@ -70,6 +67,7 @@ class Worker: loaded_model: str = None loaded_vae: str = None state: State = None + tls: bool = False # Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A. # We compare to euler a because that is what we currently benchmark each node with. @@ -95,7 +93,7 @@ class Worker: } def __init__(self, address: str = None, port: int = None, uuid: str = None, verify_remotes: bool = None, - master: bool = False): + master: bool = False, tls: bool = False): if master is True: self.master = master self.uuid = 'master' @@ -125,18 +123,14 @@ class Worker: return f"{self.address}:{self.port}" def info(self) -> dict: - """ - Stores the payload used to benchmark the world and certain attributes of the worker. - These things are used to draw certain conclusions after the first session. - - Returns: - dict: Worker info, including how it was benchmarked. - """ - d = {} data = { "avg_ipm": self.avg_ipm, "master": self.master, + "address": self.address, + "port": self.port, + "last_mpe": self.last_mpe, + "tls": self.tls } d[self.uuid] = data @@ -169,8 +163,8 @@ class Worker: str: The full url. """ - # TODO check if using http or https - return f"http://{self.__str__()}/sdapi/v1/{route}" + protocol = 'http' if not self.tls else 'https' + return f"{protocol}://{self.__str__()}/sdapi/v1/{route}" def batch_eta_hr(self, payload: dict) -> float: """ diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index 2465ab0..43adcc6 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -137,7 +137,7 @@ class World: if job.worker.master: return job - def add_worker(self, uuid: str, address: str, port: int): + def add_worker(self, uuid: str, address: str, port: int, tls: bool = False): """ Registers a worker with the world. @@ -147,9 +147,15 @@ class World: port (int): The port number. """ - worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes) + worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls) + + for w in self.__workers: + if w.uuid == uuid: + self.__workers.remove(w) self.__workers.append(worker) + return worker + def interrupt_remotes(self): for worker in self.get_workers(): @@ -186,8 +192,11 @@ class World: if saved: with open(self.worker_info_path, 'r') as worker_info_file: workers_info = json.load(worker_info_file) - sh.benchmark_payload = workers_info['benchmark_payload'] - logger.info(f"Using saved benchmark config:\n{sh.benchmark_payload}") + try: + sh.benchmark_payload = workers_info['benchmark_payload'] + logger.info(f"Using saved benchmark config:\n{sh.benchmark_payload}") + except KeyError: + logger.debug(f"using default benchmark payload") saved = False workers = self.get_workers() @@ -211,10 +220,14 @@ class World: if saved and not rebenchmark: logger.debug(f"loaded saved configuration: \n{workers_info}") - for worker in self.get_workers(): + for worker in self.__workers: try: worker.avg_ipm = workers_info[worker.uuid]['avg_ipm'] - worker.benchmarked = True + if worker.avg_ipm <= 0: + logger.debug(f"{worker.uuid} has recorded ipm of 0... marking as unbenched") + unbenched_workers.append(worker) + else: + worker.benchmarked = True except KeyError: logger.debug(f"worker '{worker.uuid}' not found in workers.json") unbenched_workers.append(worker) @@ -491,3 +504,35 @@ class World: if self.jobs[last].batch_size < 1: del self.jobs[last] last -= 1 + + def load_config(self): + if not os.path.exists(self.worker_info_path): + logger.debug(f"Config was not found at '{self.worker_info_path}'") + return + + with open(self.worker_info_path, 'r') as config: + + try: + config_json = json.load(config) + except json.decoder.JSONDecodeError: + logger.debug(f"config is corrupt or invalid JSON, unable to load") + return + + for key in config_json: + if key == "benchmark_payload" or key == "master": + continue + + w = config_json[key] + try: + worker = self.add_worker( + uuid=key, + address=w['address'], + port=w['port'], + tls=w['tls'] + ) + worker.address = w['address'] + worker.port = w['port'] + worker.last_mpe = w['last_mpe'] + except KeyError: + logger.error(f"invalid configuration in file for worker {key}... ignoring") + continue