diff --git a/scripts/distributed.py b/scripts/distributed.py index 524e73b..839d93b 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -13,6 +13,7 @@ import time from threading import Thread from typing import List import gradio +from torchvision.transforms import ToTensor import urllib3 from PIL import Image from modules import processing @@ -71,7 +72,7 @@ class DistributedScript(scripts.Script): # return some components that should be exposed to the api return components - def add_to_gallery(self, processed, p): + def add_to_gallery(self, pp, p): """adds generated images to the image gallery after waiting for all workers to finish""" def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None): @@ -101,17 +102,20 @@ class DistributedScript(scripts.Script): processed_inject_image(image=image, info_index=0, response=response) return - processed.all_seeds.append(seed) - processed.all_subseeds.append(subseed) - processed.all_negative_prompts.append(negative_prompt) - processed.all_prompts.append(pos_prompt) - processed.images.append(image) # actual received image + p.seeds.append(seed) + p.subseeds.append(subseed) + p.negative_prompts.append(negative_prompt) + p.prompts.append(pos_prompt) + + transform = ToTensor() + pp.images.append(transform(image)) # actual received image + # generate info-text string # modules.ui_common -> update_generation_info renders to html below gallery images_per_batch = p.n_iter * p.batch_size # zero-indexed position of image in total batch (so including master results) - true_image_pos = len(processed.images) - 1 + true_image_pos = len(pp.images) - 1 num_remote_images = images_per_batch * p.batch_size if p.n_iter > 1: # if splitting by batch count num_remote_images *= p.n_iter - 1 @@ -120,28 +124,28 @@ class DistributedScript(scripts.Script): f"info-index: {info_index}") if self.world.thin_client_mode: - p.all_negative_prompts = processed.all_negative_prompts + p.all_negative_prompts = pp.all_negative_prompts - try: - info_text = image_info_post['infotexts'][i] - except IndexError: - if not grid: - logger.warning(f"image {true_image_pos + 1} was missing info-text") - info_text = processed.infotexts[0] - info_text += f", Worker Label: {job.worker.label}" - processed.infotexts.append(info_text) + # try: + # info_text = image_info_post['infotexts'][i] + # except IndexError: + # if not grid: + # logger.warning(f"image {true_image_pos + 1} was missing info-text") + # info_text = processed.infotexts[0] + # info_text += f", Worker Label: {job.worker.label}" + # processed.infotexts.append(info_text) # automatically save received image to local disk if desired - if cmd_opts.distributed_remotes_autosave: - save_image( - image=image, - path=p.outpath_samples if save_path_override is None else save_path_override, - basename="", - seed=seed, - prompt=pos_prompt, - info=info_text, - extension=opts.samples_format - ) + #if cmd_opts.distributed_remotes_autosave: + # save_image( + # image=image, + # path=p.outpath_samples if save_path_override is None else save_path_override, + # basename="", + # seed=seed, + # prompt=pos_prompt, + # info=info_text, + # extension=opts.samples_format + # ) # get master ipm by estimating based on worker speed master_elapsed = time.time() - self.master_start @@ -197,22 +201,11 @@ class DistributedScript(scripts.Script): logger.critical("couldn't collect any responses, the extension will have no effect") return - # generate and inject grid - if opts.return_grid and len(processed.images) > 1: - grid = image_grid(processed.images, len(processed.images)) - processed_inject_image( - image=grid, - info_index=0, - save_path_override=p.outpath_grids, - grid=True, - response=donor_worker.response - ) - # cleanup after we're doing using all the responses for worker in self.world.get_workers(): worker.response = None - p.batch_size = len(processed.images) + p.batch_size = len(pp.images) return # p's type is @@ -351,11 +344,26 @@ class DistributedScript(scripts.Script): self.master_start = time.time() # generate images assigned to local machine - p.do_not_save_grid = True # don't generate grid from master as we are doing this later. + # p.do_not_save_grid = True # don't generate grid from master as we are doing this later. self.runs_since_init += 1 return - def postprocess(self, p, processed, *args): + # def postprocess(self, p, processed, *args): + # is_img2img = getattr(p, 'init_images', False) + # if is_img2img and self.world.enabled_i2i is False: + # return + # elif not is_img2img and self.world.enabled is False: + # return + + # if self.master_start is not None: + # self.add_to_gallery(p=p, processed=processed) + + # # restore process_images_inner if it was monkey-patched + # processing.process_images_inner = self.original_process_images_inner + # # save any dangling state to prevent load_config in next iteration overwriting it + # self.world.save_config() + + def postprocess_batch_list(self, p, pp, *args, **kwargs): is_img2img = getattr(p, 'init_images', False) if is_img2img and self.world.enabled_i2i is False: return @@ -363,7 +371,7 @@ class DistributedScript(scripts.Script): return if self.master_start is not None: - self.add_to_gallery(p=p, processed=processed) + self.add_to_gallery(p=p, pp=pp) # restore process_images_inner if it was monkey-patched processing.process_images_inner = self.original_process_images_inner