diff --git a/scripts/distributed.py b/scripts/distributed.py index 07b9b8f..d361507 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -25,7 +25,7 @@ from modules.shared import state as webui_state from scripts.spartan.control_net import pack_control_net from scripts.spartan.shared import logger from scripts.spartan.ui import UI -from scripts.spartan.world import World, State +from scripts.spartan.world import World, State, Job old_sigint_handler = signal.getsignal(signal.SIGINT) old_sigterm_handler = signal.getsignal(signal.SIGTERM) @@ -72,7 +72,7 @@ class DistributedScript(scripts.Script): # return some components that should be exposed to the api return components - def api_to_internal(self, job): + def api_to_internal(self, job) -> ([], [], [], [], []): # takes worker response received from api and returns parsed objects in internal sdwui format. E.g. all_seeds image_params: json = job.worker.response['parameters'] @@ -107,7 +107,8 @@ class DistributedScript(scripts.Script): return all_seeds, all_subseeds, all_negative_prompts, all_prompts, images - def processed_inject_image(self, job, p, pp): + def inject_job(self, job: Job, p, pp): + """Adds the work completed by one Job via its worker response to the processing and postprocessing objects""" all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = self.api_to_internal(job) p.seeds.extend(all_seeds) @@ -115,7 +116,7 @@ class DistributedScript(scripts.Script): p.negative_prompts.extend(all_negative_prompts) p.prompts.extend(all_prompts) - for i, image in enumerate(images): + for image in images: pp.images.append(image) # add one image to the gallery # modules.ui_common -> update_generation_info renders to html below gallery @@ -126,6 +127,7 @@ class DistributedScript(scripts.Script): if p.n_iter > 1: # if splitting by batch count num_remote_images *= p.n_iter - 1 + # TODO slightly off from changes logger.debug(f"image {true_image_pos + 1}/{(self.world.p.batch_size * p.n_iter) + (not p.do_not_save_grid) + 1}, " f"info-index: fix me") @@ -135,8 +137,8 @@ class DistributedScript(scripts.Script): # saves final position of image in gallery so that we can later modify the correct infotext job.gallery_map.append(true_image_pos) - def add_to_gallery(self, pp, p): - """adds generated images to the image gallery after waiting for all workers to finish""" + def update_gallery(self, pp, p): + """adds all remotely generated images to the image gallery after waiting for all workers to finish""" # get master ipm by estimating based on worker speed master_elapsed = time.time() - self.master_start @@ -153,8 +155,7 @@ class DistributedScript(scripts.Script): logger.debug("all worker request threads returned") webui_state.textinfo = "Distributed - injecting images" - # some worker which we know has a good response that we can use for generating the grid - donor_worker = None + received_images = False for job in self.world.jobs: if job.worker.response is None or job.batch_size < 1 or job.worker.master: continue @@ -165,8 +166,7 @@ class DistributedScript(scripts.Script): if (job.batch_size * p.n_iter) < len(images): logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}") - if donor_worker is None: - donor_worker = job.worker + received_images = True except KeyError: if job.batch_size > 0: logger.warning(f"Worker '{job.worker.label}' had no images") @@ -180,20 +180,14 @@ class DistributedScript(scripts.Script): logger.exception(e) continue - # visibly add work from workers to the image gallery - # for i in range(0, len(images)): - # image_bytes = base64.b64decode(images[i]) - # image = Image.open(io.BytesIO(image_bytes)) + # adding the images in + self.inject_job(job, p, pp) - # # inject image - # self.processed_inject_image(image=image, info_index=i, job=job, p=p, pp=pp) - self.processed_inject_image(job, p, pp) - - if donor_worker is None: + # TODO fix controlnet masks returned via api having no generation info + if received_images is False: logger.critical("couldn't collect any responses, the extension will have no effect") return - p.batch_size = len(pp.images) webui_state.textinfo = "" return @@ -354,7 +348,7 @@ class DistributedScript(scripts.Script): return if self.master_start is not None: - self.add_to_gallery(p=p, pp=pp) + self.update_gallery(p=p, pp=pp) def postprocess(self, p, processed, *args):