From 0ac23cbd73b8cd9bdd607d4d0d55fa6236f73eb7 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 9 May 2024 01:44:45 -0500 Subject: [PATCH] refactoring --- scripts/distributed.py | 110 ++++++++++++++++++++--------------------- 1 file changed, 54 insertions(+), 56 deletions(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index 59f755a..ec92e0c 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -32,27 +32,32 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM) # noinspection PyMissingOrEmptyDocstring class Script(scripts.Script): - worker_threads: List[Thread] = [] - # Whether to verify worker certificates. Can be useful if your remotes are self-signed. - verify_remotes = not cmd_opts.distributed_skip_verify_remotes - is_img2img = True - is_txt2img = True - alwayson = True - master_start = None - runs_since_init = 0 - name = "distributed" - is_dropdown_handler_injected = False + def __init__(self): + super().__init__() + self.worker_threads: List[Thread] = [] + # Whether to verify worker certificates. Can be useful if your remotes are self-signed. + self.verify_remotes = not cmd_opts.distributed_skip_verify_remotes + self.is_img2img = True + self.is_txt2img = True + self.alwayson = True + self.master_start = None + self.runs_since_init = 0 + self.name = "distributed" + self.is_dropdown_handler_injected = False - if verify_remotes is False: - logger.warning(f"You have chosen to forego the verification of worker TLS certificates") - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + if self.verify_remotes is False: + logger.warning(f"You have chosen to forego the verification of worker TLS certificates") + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - # build world - world = World(verify_remotes=verify_remotes) - world.load_config() - logger.info("doing initial ping sweep to see which workers are reachable") - world.ping_remotes(indiscriminate=True) + # build world + self.world = World(verify_remotes=self.verify_remotes) + self.world.load_config() + logger.info("doing initial ping sweep to see which workers are reachable") + self.world.ping_remotes(indiscriminate=True) + + signal.signal(signal.SIGINT, self.signal_handler) + signal.signal(signal.SIGTERM, self.signal_handler) def title(self): return "Distribute" @@ -61,19 +66,18 @@ class Script(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - Script.world.load_config() - extension_ui = UI(script=Script, world=Script.world) + self.world.load_config() + extension_ui = UI(script=Script, world=self.world) # root, api_exposed = extension_ui.create_ui() components = extension_ui.create_ui() # The first injection of handler for the models dropdown(sd_model_checkpoint) which is often present # in the quick-settings bar of a user. Helps ensure model swaps propagate to all nodes ASAP. - Script.world.inject_model_dropdown_handler() + self.world.inject_model_dropdown_handler() # return some components that should be exposed to the api return components - @staticmethod - def add_to_gallery(processed, p): + def add_to_gallery(self, processed, p): """adds generated images to the image gallery after waiting for all workers to finish""" def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None): @@ -118,10 +122,10 @@ class Script(scripts.Script): if p.n_iter > 1: # if splitting by batch count num_remote_images *= p.n_iter - 1 - logger.debug(f"image {true_image_pos + 1}/{Script.world.p.batch_size * p.n_iter}, " + logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, " f"info-index: {info_index}") - if Script.world.thin_client_mode: + if self.world.thin_client_mode: p.all_negative_prompts = processed.all_negative_prompts try: @@ -146,21 +150,21 @@ class Script(scripts.Script): ) # get master ipm by estimating based on worker speed - master_elapsed = time.time() - Script.master_start + master_elapsed = time.time() - self.master_start logger.debug(f"Took master {master_elapsed:.2f}s") # wait for response from all workers webui_state.textinfo = "Distributed - receiving results" - for thread in Script.worker_threads: + for thread in self.worker_threads: logger.debug(f"waiting for worker thread '{thread.name}'") thread.join() - Script.worker_threads.clear() + self.worker_threads.clear() logger.debug("all worker request threads returned") webui_state.textinfo = "Distributed - injecting images" # some worker which we know has a good response that we can use for generating the grid donor_worker = None - for job in Script.world.jobs: + for job in self.world.jobs: if job.batch_size < 1 or job.worker.master: continue @@ -209,7 +213,7 @@ class Script(scripts.Script): ) # cleanup after we're doing using all the responses - for worker in Script.world.get_workers(): + for worker in self.world.get_workers(): worker.response = None p.batch_size = len(processed.images) @@ -218,15 +222,14 @@ class Script(scripts.Script): # p's type is # "modules.processing.StableDiffusionProcessing*" - @staticmethod - def before_process(p, *args): - if not Script.world.enabled: + def before_process(self, p, *args): + if not self.world.enabled: logger.debug("extension is disabled") return - Script.world.update(p) + self.world.update(p) # save original process_images_inner function for later if we monkeypatch it - Script.original_process_images_inner = processing.process_images_inner + self.original_process_images_inner = processing.process_images_inner # strip scripts that aren't yet supported and warn user packed_script_args: List[dict] = [] # list of api formatted per-script argument objects @@ -262,7 +265,7 @@ class Script(scripts.Script): # encapsulating the request object within a txt2imgreq object is deprecated and no longer works # see test/basic_features/txt2img_test.py for an example payload = copy.copy(p.__dict__) - payload['batch_size'] = Script.world.default_batch_size() + payload['batch_size'] = self.world.default_batch_size() payload['scripts'] = None try: del payload['script_args'] @@ -291,11 +294,11 @@ class Script(scripts.Script): # start generating images assigned to remote machines sync = False # should only really need to sync once per job - Script.world.optimize_jobs(payload) # optimize work assignment before dispatching + self.world.optimize_jobs(payload) # optimize work assignment before dispatching started_jobs = [] # check if anything even needs to be done - if len(Script.world.jobs) == 1 and Script.world.jobs[0].worker.master: + if len(self.world.jobs) == 1 and self.world.jobs[0].worker.master: if payload['batch_size'] >= 2: msg = f"all remote workers are offline or unreachable" @@ -306,7 +309,7 @@ class Script(scripts.Script): return - for job in Script.world.jobs: + for job in self.world.jobs: payload_temp = copy.copy(payload) del payload_temp['scripts_value'] payload_temp = copy.deepcopy(payload_temp) @@ -338,35 +341,33 @@ class Script(scripts.Script): name=f"{job.worker.label}_request") t.start() - Script.worker_threads.append(t) + self.worker_threads.append(t) started_jobs.append(job) # if master batch size was changed again due to optimization change it to the updated value - if not Script.world.thin_client_mode: - p.batch_size = Script.world.master_job().batch_size - Script.master_start = time.time() + if not self.world.thin_client_mode: + p.batch_size = self.world.master_job().batch_size + self.master_start = time.time() # generate images assigned to local machine p.do_not_save_grid = True # don't generate grid from master as we are doing this later. - Script.runs_since_init += 1 + self.runs_since_init += 1 return - @staticmethod - def postprocess(p, processed, *args): - if not Script.world.enabled: + def postprocess(self, p, processed, *args): + if not self.world.enabled: return - if Script.master_start is not None: - Script.add_to_gallery(p=p, processed=processed) + if self.master_start is not None: + self.add_to_gallery(p=p, processed=processed) # restore process_images_inner if it was monkey-patched - processing.process_images_inner = Script.original_process_images_inner + processing.process_images_inner = self.original_process_images_inner - @staticmethod - def signal_handler(sig, frame): + def signal_handler(self, sig, frame): logger.debug("handling interrupt signal") # do cleanup - Script.world.save_config() + self.world.save_config() if sig == signal.SIGINT: if callable(old_sigint_handler): @@ -376,6 +377,3 @@ class Script(scripts.Script): old_sigterm_handler(sig, frame) else: sys.exit(0) - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler)