diff --git a/scripts/extension.py b/scripts/extension.py index 6eb0796..d93e2be 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -14,11 +14,11 @@ from typing import List import urllib3 import copy from modules.images import save_image -from modules.shared import cmd_opts +from modules.shared import opts, cmd_opts +from modules.shared import state as webui_state import time from scripts.spartan.World import World, WorldAlreadyInitialized from scripts.spartan.UI import UI -from modules.shared import opts from scripts.spartan.shared import logger from scripts.spartan.control_net import pack_control_net from modules.processing import fix_seed, Processed @@ -47,8 +47,12 @@ class Script(scripts.Script): # build world world = World(initial_payload=None, verify_remotes=verify_remotes) # add workers to the world - for worker in cmd_opts.distributed_remotes: - world.add_worker(uuid=worker[0], address=worker[1], port=worker[2]) + # make sure arguments aren't missing + if cmd_opts.distributed_remotes is not None and len(cmd_opts.distributed_remotes) > 0: + for worker in cmd_opts.distributed_remotes: + world.add_worker(uuid=worker[0], address=worker[1], port=worker[2]) + else: + logger.fatal(f"Found no worker info passed as arguments. Did you populate --distributed-remotes ?") def title(self): return "Distribute" @@ -64,6 +68,7 @@ class Script(scripts.Script): @staticmethod def add_to_gallery(processed, p): """adds generated images to the image gallery after waiting for all workers to finish""" + webui_state.textinfo = "Distributed - injecting images" def processed_inject_image(image, info_index, iteration: int, save_path_override=None, grid=False, response=None): image_params: json = response["parameters"] diff --git a/scripts/spartan/UI.py b/scripts/spartan/UI.py index 9ffa5fe..a6399b3 100644 --- a/scripts/spartan/UI.py +++ b/scripts/spartan/UI.py @@ -2,8 +2,9 @@ import os import subprocess from pathlib import Path import gradio -from scripts.spartan.shared import logger +from scripts.spartan.shared import logger, log_level from scripts.spartan.Worker import State +from modules.shared import state as webui_state class UI: @@ -46,6 +47,11 @@ class UI: def refresh_ckpts_btn(self): self.world.refresh_checkpoints() + def clear_queue_btn(self): + logger.debug(webui_state.__dict__) + webui_state.end() + + def status_btn(self): worker_status = '' workers = self.world.get_workers() @@ -104,6 +110,11 @@ class UI: redo_benchmarks_btn.style(full_width=False) redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[]) + if log_level == 'DEBUG': + clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop') + clear_queue_btn.style(full_width=False) + clear_queue_btn.click(self.clear_queue_btn) + with gradio.Tab('Settings'): thin_client_cbx = gradio.Checkbox( label='Thin-client mode (experimental)', diff --git a/scripts/spartan/Worker.py b/scripts/spartan/Worker.py index 83d8475..56431fd 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/Worker.py @@ -281,7 +281,7 @@ class Worker: free_vram = int(memory_response['free']) / (1024 * 1024 * 1024) total_vram = int(memory_response['total']) / (1024 * 1024 * 1024) logger.debug(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n") - self.free_vram = bytes(memory_response['free']) + self.free_vram = memory_response['free'] except KeyError: error = memory_response['cuda']['error'] logger.debug(f"CUDA doesn't seem to be available for worker '{self.uuid}'\nError: {error}") diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index 2a17061..3444629 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -170,7 +170,7 @@ class World: """ Attempts to benchmark all workers a part of the world. """ - global benchmark_payload + from scripts.spartan.shared import benchmark_payload workers_info: dict = {} saved: bool = os.path.exists(self.worker_info_path) @@ -178,11 +178,18 @@ class World: benchmark_threads = [] def benchmark_wrapped(worker): + logger.critical(f"benchmark payload is: {benchmark_payload}") bench_func = worker.benchmark if not worker.master else self.benchmark_master worker.avg_ipm = bench_func() worker.benchmarked = True if rebenchmark: + if saved: + with open(self.worker_info_path, 'r') as worker_info_file: + workers_info = json.load(worker_info_file) + benchmark_payload = workers_info['benchmark_payload'] + logger.info(f"Using saved benchmark config:\n{benchmark_payload}") + saved = False workers = self.get_workers() @@ -194,6 +201,8 @@ class World: with open(self.worker_info_path, 'r') as worker_info_file: try: workers_info = json.load(worker_info_file) + benchmark_payload = workers_info['benchmark_payload'] + logger.info(f"Using saved benchmark config:\n{benchmark_payload}") except json.JSONDecodeError: logger.error(f"workers.json is not valid JSON, regenerating") rebenchmark = True @@ -237,7 +246,6 @@ class World: logger.info(self.speed_summary()) - def get_current_output_size(self) -> int: """ returns how many images would be returned from all jobs