diff --git a/scripts/distributed.py b/scripts/distributed.py index b771250..59f755a 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -24,13 +24,12 @@ from modules.shared import state as webui_state from scripts.spartan.control_net import pack_control_net from scripts.spartan.shared import logger from scripts.spartan.ui import UI -from scripts.spartan.world import World, WorldAlreadyInitialized +from scripts.spartan.world import World old_sigint_handler = signal.getsignal(signal.SIGINT) old_sigterm_handler = signal.getsignal(signal.SIGTERM) -# TODO implement advertisement of some sort in sdwui api to allow extension to automatically discover workers? # noinspection PyMissingOrEmptyDocstring class Script(scripts.Script): worker_threads: List[Thread] = [] @@ -50,17 +49,8 @@ class Script(scripts.Script): urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # build world - world = World(initial_payload=None, verify_remotes=verify_remotes) - # add workers to the world + world = World(verify_remotes=verify_remotes) world.load_config() - if cmd_opts.distributed_remotes is not None and len(cmd_opts.distributed_remotes) > 0: - logger.warning(f"--distributed-remotes is deprecated and may be removed in the future\n" - f"gui/external modification of {world.config_path} will be prioritized going forward") - - for worker in cmd_opts.distributed_remotes: - world.add_worker(uuid=worker[0], address=worker[1], port=worker[2], tls=False) - world.save_config() - # do an early check to see which workers are online logger.info("doing initial ping sweep to see which workers are reachable") world.ping_remotes(indiscriminate=True) @@ -71,7 +61,7 @@ class Script(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - self.world.load_config() + Script.world.load_config() extension_ui = UI(script=Script, world=Script.world) # root, api_exposed = extension_ui.create_ui() components = extension_ui.create_ui() @@ -128,7 +118,7 @@ class Script(scripts.Script): if p.n_iter > 1: # if splitting by batch count num_remote_images *= p.n_iter - 1 - logger.debug(f"image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, " + logger.debug(f"image {true_image_pos + 1}/{Script.world.p.batch_size * p.n_iter}, " f"info-index: {info_index}") if Script.world.thin_client_mode: @@ -225,30 +215,15 @@ class Script(scripts.Script): p.batch_size = len(processed.images) return - @staticmethod - def initialize(initial_payload): - # get default batch size - try: - batch_size = initial_payload.batch_size - except AttributeError: - batch_size = 1 - - try: - Script.world.initialize(batch_size) - Script.world.initial_payload = initial_payload - logger.debug(f"World initialized!") - except WorldAlreadyInitialized: - Script.world.update_world(total_batch_size=batch_size) # p's type is - # "modules.processing.StableDiffusionProcessingTxt2Img" - def before_process(self, p, *args): - if not self.world.enabled: + # "modules.processing.StableDiffusionProcessing*" + @staticmethod + def before_process(p, *args): + if not Script.world.enabled: logger.debug("extension is disabled") return - - current_thread().name = "distributed_main" - Script.initialize(initial_payload=p) + Script.world.update(p) # save original process_images_inner function for later if we monkeypatch it Script.original_process_images_inner = processing.process_images_inner @@ -367,22 +342,12 @@ class Script(scripts.Script): started_jobs.append(job) # if master batch size was changed again due to optimization change it to the updated value - if not self.world.thin_client_mode: + if not Script.world.thin_client_mode: p.batch_size = Script.world.master_job().batch_size Script.master_start = time.time() # generate images assigned to local machine p.do_not_save_grid = True # don't generate grid from master as we are doing this later. - if Script.world.thin_client_mode: - p.batch_size = 0 - processed = Processed(p=p, images_list=[]) - processed.all_prompts = [] - processed.all_seeds = [] - processed.all_subseeds = [] - processed.all_negative_prompts = [] - processed.infotexts = [] - processed.prompt = None - Script.runs_since_init += 1 return diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 2dc8149..39cf884 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -1,4 +1,3 @@ -import asyncio import base64 import copy import io diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 494c6fe..5156cc6 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -30,13 +30,6 @@ class NotBenchmarked(Exception): pass -class WorldAlreadyInitialized(Exception): - """ - Raised when attempting to initialize the World when it has already been initialized. - """ - pass - - class Job: """ Keeps track of how much work a given worker should handle. @@ -78,7 +71,7 @@ class World: The frame or "world" which holds all workers (including the local machine). Args: - initial_payload: The original txt2img payload created by the user initiating the generation request on master. + p: The original processing state object created by the user initiating the generation request on master. verify_remotes (bool): Whether to validate remote worker certificates. """ @@ -86,15 +79,14 @@ class World: config_path = shared.cmd_opts.distributed_config old_config_path = worker_info_path = extension_path.joinpath('workers.json') - def __init__(self, initial_payload, verify_remotes: bool = True): + def __init__(self, verify_remotes: bool = True): + self.p = None self.master_worker = Worker(master=True) - self.total_batch_size: int = 0 self._workers: List[Worker] = [self.master_worker] self.jobs: List[Job] = [] self.job_timeout: int = 3 # seconds self.initialized: bool = False self.verify_remotes = verify_remotes - self.initial_payload = copy.copy(initial_payload) self.thin_client_mode = False self.enabled = True self.is_dropdown_handler_injected = False @@ -109,32 +101,12 @@ class World: def __repr__(self): return f"{len(self._workers)} workers" - def update_world(self, total_batch_size): - """ - Updates the world with information vital to handling the local generation request after - the world has already been initialized. - - Args: - total_batch_size (int): The total number of images requested by the local/master sdwui instance. - """ - - self.total_batch_size = total_batch_size - self.update_jobs() - - def initialize(self, total_batch_size): - """should be called before a world instance is used for anything""" - if self.initialized: - raise WorldAlreadyInitialized("This world instance was already initialized") - - self.benchmark() - self.update_world(total_batch_size=total_batch_size) - self.initialized = True def default_batch_size(self) -> int: """the amount of images/total images requested that a worker would compute if conditions were perfect and each worker generated at the same speed. assumes one batch only""" - return self.total_batch_size // self.size() + return self.p.batch_size // self.size() def size(self) -> int: """ @@ -396,6 +368,18 @@ class World: self.jobs.append(Job(worker=worker, batch_size=batch_size)) + def update(self, p): + """preps world for another run""" + if not self.initialized: + self.benchmark() + self.initialized = True + logger.debug("world initialized!") + else: + logger.debug("world was already initialized") + + self.p = p + self.update_jobs() + def get_workers(self): filtered: List[Worker] = [] for worker in self._workers: @@ -431,7 +415,7 @@ class World: logger.debug(f"worker '{job.worker.label}' would stall the image gallery by ~{lag:.2f}s\n") job.complementary = True - if deferred_images + images_checked + payload['batch_size'] > self.total_batch_size: + if deferred_images + images_checked + payload['batch_size'] > self.p.batch_size: logger.debug(f"would go over actual requested size") else: deferred_images += payload['batch_size'] @@ -474,9 +458,9 @@ class World: ####################### # when total number of requested images was not cleanly divisible by world size then we tack the remainder on - remainder_images = self.total_batch_size - self.get_current_output_size() + remainder_images = self.p.batch_size - self.get_current_output_size() if remainder_images >= 1: - logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed") + logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed") realtime_jobs = self.realtime_jobs() realtime_jobs.sort(key=lambda x: x.batch_size) @@ -555,9 +539,9 @@ class World: iterations = payload['n_iter'] num_returning = self.get_current_output_size() - num_complementary = num_returning - self.total_batch_size + num_complementary = num_returning - self.p.batch_size distro_summary = "Job distribution:\n" - distro_summary += f"{self.total_batch_size} * {iterations} iteration(s)" + distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)" if num_complementary > 0: distro_summary += f" + {num_complementary} complementary" distro_summary += f": {num_returning} images total\n" @@ -577,7 +561,7 @@ class World: processed.infotexts = [] processed.prompt = None - self.initial_payload.scripts.postprocess(p, processed) + self.p.scripts.postprocess(p, processed) return processed processing.process_images_inner = process_images_inner_bypass