diff --git a/CHANGELOG.md b/CHANGELOG.md index 83f3d94..82459e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,23 @@ # Change Log Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html) +## [2.3.0] - 2024-10-26 + +## Added +- Compatibility for some extensions which mostly only do postprocessing (e.g. Adetailer) +- Separate toggle state for img2img tab so txt2img can be enabled and t2i disabled or vice versa + +## Changed +- Status tab will now automatically refresh +- Main toggle is now in the form of an InputAccordion + +## Fixed +- An issue affecting controlnet and inpainting +- Toggle state sometimes desyncing when the page was refreshed + ## [2.2.2] - 2024-8-30 ### Fixed - - Unavailable state sometimes being ignored ## [2.2.1] - 2024-5-16 @@ -82,4 +95,4 @@ Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic - Worker randomly disconnecting when under high load due to handling a previous request ### Removed -- Certain superfluous warnings in logs related to third party extensions \ No newline at end of file +- Certain superfluous warnings in logs related to third party extensions diff --git a/javascript/distributed.js b/javascript/distributed.js index 5827386..ba54cae 100644 --- a/javascript/distributed.js +++ b/javascript/distributed.js @@ -1,4 +1,23 @@ function confirm_restart_workers(_) { return confirm('Restart remote workers?') -} \ No newline at end of file +} + +// live updates +function update() { + try { + let currentTab = get_uiCurrentTabContent() + let buttons = document.querySelectorAll('#distributed-refresh-status') + for(let i = 0; i < buttons.length; i++) { + if(currentTab.contains(buttons[i])) { + buttons[i].click() + break + } + } + } catch (e) { + if (!(e instanceof TypeError)) { + throw e + } + } +} +setInterval(update, 1500) \ No newline at end of file diff --git a/scripts/distributed.py b/scripts/distributed.py index 492d4d4..9ff6f7b 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -13,6 +13,7 @@ import time from threading import Thread from typing import List import gradio +from torchvision.transforms import ToTensor import urllib3 from PIL import Image from modules import processing @@ -24,7 +25,7 @@ from modules.shared import state as webui_state from scripts.spartan.control_net import pack_control_net from scripts.spartan.shared import logger from scripts.spartan.ui import UI -from scripts.spartan.world import World, State +from scripts.spartan.world import World, State, Job old_sigint_handler = signal.getsignal(signal.SIGINT) old_sigterm_handler = signal.getsignal(signal.SIGTERM) @@ -61,7 +62,7 @@ class DistributedScript(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - extension_ui = UI(world=self.world) + extension_ui = UI(world=self.world, is_img2img=is_img2img) # root, api_exposed = extension_ui.create_ui() components = extension_ui.create_ui() @@ -71,77 +72,61 @@ class DistributedScript(scripts.Script): # return some components that should be exposed to the api return components - def add_to_gallery(self, processed, p): - """adds generated images to the image gallery after waiting for all workers to finish""" + def api_to_internal(self, job) -> ([], [], [], [], []): + # takes worker response received from api and returns parsed objects in internal sdwui format. E.g. all_seeds - def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None): - image_params: json = response['parameters'] - image_info_post: json = json.loads(response["info"]) # image info known after processing - num_response_images = image_params["batch_size"] * image_params["n_iter"] - - seed = None - subseed = None - negative_prompt = None - pos_prompt = None + image_params: json = job.worker.response['parameters'] + image_info_post: json = json.loads(job.worker.response["info"]) # image info known after processing + all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = [], [], [], [], [] + for i in range(len(job.worker.response["images"])): try: - if num_response_images > 1: - seed = image_info_post['all_seeds'][info_index] - subseed = image_info_post['all_subseeds'][info_index] - negative_prompt = image_info_post['all_negative_prompts'][info_index] - pos_prompt = image_info_post['all_prompts'][info_index] - else: - seed = image_info_post['seed'] - subseed = image_info_post['subseed'] - negative_prompt = image_info_post['negative_prompt'] - pos_prompt = image_info_post['prompt'] + if image_params["batch_size"] * image_params["n_iter"] > 1: + all_seeds.append(image_info_post['all_seeds'][i]) + all_subseeds.append(image_info_post['all_subseeds'][i]) + all_negative_prompts.append(image_info_post['all_negative_prompts'][i]) + all_prompts.append(image_info_post['all_prompts'][i]) + else: # only a single image received + all_seeds.append(image_info_post['seed']) + all_subseeds.append(image_info_post['subseed']) + all_negative_prompts.append(image_info_post['negative_prompt']) + all_prompts.append(image_info_post['prompt']) except IndexError: - # like with controlnet masks, there isn't always full post-gen info, so we use the first images' - logger.debug(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data") - processed_inject_image(image=image, info_index=0, response=response) - return + # # like with controlnet masks, there isn't always full post-gen info, so we use the first images' + # logger.debug(f"Image at index {info_index} for '{job.worker.label}' was missing some post-generation data") + # self.processed_inject_image(image=image, info_index=0, job=job, p=p) + # return + logger.critical(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data") + continue - processed.all_seeds.append(seed) - processed.all_subseeds.append(subseed) - processed.all_negative_prompts.append(negative_prompt) - processed.all_prompts.append(pos_prompt) - processed.images.append(image) # actual received image + # parse image + image_bytes = base64.b64decode(job.worker.response["images"][i]) + image = Image.open(io.BytesIO(image_bytes)) + transform = ToTensor() + images.append(transform(image)) - # generate info-text string + return all_seeds, all_subseeds, all_negative_prompts, all_prompts, images + + def inject_job(self, job: Job, p, pp): + """Adds the work completed by one Job via its worker response to the processing and postprocessing objects""" + all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = self.api_to_internal(job) + + p.seeds.extend(all_seeds) + p.subseeds.extend(all_subseeds) + p.negative_prompts.extend(all_negative_prompts) + p.prompts.extend(all_prompts) + + num_local = self.world.p.n_iter * self.world.p.batch_size + (opts.return_grid - self.world.thin_client_mode) + num_injected = len(pp.images) - self.world.p.batch_size + for i, image in enumerate(images): # modules.ui_common -> update_generation_info renders to html below gallery - images_per_batch = p.n_iter * p.batch_size - # zero-indexed position of image in total batch (so including master results) - true_image_pos = len(processed.images) - 1 - num_remote_images = images_per_batch * p.batch_size - if p.n_iter > 1: # if splitting by batch count - num_remote_images *= p.n_iter - 1 + gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery + job.gallery_map.append(gallery_index) # so we know where to edit infotext + pp.images.append(image) + logger.debug(f"image {gallery_index + 1 + self.world.thin_client_mode}/{self.world.num_gallery()}") - logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, " - f"info-index: {info_index}") - - if self.world.thin_client_mode: - p.all_negative_prompts = processed.all_negative_prompts - - try: - info_text = image_info_post['infotexts'][i] - except IndexError: - if not grid: - logger.warning(f"image {true_image_pos + 1} was missing info-text") - info_text = processed.infotexts[0] - info_text += f", Worker Label: {job.worker.label}" - processed.infotexts.append(info_text) - - # automatically save received image to local disk if desired - if cmd_opts.distributed_remotes_autosave: - save_image( - image=image, - path=p.outpath_samples if save_path_override is None else save_path_override, - basename="", - seed=seed, - prompt=pos_prompt, - info=info_text, - extension=opts.samples_format - ) + def update_gallery(self, pp, p): + """adds all remotely generated images to the image gallery after waiting for all workers to finish""" # get master ipm by estimating based on worker speed master_elapsed = time.time() - self.master_start @@ -158,8 +143,7 @@ class DistributedScript(scripts.Script): logger.debug("all worker request threads returned") webui_state.textinfo = "Distributed - injecting images" - # some worker which we know has a good response that we can use for generating the grid - donor_worker = None + received_images = False for job in self.world.jobs: if job.worker.response is None or job.batch_size < 1 or job.worker.master: continue @@ -170,8 +154,7 @@ class DistributedScript(scripts.Script): if (job.batch_size * p.n_iter) < len(images): logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}") - if donor_worker is None: - donor_worker = job.worker + received_images = True except KeyError: if job.batch_size > 0: logger.warning(f"Worker '{job.worker.label}' had no images") @@ -185,41 +168,27 @@ class DistributedScript(scripts.Script): logger.exception(e) continue - # visibly add work from workers to the image gallery - for i in range(0, len(images)): - image_bytes = base64.b64decode(images[i]) - image = Image.open(io.BytesIO(image_bytes)) + # adding the images in + self.inject_job(job, p, pp) - # inject image - processed_inject_image(image=image, info_index=i, response=job.worker.response) - - if donor_worker is None: + # TODO fix controlnet masks returned via api having no generation info + if received_images is False: logger.critical("couldn't collect any responses, the extension will have no effect") return - # generate and inject grid - if opts.return_grid and len(processed.images) > 1: - grid = image_grid(processed.images, len(processed.images)) - processed_inject_image( - image=grid, - info_index=0, - save_path_override=p.outpath_grids, - grid=True, - response=donor_worker.response - ) - - # cleanup after we're doing using all the responses - for worker in self.world.get_workers(): - worker.response = None - - p.batch_size = len(processed.images) + p.batch_size = len(pp.images) + webui_state.textinfo = "" return # p's type is # "modules.processing.StableDiffusionProcessing*" def before_process(self, p, *args): - if not self.world.enabled: - logger.debug("extension is disabled") + is_img2img = getattr(p, 'init_images', False) + if is_img2img and self.world.enabled_i2i is False: + logger.debug("extension is disabled for i2i") + return + elif not is_img2img and self.world.enabled is False: + logger.debug("extension is disabled for t2i") return self.world.update(p) @@ -234,6 +203,14 @@ class DistributedScript(scripts.Script): continue title = script.title() + if title == "ADetailer": + adetailer_args = p.script_args[script.args_from:script.args_to] + + # InputAccordion main toggle, skip img2img toggle + if adetailer_args[0] and adetailer_args[1]: + logger.debug(f"adetailer is skipping img2img, returning control to wui") + return + # check for supported scripts if title == "ControlNet": # grab all controlnet units @@ -346,18 +323,34 @@ class DistributedScript(scripts.Script): p.batch_size = self.world.master_job().batch_size self.master_start = time.time() - # generate images assigned to local machine - p.do_not_save_grid = True # don't generate grid from master as we are doing this later. self.runs_since_init += 1 return - def postprocess(self, p, processed, *args): - if not self.world.enabled: + def postprocess_batch_list(self, p, pp, *args, **kwargs): + if not self.world.thin_client_mode and p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch + return + + is_img2img = getattr(p, 'init_images', False) + if is_img2img and self.world.enabled_i2i is False: + return + elif not is_img2img and self.world.enabled is False: return if self.master_start is not None: - self.add_to_gallery(p=p, processed=processed) + self.update_gallery(p=p, pp=pp) + + def postprocess(self, p, processed, *args): + for job in self.world.jobs: + if job.worker.response is not None: + for i, v in enumerate(job.gallery_map): + infotext = json.loads(job.worker.response['info'])['infotexts'][i] + infotext += f", Worker Label: {job.worker.label}" + processed.infotexts[v] = infotext + + # cleanup + for worker in self.world.get_workers(): + worker.response = None # restore process_images_inner if it was monkey-patched processing.process_images_inner = self.original_process_images_inner # save any dangling state to prevent load_config in next iteration overwriting it diff --git a/scripts/spartan/control_net.py b/scripts/spartan/control_net.py index dc55cd9..a029aea 100644 --- a/scripts/spartan/control_net.py +++ b/scripts/spartan/control_net.py @@ -1,9 +1,12 @@ +# https://github.com/Mikubill/sd-webui-controlnet/wiki/API#examples-1 + import copy from PIL import Image from modules.api.api import encode_pil_to_base64 from scripts.spartan.shared import logger import numpy as np import json +import enum def np_to_b64(image: np.ndarray): @@ -58,10 +61,14 @@ def pack_control_net(cn_units) -> dict: unit['mask'] = mask_b64 # mikubill unit['mask_image'] = mask_b64 # forge + + # serialize all enums + for k in unit.keys(): + if isinstance(unit[k], enum.Enum): + unit[k] = unit[k].value + # avoid returning duplicate detection maps since master should return the same one unit['save_detected_map'] = False - # remove anything unserializable - del unit['input_mode'] try: json.dumps(controlnet) diff --git a/scripts/spartan/pmodels.py b/scripts/spartan/pmodels.py index e349dc8..48fe7a6 100644 --- a/scripts/spartan/pmodels.py +++ b/scripts/spartan/pmodels.py @@ -29,17 +29,18 @@ class Worker_Model(BaseModel): default=False ) state: Optional[Any] = Field(default=1, description="The last known state of this worker") - user: Optional[str] = Field(description="The username to be used when authenticating with this worker") - password: Optional[str] = Field(description="The password to be used when authenticating with this worker") + user: Optional[str] = Field(description="The username to be used when authenticating with this worker", default=None) + password: Optional[str] = Field(description="The password to be used when authenticating with this worker", default=None) pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit") class ConfigModel(BaseModel): workers: List[Dict[str, Worker_Model]] - benchmark_payload: Dict = Field( + benchmark_payload: Benchmark_Payload = Field( default=Benchmark_Payload, description='the payload used when benchmarking a node' ) job_timeout: Optional[int] = Field(default=3) enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True) + enabled_i2i: Optional[bool] = Field(description="Same as above but for image to image", default=True) complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True) step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False) diff --git a/scripts/spartan/shared.py b/scripts/spartan/shared.py index 5400464..ead9f65 100644 --- a/scripts/spartan/shared.py +++ b/scripts/spartan/shared.py @@ -65,7 +65,7 @@ samples = 3 # number of times to benchmark worker after warmup benchmarks are c class BenchmarkPayload(BaseModel): - validate_assignment = True + # validate_assignment = True prompt: str = Field(default="A herd of cows grazing at the bottom of a sunny valley") negative_prompt: str = Field(default="") steps: int = Field(default=20) diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index 5aac350..b87c724 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -9,16 +9,17 @@ from .shared import logger, LOG_LEVEL, gui_handler from .worker import State from modules.call_queue import queue_lock from modules import progress +from modules.ui_components import InputAccordion worker_select_dropdown = None - class UI: """extension user interface related things""" - def __init__(self, world): + def __init__(self, world, is_img2img): self.world = world self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange + self.is_img2img = is_img2img # handlers @staticmethod @@ -184,12 +185,20 @@ class UI: worker.session.auth = (user, password) self.world.save_config() - def main_toggle_btn(self): - self.world.enabled = not self.world.enabled + def main_toggle_btn(self, state): + if self.is_img2img: + if self.world.enabled_i2i == state: # just prevents a redundant config save if ui desyncs + return + self.world.enabled_i2i = state + else: + if self.world.enabled == state: + return + self.world.enabled = state + self.world.save_config() # restore vanilla sdwui handler for model dropdown if extension is disabled or inject if otherwise - if not self.world.enabled: + if not self.world.enabled and not self.world.enabled_i2i: model_dropdown = opts.data_labels.get('sd_model_checkpoint') if self.original_model_dropdown_handler is not None: model_dropdown.onchange = self.original_model_dropdown_handler @@ -208,17 +217,14 @@ class UI: def create_ui(self): """creates the extension UI and returns relevant components""" components = [] + elem_id = 'enabled' + if self.is_img2img: + elem_id += '_i2i' with gradio.Blocks(variant='compact'): # Group() and Box() remove spacing - with gradio.Accordion(label='Distributed', open=False): - # https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/6109#issuecomment-1403315784 - main_toggle = gradio.Checkbox( # main on/off ext. toggle - elem_id='enable', - label='Enable', - value=self.world.enabled if self.world.enabled is not None else True, - interactive=True - ) - main_toggle.input(self.main_toggle_btn) + with InputAccordion(label='Distributed', open=False, value=self.world.config().get(elem_id), elem_id=elem_id) as main_toggle: + main_toggle.input(self.main_toggle_btn, inputs=[main_toggle]) + setattr(main_toggle.accordion, 'do_not_save_to_config', True) # InputAccordion is really a CheckBox components.append(main_toggle) with gradio.Tab('Status') as status_tab: @@ -236,8 +242,8 @@ class UI: info='top-most message is newest' ) - refresh_status_btn = gradio.Button(value='Refresh 🔄', size='sm') - refresh_status_btn.click(self.status_btn, inputs=[], outputs=[jobs, status, logs]) + refresh_status_btn = gradio.Button(value='Refresh 🔄', size='sm', elem_id='distributed-refresh-status', visible=False) + refresh_status_btn.click(self.status_btn, inputs=[], outputs=[jobs, status, logs], show_progress=False) status_tab.select(fn=self.status_btn, inputs=[], outputs=[jobs, status, logs]) components += [status, jobs, logs, refresh_status_btn] diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index f4eed0a..7bff6f3 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -15,6 +15,7 @@ from modules.shared import cmd_opts from modules.shared import state as master_state from . import shared as sh from .shared import logger, warmup_samples, LOG_LEVEL +from PIL import Image try: from webui import server_name @@ -40,6 +41,13 @@ class State(Enum): DISABLED = 5 +# looks redundant when encode_pil...() could be used, but it does not support all file formats. E.g. AVIF +def pil_to_64(image: Image) -> str: + buffer = io.BytesIO() + image.save(buffer, format="PNG") + return 'data:image/png;base64,' + str(base64.b64encode(buffer.getvalue()), 'utf-8') + + class Worker: """ This class represents a worker node in a distributed computing setup. @@ -361,10 +369,7 @@ class Worker: mode = 'img2img' # for use in checking script compat images = [] for image in init_images: - buffer = io.BytesIO() - image.save(buffer, format="PNG") - image = 'data:image/png;base64,' + str(base64.b64encode(buffer.getvalue()), 'utf-8') - images.append(image) + images.append(pil_to_64(image)) payload['init_images'] = images alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example @@ -401,9 +406,7 @@ class Worker: # if an image mask is present image_mask = payload.get('image_mask', None) if image_mask is not None: - image_b64 = encode_pil_to_base64(image_mask) - image_b64 = str(image_b64, 'utf-8') - payload['mask'] = image_b64 + payload['mask'] = pil_to_64(image_mask) del payload['image_mask'] # see if there is anything else wrong with serializing to payload @@ -606,6 +609,7 @@ class Worker: self.full_url("memory"), timeout=3 ) + self.response = response return response.status_code == 200 except requests.exceptions.ConnectionError as e: @@ -640,6 +644,7 @@ class Worker: return [] def load_options(self, model, vae=None): + failure_msg = f"failed to load options for worker '{self.label}'" if self.master: return @@ -653,17 +658,25 @@ class Worker: if vae is not None: payload['sd_vae'] = vae - self.set_state(State.WORKING) + state_cache = self.state + self.set_state(State.WORKING, expect_cycle=True) # may already be WORKING if called by worker.request() start = time.time() - response = self.session.post( - self.full_url("options"), - json=payload - ) + try: + response = self.session.post( + self.full_url("options"), + json=payload + ) + except requests.exceptions.RequestException: + self.set_state(State.UNAVAILABLE) + logger.error(f"{failure_msg} (connection error... OOM?)") + return + elapsed = time.time() - start - self.set_state(State.IDLE) + if state_cache != State.WORKING: # see above comment, this lets caller determine when worker is IDLE + self.set_state(State.IDLE) if response.status_code != 200: - logger.debug(f"failed to load options for worker '{self.label}'") + logger.debug(failure_msg) else: logger.debug(f"worker '{self.label}' loaded weights in {elapsed:.2f}s") self.loaded_model = model_name diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index f08faf3..92f153b 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -21,6 +21,9 @@ from .worker import Worker, State from modules.call_queue import wrap_queued_call, queue_lock from modules import processing from modules import progress +from modules.scripts import PostprocessBatchListArgs +from torchvision.transforms import ToPILImage +from modules.images import image_grid class NotBenchmarked(Exception): @@ -46,6 +49,7 @@ class Job: self.complementary: bool = False self.step_override = None self.thread = None + self.gallery_map: List[int] = [] def __str__(self): prefix = '' @@ -91,6 +95,7 @@ class World: self.verify_remotes = verify_remotes self.thin_client_mode = False self.enabled = True + self.enabled_i2i = True self.is_dropdown_handler_injected = False self.complement_production = True self.step_scaling = False @@ -103,7 +108,6 @@ class World: def __repr__(self): return f"{len(self._workers)} workers" - def default_batch_size(self) -> int: """the amount of images/total images requested that a worker would compute if conditions were perfect and each worker generated at the same speed. assumes one batch only""" @@ -113,7 +117,7 @@ class World: def size(self) -> int: """ Returns: - int: The number of nodes currently registered in the world. + int: The number of nodes currently registered in the world and in a valid state """ return len(self.get_workers()) @@ -273,7 +277,7 @@ class World: gradio.Info("Distributed: benchmarking complete!") self.save_config() - def get_current_output_size(self) -> int: + def num_requested(self) -> int: """ returns how many images would be returned from all jobs """ @@ -285,6 +289,11 @@ class World: return num_images + def num_gallery(self) -> int: + """How many images should appear in the gallery. This includes local generations and a grid(if enabled)""" + + return self.num_requested() * self.p.n_iter + shared.opts.return_grid + def speed_summary(self) -> str: """ Returns string listing workers by their ipm in descending order. @@ -393,7 +402,6 @@ class World: self.initialized = True logger.debug("world initialized!") - def get_workers(self): filtered: List[Worker] = [] for worker in self._workers: @@ -472,7 +480,7 @@ class World: ####################### # when total number of requested images was not cleanly divisible by world size then we tack the remainder on - remainder_images = self.p.batch_size - self.get_current_output_size() + remainder_images = self.p.batch_size - self.num_requested() if remainder_images >= 1: logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed") @@ -551,21 +559,37 @@ class World: else: logger.debug("complementary image production is disabled") - logger.info(self.distro_summary(payload)) + logger.info(self.distro_summary()) if self.thin_client_mode is True or self.master_job().batch_size == 0: # save original process_images_inner for later so we can restore once we're done logger.debug(f"bypassing local generation completely") def process_images_inner_bypass(p) -> processing.Processed: + p.seeds, p.subseeds, p.negative_prompts, p.prompts = [], [], [], [] + pp = PostprocessBatchListArgs(images=[]) + self.p.scripts.postprocess_batch_list(p, pp) + processed = processing.Processed(p, [], p.seed, info="") - processed.all_prompts = [] - processed.all_seeds = [] - processed.all_subseeds = [] - processed.all_negative_prompts = [] - processed.infotexts = [] - processed.prompt = None + processed.all_prompts = p.prompts + processed.all_seeds = p.seeds + processed.all_subseeds = p.subseeds + processed.all_negative_prompts = p.negative_prompts + processed.images = pp.images + processed.infotexts = [''] * self.num_requested() + + transform = ToPILImage() + for i, image in enumerate(processed.images): + processed.images[i] = transform(image) + self.p.scripts.postprocess(p, processed) + + # generate grid if enabled + if shared.opts.return_grid and len(processed.images) > 1: + grid = image_grid(processed.images, len(processed.images)) + processed.images.insert(0, grid) + processed.infotexts.insert(0, processed.infotexts[0]) + return processed processing.process_images_inner = process_images_inner_bypass @@ -576,18 +600,17 @@ class World: del self.jobs[last] last -= 1 - def distro_summary(self, payload): - # iterations = dict(payload)['n_iter'] - iterations = self.p.n_iter - num_returning = self.get_current_output_size() + def distro_summary(self): + num_returning = self.num_requested() num_complementary = num_returning - self.p.batch_size + distro_summary = "Job distribution:\n" - distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)" + distro_summary += f"{self.p.batch_size} * {self.p.n_iter} iteration(s)" if num_complementary > 0: distro_summary += f" + {num_complementary} complementary" - distro_summary += f": {num_returning} images total\n" + distro_summary += f": {num_returning * self.p.n_iter} images total\n" for job in self.jobs: - distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n" + distro_summary += f"'{job.worker.label}' - {job.batch_size * self.p.n_iter} image(s) @ {job.worker.avg_ipm:.2f} ipm\n" return distro_summary def config(self) -> dict: @@ -670,9 +693,10 @@ class World: self.add_worker(**fields) - sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload) + sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload.dict()) self.job_timeout = config.job_timeout self.enabled = config.enabled + self.enabled_i2i = config.enabled_i2i self.complement_production = config.complement_production self.step_scaling = config.step_scaling @@ -688,6 +712,7 @@ class World: benchmark_payload=sh.benchmark_payload, job_timeout=self.job_timeout, enabled=self.enabled, + enabled_i2i=self.enabled_i2i, complement_production=self.complement_production, step_scaling=self.step_scaling ) @@ -743,6 +768,9 @@ class World: worker.set_state(State.IDLE, expect_cycle=True) else: msg = f"worker '{worker.label}' is unreachable" + if worker.response.status_code is not None: + msg += f" <{worker.response.status_code}>" + logger.info(msg) gradio.Warning("Distributed: "+msg) worker.set_state(State.UNAVAILABLE) @@ -753,7 +781,6 @@ class World: for worker in self._workers: worker.restart() - def inject_model_dropdown_handler(self): if self.config().get('enabled', False): # TODO avoid access from config() return