diff --git a/scripts/distributed.py b/scripts/distributed.py index b3c10f6..9ff6f7b 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -116,19 +116,14 @@ class DistributedScript(scripts.Script): p.negative_prompts.extend(all_negative_prompts) p.prompts.extend(all_prompts) - num_local = self.world.p.n_iter * self.world.p.batch_size + opts.return_grid + num_local = self.world.p.n_iter * self.world.p.batch_size + (opts.return_grid - self.world.thin_client_mode) num_injected = len(pp.images) - self.world.p.batch_size for i, image in enumerate(images): # modules.ui_common -> update_generation_info renders to html below gallery - - # TODO probably shouldn't be here - if self.world.thin_client_mode: - p.all_negative_prompts = pp.all_negative_prompts - gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery job.gallery_map.append(gallery_index) # so we know where to edit infotext pp.images.append(image) - logger.debug(f"image {gallery_index + 1}/{self.world.num_gallery()}") + logger.debug(f"image {gallery_index + 1 + self.world.thin_client_mode}/{self.world.num_gallery()}") def update_gallery(self, pp, p): """adds all remotely generated images to the image gallery after waiting for all workers to finish""" @@ -328,13 +323,11 @@ class DistributedScript(scripts.Script): p.batch_size = self.world.master_job().batch_size self.master_start = time.time() - # generate images assigned to local machine - # p.do_not_save_grid = True # don't generate grid from master as we are doing this later. self.runs_since_init += 1 return def postprocess_batch_list(self, p, pp, *args, **kwargs): - if p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch + if not self.world.thin_client_mode and p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch return is_img2img = getattr(p, 'init_images', False) diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 28616b9..7bff6f3 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -609,6 +609,7 @@ class Worker: self.full_url("memory"), timeout=3 ) + self.response = response return response.status_code == 200 except requests.exceptions.ConnectionError as e: diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 3b432b9..92f153b 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -21,6 +21,9 @@ from .worker import Worker, State from modules.call_queue import wrap_queued_call, queue_lock from modules import processing from modules import progress +from modules.scripts import PostprocessBatchListArgs +from torchvision.transforms import ToPILImage +from modules.images import image_grid class NotBenchmarked(Exception): @@ -562,15 +565,31 @@ class World: # save original process_images_inner for later so we can restore once we're done logger.debug(f"bypassing local generation completely") def process_images_inner_bypass(p) -> processing.Processed: + p.seeds, p.subseeds, p.negative_prompts, p.prompts = [], [], [], [] + pp = PostprocessBatchListArgs(images=[]) + self.p.scripts.postprocess_batch_list(p, pp) + processed = processing.Processed(p, [], p.seed, info="") - processed.all_prompts = [] - processed.all_seeds = [] - processed.all_subseeds = [] - processed.all_negative_prompts = [] - processed.infotexts = [] - processed.prompt = None + processed.all_prompts = p.prompts + processed.all_seeds = p.seeds + processed.all_subseeds = p.subseeds + processed.all_negative_prompts = p.negative_prompts + processed.images = pp.images + processed.infotexts = [''] * self.num_requested() + + transform = ToPILImage() + for i, image in enumerate(processed.images): + processed.images[i] = transform(image) + self.p.scripts.postprocess(p, processed) + + # generate grid if enabled + if shared.opts.return_grid and len(processed.images) > 1: + grid = image_grid(processed.images, len(processed.images)) + processed.images.insert(0, grid) + processed.infotexts.insert(0, processed.infotexts[0]) + return processed processing.process_images_inner = process_images_inner_bypass @@ -699,7 +718,7 @@ class World: ) with open(self.config_path, 'w+') as config_file: - config_file.write(config.model_dump_json(indent=3)) + config_file.write(config.json(indent=3)) logger.debug(f"config saved") def ping_remotes(self, indiscriminate: bool = False): @@ -749,6 +768,9 @@ class World: worker.set_state(State.IDLE, expect_cycle=True) else: msg = f"worker '{worker.label}' is unreachable" + if worker.response.status_code is not None: + msg += f" <{worker.response.status_code}>" + logger.info(msg) gradio.Warning("Distributed: "+msg) worker.set_state(State.UNAVAILABLE)