diff --git a/scripts/distributed.py b/scripts/distributed.py index c1d68ae..059d9e2 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -75,9 +75,9 @@ class DistributedScript(scripts.Script): def add_to_gallery(self, pp, p): """adds generated images to the image gallery after waiting for all workers to finish""" - def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None): - image_params: json = response['parameters'] - image_info_post: json = json.loads(response["info"]) # image info known after processing + def processed_inject_image(image, info_index, save_path_override=None, grid=False, job=None): + image_params: json = job.worker.response['parameters'] + image_info_post: json = json.loads(job.worker.response["info"]) # image info known after processing num_response_images = image_params["batch_size"] * image_params["n_iter"] seed = None @@ -99,7 +99,7 @@ class DistributedScript(scripts.Script): except IndexError: # like with controlnet masks, there isn't always full post-gen info, so we use the first images' logger.debug(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data") - processed_inject_image(image=image, info_index=0, response=response) + processed_inject_image(image=image, info_index=0, job=job) return p.seeds.append(seed) @@ -110,12 +110,10 @@ class DistributedScript(scripts.Script): transform = ToTensor() pp.images.append(transform(image)) # actual received image - - # generate info-text string # modules.ui_common -> update_generation_info renders to html below gallery images_per_batch = p.n_iter * p.batch_size - # zero-indexed position of image in total batch (so including master results) - true_image_pos = len(pp.images) - 1 + # zero-indexed position of image in gallery (so including master/local results) + true_image_pos = (len(pp.images) - 1) + (not p.do_not_save_grid) num_remote_images = images_per_batch * p.batch_size if p.n_iter > 1: # if splitting by batch count num_remote_images *= p.n_iter - 1 @@ -126,26 +124,8 @@ class DistributedScript(scripts.Script): if self.world.thin_client_mode: p.all_negative_prompts = pp.all_negative_prompts - # try: - # info_text = image_info_post['infotexts'][i] - # except IndexError: - # if not grid: - # logger.warning(f"image {true_image_pos + 1} was missing info-text") - # info_text = processed.infotexts[0] - # info_text += f", Worker Label: {job.worker.label}" - # processed.infotexts.append(info_text) - - # automatically save received image to local disk if desired - #if cmd_opts.distributed_remotes_autosave: - # save_image( - # image=image, - # path=p.outpath_samples if save_path_override is None else save_path_override, - # basename="", - # seed=seed, - # prompt=pos_prompt, - # info=info_text, - # extension=opts.samples_format - # ) + # saves final position of image in gallery so that we can later modify the correct infotext + job.gallery_map.append(true_image_pos) # get master ipm by estimating based on worker speed master_elapsed = time.time() - self.master_start @@ -195,15 +175,12 @@ class DistributedScript(scripts.Script): image = Image.open(io.BytesIO(image_bytes)) # inject image - processed_inject_image(image=image, info_index=i, response=job.worker.response) + processed_inject_image(image=image, info_index=i, job=job) if donor_worker is None: logger.critical("couldn't collect any responses, the extension will have no effect") return - # cleanup after we're doing using all the responses - for worker in self.world.get_workers(): - worker.response = None p.batch_size = len(pp.images) webui_state.textinfo = "" @@ -357,21 +334,6 @@ class DistributedScript(scripts.Script): self.runs_since_init += 1 return - # def postprocess(self, p, processed, *args): - # is_img2img = getattr(p, 'init_images', False) - # if is_img2img and self.world.enabled_i2i is False: - # return - # elif not is_img2img and self.world.enabled is False: - # return - - # if self.master_start is not None: - # self.add_to_gallery(p=p, processed=processed) - - # # restore process_images_inner if it was monkey-patched - # processing.process_images_inner = self.original_process_images_inner - # # save any dangling state to prevent load_config in next iteration overwriting it - # self.world.save_config() - def postprocess_batch_list(self, p, pp, *args, **kwargs): is_img2img = getattr(p, 'init_images', False) if is_img2img and self.world.enabled_i2i is False: @@ -382,6 +344,20 @@ class DistributedScript(scripts.Script): if self.master_start is not None: self.add_to_gallery(p=p, pp=pp) + + def postprocess(self, p, processed, *args): + for job in self.world.jobs: + if job.worker.response is not None: + for i, v in enumerate(job.gallery_map): + infotext = json.loads(job.worker.response['info'])['infotexts'][i] + logger.debug(f"replacing image {v}'s infotext with\n" + f"> '{infotext}'") + infotext += f", Worker Label: {job.worker.label}" + processed.infotexts[v] = infotext + + # cleanup + for worker in self.world.get_workers(): + worker.response = None # restore process_images_inner if it was monkey-patched processing.process_images_inner = self.original_process_images_inner # save any dangling state to prevent load_config in next iteration overwriting it diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 73cc742..fc6e4cd 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -46,6 +46,7 @@ class Job: self.complementary: bool = False self.step_override = None self.thread = None + self.gallery_map: List[int] = [] def __str__(self): prefix = ''