merge dev, making 2.3.0

master v2.3.0
papuSpartan 2024-10-26 21:46:01 -05:00
commit 8fd65ebdc1
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
9 changed files with 239 additions and 160 deletions

View File

@ -1,10 +1,23 @@
# Change Log
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
## [2.3.0] - 2024-10-26
## Added
- Compatibility for some extensions which mostly only do postprocessing (e.g. Adetailer)
- Separate toggle state for img2img tab so txt2img can be enabled and t2i disabled or vice versa
## Changed
- Status tab will now automatically refresh
- Main toggle is now in the form of an InputAccordion
## Fixed
- An issue affecting controlnet and inpainting
- Toggle state sometimes desyncing when the page was refreshed
## [2.2.2] - 2024-8-30
### Fixed
- Unavailable state sometimes being ignored
## [2.2.1] - 2024-5-16
@ -82,4 +95,4 @@ Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic
- Worker randomly disconnecting when under high load due to handling a previous request
### Removed
- Certain superfluous warnings in logs related to third party extensions
- Certain superfluous warnings in logs related to third party extensions

View File

@ -1,4 +1,23 @@
function confirm_restart_workers(_) {
return confirm('Restart remote workers?')
}
}
// live updates
function update() {
try {
let currentTab = get_uiCurrentTabContent()
let buttons = document.querySelectorAll('#distributed-refresh-status')
for(let i = 0; i < buttons.length; i++) {
if(currentTab.contains(buttons[i])) {
buttons[i].click()
break
}
}
} catch (e) {
if (!(e instanceof TypeError)) {
throw e
}
}
}
setInterval(update, 1500)

View File

@ -13,6 +13,7 @@ import time
from threading import Thread
from typing import List
import gradio
from torchvision.transforms import ToTensor
import urllib3
from PIL import Image
from modules import processing
@ -24,7 +25,7 @@ from modules.shared import state as webui_state
from scripts.spartan.control_net import pack_control_net
from scripts.spartan.shared import logger
from scripts.spartan.ui import UI
from scripts.spartan.world import World, State
from scripts.spartan.world import World, State, Job
old_sigint_handler = signal.getsignal(signal.SIGINT)
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
@ -61,7 +62,7 @@ class DistributedScript(scripts.Script):
return scripts.AlwaysVisible
def ui(self, is_img2img):
extension_ui = UI(world=self.world)
extension_ui = UI(world=self.world, is_img2img=is_img2img)
# root, api_exposed = extension_ui.create_ui()
components = extension_ui.create_ui()
@ -71,77 +72,61 @@ class DistributedScript(scripts.Script):
# return some components that should be exposed to the api
return components
def add_to_gallery(self, processed, p):
"""adds generated images to the image gallery after waiting for all workers to finish"""
def api_to_internal(self, job) -> ([], [], [], [], []):
# takes worker response received from api and returns parsed objects in internal sdwui format. E.g. all_seeds
def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None):
image_params: json = response['parameters']
image_info_post: json = json.loads(response["info"]) # image info known after processing
num_response_images = image_params["batch_size"] * image_params["n_iter"]
seed = None
subseed = None
negative_prompt = None
pos_prompt = None
image_params: json = job.worker.response['parameters']
image_info_post: json = json.loads(job.worker.response["info"]) # image info known after processing
all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = [], [], [], [], []
for i in range(len(job.worker.response["images"])):
try:
if num_response_images > 1:
seed = image_info_post['all_seeds'][info_index]
subseed = image_info_post['all_subseeds'][info_index]
negative_prompt = image_info_post['all_negative_prompts'][info_index]
pos_prompt = image_info_post['all_prompts'][info_index]
else:
seed = image_info_post['seed']
subseed = image_info_post['subseed']
negative_prompt = image_info_post['negative_prompt']
pos_prompt = image_info_post['prompt']
if image_params["batch_size"] * image_params["n_iter"] > 1:
all_seeds.append(image_info_post['all_seeds'][i])
all_subseeds.append(image_info_post['all_subseeds'][i])
all_negative_prompts.append(image_info_post['all_negative_prompts'][i])
all_prompts.append(image_info_post['all_prompts'][i])
else: # only a single image received
all_seeds.append(image_info_post['seed'])
all_subseeds.append(image_info_post['subseed'])
all_negative_prompts.append(image_info_post['negative_prompt'])
all_prompts.append(image_info_post['prompt'])
except IndexError:
# like with controlnet masks, there isn't always full post-gen info, so we use the first images'
logger.debug(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
processed_inject_image(image=image, info_index=0, response=response)
return
# # like with controlnet masks, there isn't always full post-gen info, so we use the first images'
# logger.debug(f"Image at index {info_index} for '{job.worker.label}' was missing some post-generation data")
# self.processed_inject_image(image=image, info_index=0, job=job, p=p)
# return
logger.critical(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
continue
processed.all_seeds.append(seed)
processed.all_subseeds.append(subseed)
processed.all_negative_prompts.append(negative_prompt)
processed.all_prompts.append(pos_prompt)
processed.images.append(image) # actual received image
# parse image
image_bytes = base64.b64decode(job.worker.response["images"][i])
image = Image.open(io.BytesIO(image_bytes))
transform = ToTensor()
images.append(transform(image))
# generate info-text string
return all_seeds, all_subseeds, all_negative_prompts, all_prompts, images
def inject_job(self, job: Job, p, pp):
"""Adds the work completed by one Job via its worker response to the processing and postprocessing objects"""
all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = self.api_to_internal(job)
p.seeds.extend(all_seeds)
p.subseeds.extend(all_subseeds)
p.negative_prompts.extend(all_negative_prompts)
p.prompts.extend(all_prompts)
num_local = self.world.p.n_iter * self.world.p.batch_size + (opts.return_grid - self.world.thin_client_mode)
num_injected = len(pp.images) - self.world.p.batch_size
for i, image in enumerate(images):
# modules.ui_common -> update_generation_info renders to html below gallery
images_per_batch = p.n_iter * p.batch_size
# zero-indexed position of image in total batch (so including master results)
true_image_pos = len(processed.images) - 1
num_remote_images = images_per_batch * p.batch_size
if p.n_iter > 1: # if splitting by batch count
num_remote_images *= p.n_iter - 1
gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
job.gallery_map.append(gallery_index) # so we know where to edit infotext
pp.images.append(image)
logger.debug(f"image {gallery_index + 1 + self.world.thin_client_mode}/{self.world.num_gallery()}")
logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, "
f"info-index: {info_index}")
if self.world.thin_client_mode:
p.all_negative_prompts = processed.all_negative_prompts
try:
info_text = image_info_post['infotexts'][i]
except IndexError:
if not grid:
logger.warning(f"image {true_image_pos + 1} was missing info-text")
info_text = processed.infotexts[0]
info_text += f", Worker Label: {job.worker.label}"
processed.infotexts.append(info_text)
# automatically save received image to local disk if desired
if cmd_opts.distributed_remotes_autosave:
save_image(
image=image,
path=p.outpath_samples if save_path_override is None else save_path_override,
basename="",
seed=seed,
prompt=pos_prompt,
info=info_text,
extension=opts.samples_format
)
def update_gallery(self, pp, p):
"""adds all remotely generated images to the image gallery after waiting for all workers to finish"""
# get master ipm by estimating based on worker speed
master_elapsed = time.time() - self.master_start
@ -158,8 +143,7 @@ class DistributedScript(scripts.Script):
logger.debug("all worker request threads returned")
webui_state.textinfo = "Distributed - injecting images"
# some worker which we know has a good response that we can use for generating the grid
donor_worker = None
received_images = False
for job in self.world.jobs:
if job.worker.response is None or job.batch_size < 1 or job.worker.master:
continue
@ -170,8 +154,7 @@ class DistributedScript(scripts.Script):
if (job.batch_size * p.n_iter) < len(images):
logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}")
if donor_worker is None:
donor_worker = job.worker
received_images = True
except KeyError:
if job.batch_size > 0:
logger.warning(f"Worker '{job.worker.label}' had no images")
@ -185,41 +168,27 @@ class DistributedScript(scripts.Script):
logger.exception(e)
continue
# visibly add work from workers to the image gallery
for i in range(0, len(images)):
image_bytes = base64.b64decode(images[i])
image = Image.open(io.BytesIO(image_bytes))
# adding the images in
self.inject_job(job, p, pp)
# inject image
processed_inject_image(image=image, info_index=i, response=job.worker.response)
if donor_worker is None:
# TODO fix controlnet masks returned via api having no generation info
if received_images is False:
logger.critical("couldn't collect any responses, the extension will have no effect")
return
# generate and inject grid
if opts.return_grid and len(processed.images) > 1:
grid = image_grid(processed.images, len(processed.images))
processed_inject_image(
image=grid,
info_index=0,
save_path_override=p.outpath_grids,
grid=True,
response=donor_worker.response
)
# cleanup after we're doing using all the responses
for worker in self.world.get_workers():
worker.response = None
p.batch_size = len(processed.images)
p.batch_size = len(pp.images)
webui_state.textinfo = ""
return
# p's type is
# "modules.processing.StableDiffusionProcessing*"
def before_process(self, p, *args):
if not self.world.enabled:
logger.debug("extension is disabled")
is_img2img = getattr(p, 'init_images', False)
if is_img2img and self.world.enabled_i2i is False:
logger.debug("extension is disabled for i2i")
return
elif not is_img2img and self.world.enabled is False:
logger.debug("extension is disabled for t2i")
return
self.world.update(p)
@ -234,6 +203,14 @@ class DistributedScript(scripts.Script):
continue
title = script.title()
if title == "ADetailer":
adetailer_args = p.script_args[script.args_from:script.args_to]
# InputAccordion main toggle, skip img2img toggle
if adetailer_args[0] and adetailer_args[1]:
logger.debug(f"adetailer is skipping img2img, returning control to wui")
return
# check for supported scripts
if title == "ControlNet":
# grab all controlnet units
@ -346,18 +323,34 @@ class DistributedScript(scripts.Script):
p.batch_size = self.world.master_job().batch_size
self.master_start = time.time()
# generate images assigned to local machine
p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
self.runs_since_init += 1
return
def postprocess(self, p, processed, *args):
if not self.world.enabled:
def postprocess_batch_list(self, p, pp, *args, **kwargs):
if not self.world.thin_client_mode and p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
return
is_img2img = getattr(p, 'init_images', False)
if is_img2img and self.world.enabled_i2i is False:
return
elif not is_img2img and self.world.enabled is False:
return
if self.master_start is not None:
self.add_to_gallery(p=p, processed=processed)
self.update_gallery(p=p, pp=pp)
def postprocess(self, p, processed, *args):
for job in self.world.jobs:
if job.worker.response is not None:
for i, v in enumerate(job.gallery_map):
infotext = json.loads(job.worker.response['info'])['infotexts'][i]
infotext += f", Worker Label: {job.worker.label}"
processed.infotexts[v] = infotext
# cleanup
for worker in self.world.get_workers():
worker.response = None
# restore process_images_inner if it was monkey-patched
processing.process_images_inner = self.original_process_images_inner
# save any dangling state to prevent load_config in next iteration overwriting it

View File

@ -1,9 +1,12 @@
# https://github.com/Mikubill/sd-webui-controlnet/wiki/API#examples-1
import copy
from PIL import Image
from modules.api.api import encode_pil_to_base64
from scripts.spartan.shared import logger
import numpy as np
import json
import enum
def np_to_b64(image: np.ndarray):
@ -58,10 +61,14 @@ def pack_control_net(cn_units) -> dict:
unit['mask'] = mask_b64 # mikubill
unit['mask_image'] = mask_b64 # forge
# serialize all enums
for k in unit.keys():
if isinstance(unit[k], enum.Enum):
unit[k] = unit[k].value
# avoid returning duplicate detection maps since master should return the same one
unit['save_detected_map'] = False
# remove anything unserializable
del unit['input_mode']
try:
json.dumps(controlnet)

View File

@ -29,17 +29,18 @@ class Worker_Model(BaseModel):
default=False
)
state: Optional[Any] = Field(default=1, description="The last known state of this worker")
user: Optional[str] = Field(description="The username to be used when authenticating with this worker")
password: Optional[str] = Field(description="The password to be used when authenticating with this worker")
user: Optional[str] = Field(description="The username to be used when authenticating with this worker", default=None)
password: Optional[str] = Field(description="The password to be used when authenticating with this worker", default=None)
pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit")
class ConfigModel(BaseModel):
workers: List[Dict[str, Worker_Model]]
benchmark_payload: Dict = Field(
benchmark_payload: Benchmark_Payload = Field(
default=Benchmark_Payload,
description='the payload used when benchmarking a node'
)
job_timeout: Optional[int] = Field(default=3)
enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True)
enabled_i2i: Optional[bool] = Field(description="Same as above but for image to image", default=True)
complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True)
step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False)

View File

@ -65,7 +65,7 @@ samples = 3 # number of times to benchmark worker after warmup benchmarks are c
class BenchmarkPayload(BaseModel):
validate_assignment = True
# validate_assignment = True
prompt: str = Field(default="A herd of cows grazing at the bottom of a sunny valley")
negative_prompt: str = Field(default="")
steps: int = Field(default=20)

View File

@ -9,16 +9,17 @@ from .shared import logger, LOG_LEVEL, gui_handler
from .worker import State
from modules.call_queue import queue_lock
from modules import progress
from modules.ui_components import InputAccordion
worker_select_dropdown = None
class UI:
"""extension user interface related things"""
def __init__(self, world):
def __init__(self, world, is_img2img):
self.world = world
self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange
self.is_img2img = is_img2img
# handlers
@staticmethod
@ -184,12 +185,20 @@ class UI:
worker.session.auth = (user, password)
self.world.save_config()
def main_toggle_btn(self):
self.world.enabled = not self.world.enabled
def main_toggle_btn(self, state):
if self.is_img2img:
if self.world.enabled_i2i == state: # just prevents a redundant config save if ui desyncs
return
self.world.enabled_i2i = state
else:
if self.world.enabled == state:
return
self.world.enabled = state
self.world.save_config()
# restore vanilla sdwui handler for model dropdown if extension is disabled or inject if otherwise
if not self.world.enabled:
if not self.world.enabled and not self.world.enabled_i2i:
model_dropdown = opts.data_labels.get('sd_model_checkpoint')
if self.original_model_dropdown_handler is not None:
model_dropdown.onchange = self.original_model_dropdown_handler
@ -208,17 +217,14 @@ class UI:
def create_ui(self):
"""creates the extension UI and returns relevant components"""
components = []
elem_id = 'enabled'
if self.is_img2img:
elem_id += '_i2i'
with gradio.Blocks(variant='compact'): # Group() and Box() remove spacing
with gradio.Accordion(label='Distributed', open=False):
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/6109#issuecomment-1403315784
main_toggle = gradio.Checkbox( # main on/off ext. toggle
elem_id='enable',
label='Enable',
value=self.world.enabled if self.world.enabled is not None else True,
interactive=True
)
main_toggle.input(self.main_toggle_btn)
with InputAccordion(label='Distributed', open=False, value=self.world.config().get(elem_id), elem_id=elem_id) as main_toggle:
main_toggle.input(self.main_toggle_btn, inputs=[main_toggle])
setattr(main_toggle.accordion, 'do_not_save_to_config', True) # InputAccordion is really a CheckBox
components.append(main_toggle)
with gradio.Tab('Status') as status_tab:
@ -236,8 +242,8 @@ class UI:
info='top-most message is newest'
)
refresh_status_btn = gradio.Button(value='Refresh 🔄', size='sm')
refresh_status_btn.click(self.status_btn, inputs=[], outputs=[jobs, status, logs])
refresh_status_btn = gradio.Button(value='Refresh 🔄', size='sm', elem_id='distributed-refresh-status', visible=False)
refresh_status_btn.click(self.status_btn, inputs=[], outputs=[jobs, status, logs], show_progress=False)
status_tab.select(fn=self.status_btn, inputs=[], outputs=[jobs, status, logs])
components += [status, jobs, logs, refresh_status_btn]

View File

@ -15,6 +15,7 @@ from modules.shared import cmd_opts
from modules.shared import state as master_state
from . import shared as sh
from .shared import logger, warmup_samples, LOG_LEVEL
from PIL import Image
try:
from webui import server_name
@ -40,6 +41,13 @@ class State(Enum):
DISABLED = 5
# looks redundant when encode_pil...() could be used, but it does not support all file formats. E.g. AVIF
def pil_to_64(image: Image) -> str:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return 'data:image/png;base64,' + str(base64.b64encode(buffer.getvalue()), 'utf-8')
class Worker:
"""
This class represents a worker node in a distributed computing setup.
@ -361,10 +369,7 @@ class Worker:
mode = 'img2img' # for use in checking script compat
images = []
for image in init_images:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image = 'data:image/png;base64,' + str(base64.b64encode(buffer.getvalue()), 'utf-8')
images.append(image)
images.append(pil_to_64(image))
payload['init_images'] = images
alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example
@ -401,9 +406,7 @@ class Worker:
# if an image mask is present
image_mask = payload.get('image_mask', None)
if image_mask is not None:
image_b64 = encode_pil_to_base64(image_mask)
image_b64 = str(image_b64, 'utf-8')
payload['mask'] = image_b64
payload['mask'] = pil_to_64(image_mask)
del payload['image_mask']
# see if there is anything else wrong with serializing to payload
@ -606,6 +609,7 @@ class Worker:
self.full_url("memory"),
timeout=3
)
self.response = response
return response.status_code == 200
except requests.exceptions.ConnectionError as e:
@ -640,6 +644,7 @@ class Worker:
return []
def load_options(self, model, vae=None):
failure_msg = f"failed to load options for worker '{self.label}'"
if self.master:
return
@ -653,17 +658,25 @@ class Worker:
if vae is not None:
payload['sd_vae'] = vae
self.set_state(State.WORKING)
state_cache = self.state
self.set_state(State.WORKING, expect_cycle=True) # may already be WORKING if called by worker.request()
start = time.time()
response = self.session.post(
self.full_url("options"),
json=payload
)
try:
response = self.session.post(
self.full_url("options"),
json=payload
)
except requests.exceptions.RequestException:
self.set_state(State.UNAVAILABLE)
logger.error(f"{failure_msg} (connection error... OOM?)")
return
elapsed = time.time() - start
self.set_state(State.IDLE)
if state_cache != State.WORKING: # see above comment, this lets caller determine when worker is IDLE
self.set_state(State.IDLE)
if response.status_code != 200:
logger.debug(f"failed to load options for worker '{self.label}'")
logger.debug(failure_msg)
else:
logger.debug(f"worker '{self.label}' loaded weights in {elapsed:.2f}s")
self.loaded_model = model_name

View File

@ -21,6 +21,9 @@ from .worker import Worker, State
from modules.call_queue import wrap_queued_call, queue_lock
from modules import processing
from modules import progress
from modules.scripts import PostprocessBatchListArgs
from torchvision.transforms import ToPILImage
from modules.images import image_grid
class NotBenchmarked(Exception):
@ -46,6 +49,7 @@ class Job:
self.complementary: bool = False
self.step_override = None
self.thread = None
self.gallery_map: List[int] = []
def __str__(self):
prefix = ''
@ -91,6 +95,7 @@ class World:
self.verify_remotes = verify_remotes
self.thin_client_mode = False
self.enabled = True
self.enabled_i2i = True
self.is_dropdown_handler_injected = False
self.complement_production = True
self.step_scaling = False
@ -103,7 +108,6 @@ class World:
def __repr__(self):
return f"{len(self._workers)} workers"
def default_batch_size(self) -> int:
"""the amount of images/total images requested that a worker would compute if conditions were perfect and
each worker generated at the same speed. assumes one batch only"""
@ -113,7 +117,7 @@ class World:
def size(self) -> int:
"""
Returns:
int: The number of nodes currently registered in the world.
int: The number of nodes currently registered in the world and in a valid state
"""
return len(self.get_workers())
@ -273,7 +277,7 @@ class World:
gradio.Info("Distributed: benchmarking complete!")
self.save_config()
def get_current_output_size(self) -> int:
def num_requested(self) -> int:
"""
returns how many images would be returned from all jobs
"""
@ -285,6 +289,11 @@ class World:
return num_images
def num_gallery(self) -> int:
"""How many images should appear in the gallery. This includes local generations and a grid(if enabled)"""
return self.num_requested() * self.p.n_iter + shared.opts.return_grid
def speed_summary(self) -> str:
"""
Returns string listing workers by their ipm in descending order.
@ -393,7 +402,6 @@ class World:
self.initialized = True
logger.debug("world initialized!")
def get_workers(self):
filtered: List[Worker] = []
for worker in self._workers:
@ -472,7 +480,7 @@ class World:
#######################
# when total number of requested images was not cleanly divisible by world size then we tack the remainder on
remainder_images = self.p.batch_size - self.get_current_output_size()
remainder_images = self.p.batch_size - self.num_requested()
if remainder_images >= 1:
logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
@ -551,21 +559,37 @@ class World:
else:
logger.debug("complementary image production is disabled")
logger.info(self.distro_summary(payload))
logger.info(self.distro_summary())
if self.thin_client_mode is True or self.master_job().batch_size == 0:
# save original process_images_inner for later so we can restore once we're done
logger.debug(f"bypassing local generation completely")
def process_images_inner_bypass(p) -> processing.Processed:
p.seeds, p.subseeds, p.negative_prompts, p.prompts = [], [], [], []
pp = PostprocessBatchListArgs(images=[])
self.p.scripts.postprocess_batch_list(p, pp)
processed = processing.Processed(p, [], p.seed, info="")
processed.all_prompts = []
processed.all_seeds = []
processed.all_subseeds = []
processed.all_negative_prompts = []
processed.infotexts = []
processed.prompt = None
processed.all_prompts = p.prompts
processed.all_seeds = p.seeds
processed.all_subseeds = p.subseeds
processed.all_negative_prompts = p.negative_prompts
processed.images = pp.images
processed.infotexts = [''] * self.num_requested()
transform = ToPILImage()
for i, image in enumerate(processed.images):
processed.images[i] = transform(image)
self.p.scripts.postprocess(p, processed)
# generate grid if enabled
if shared.opts.return_grid and len(processed.images) > 1:
grid = image_grid(processed.images, len(processed.images))
processed.images.insert(0, grid)
processed.infotexts.insert(0, processed.infotexts[0])
return processed
processing.process_images_inner = process_images_inner_bypass
@ -576,18 +600,17 @@ class World:
del self.jobs[last]
last -= 1
def distro_summary(self, payload):
# iterations = dict(payload)['n_iter']
iterations = self.p.n_iter
num_returning = self.get_current_output_size()
def distro_summary(self):
num_returning = self.num_requested()
num_complementary = num_returning - self.p.batch_size
distro_summary = "Job distribution:\n"
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
distro_summary += f"{self.p.batch_size} * {self.p.n_iter} iteration(s)"
if num_complementary > 0:
distro_summary += f" + {num_complementary} complementary"
distro_summary += f": {num_returning} images total\n"
distro_summary += f": {num_returning * self.p.n_iter} images total\n"
for job in self.jobs:
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
distro_summary += f"'{job.worker.label}' - {job.batch_size * self.p.n_iter} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
return distro_summary
def config(self) -> dict:
@ -670,9 +693,10 @@ class World:
self.add_worker(**fields)
sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload)
sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload.dict())
self.job_timeout = config.job_timeout
self.enabled = config.enabled
self.enabled_i2i = config.enabled_i2i
self.complement_production = config.complement_production
self.step_scaling = config.step_scaling
@ -688,6 +712,7 @@ class World:
benchmark_payload=sh.benchmark_payload,
job_timeout=self.job_timeout,
enabled=self.enabled,
enabled_i2i=self.enabled_i2i,
complement_production=self.complement_production,
step_scaling=self.step_scaling
)
@ -743,6 +768,9 @@ class World:
worker.set_state(State.IDLE, expect_cycle=True)
else:
msg = f"worker '{worker.label}' is unreachable"
if worker.response.status_code is not None:
msg += f" <{worker.response.status_code}>"
logger.info(msg)
gradio.Warning("Distributed: "+msg)
worker.set_state(State.UNAVAILABLE)
@ -753,7 +781,6 @@ class World:
for worker in self._workers:
worker.restart()
def inject_model_dropdown_handler(self):
if self.config().get('enabled', False): # TODO avoid access from config()
return