merge dev, making 2.3.0

master v2.3.0
papuSpartan 2024-10-26 21:46:01 -05:00
commit 8fd65ebdc1
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
9 changed files with 239 additions and 160 deletions

View File

@ -1,10 +1,23 @@
# Change Log # Change Log
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html) Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
## [2.3.0] - 2024-10-26
## Added
- Compatibility for some extensions which mostly only do postprocessing (e.g. Adetailer)
- Separate toggle state for img2img tab so txt2img can be enabled and t2i disabled or vice versa
## Changed
- Status tab will now automatically refresh
- Main toggle is now in the form of an InputAccordion
## Fixed
- An issue affecting controlnet and inpainting
- Toggle state sometimes desyncing when the page was refreshed
## [2.2.2] - 2024-8-30 ## [2.2.2] - 2024-8-30
### Fixed ### Fixed
- Unavailable state sometimes being ignored - Unavailable state sometimes being ignored
## [2.2.1] - 2024-5-16 ## [2.2.1] - 2024-5-16
@ -82,4 +95,4 @@ Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic
- Worker randomly disconnecting when under high load due to handling a previous request - Worker randomly disconnecting when under high load due to handling a previous request
### Removed ### Removed
- Certain superfluous warnings in logs related to third party extensions - Certain superfluous warnings in logs related to third party extensions

View File

@ -1,4 +1,23 @@
function confirm_restart_workers(_) { function confirm_restart_workers(_) {
return confirm('Restart remote workers?') return confirm('Restart remote workers?')
} }
// live updates
function update() {
try {
let currentTab = get_uiCurrentTabContent()
let buttons = document.querySelectorAll('#distributed-refresh-status')
for(let i = 0; i < buttons.length; i++) {
if(currentTab.contains(buttons[i])) {
buttons[i].click()
break
}
}
} catch (e) {
if (!(e instanceof TypeError)) {
throw e
}
}
}
setInterval(update, 1500)

View File

@ -13,6 +13,7 @@ import time
from threading import Thread from threading import Thread
from typing import List from typing import List
import gradio import gradio
from torchvision.transforms import ToTensor
import urllib3 import urllib3
from PIL import Image from PIL import Image
from modules import processing from modules import processing
@ -24,7 +25,7 @@ from modules.shared import state as webui_state
from scripts.spartan.control_net import pack_control_net from scripts.spartan.control_net import pack_control_net
from scripts.spartan.shared import logger from scripts.spartan.shared import logger
from scripts.spartan.ui import UI from scripts.spartan.ui import UI
from scripts.spartan.world import World, State from scripts.spartan.world import World, State, Job
old_sigint_handler = signal.getsignal(signal.SIGINT) old_sigint_handler = signal.getsignal(signal.SIGINT)
old_sigterm_handler = signal.getsignal(signal.SIGTERM) old_sigterm_handler = signal.getsignal(signal.SIGTERM)
@ -61,7 +62,7 @@ class DistributedScript(scripts.Script):
return scripts.AlwaysVisible return scripts.AlwaysVisible
def ui(self, is_img2img): def ui(self, is_img2img):
extension_ui = UI(world=self.world) extension_ui = UI(world=self.world, is_img2img=is_img2img)
# root, api_exposed = extension_ui.create_ui() # root, api_exposed = extension_ui.create_ui()
components = extension_ui.create_ui() components = extension_ui.create_ui()
@ -71,77 +72,61 @@ class DistributedScript(scripts.Script):
# return some components that should be exposed to the api # return some components that should be exposed to the api
return components return components
def add_to_gallery(self, processed, p): def api_to_internal(self, job) -> ([], [], [], [], []):
"""adds generated images to the image gallery after waiting for all workers to finish""" # takes worker response received from api and returns parsed objects in internal sdwui format. E.g. all_seeds
def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None): image_params: json = job.worker.response['parameters']
image_params: json = response['parameters'] image_info_post: json = json.loads(job.worker.response["info"]) # image info known after processing
image_info_post: json = json.loads(response["info"]) # image info known after processing all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = [], [], [], [], []
num_response_images = image_params["batch_size"] * image_params["n_iter"]
seed = None
subseed = None
negative_prompt = None
pos_prompt = None
for i in range(len(job.worker.response["images"])):
try: try:
if num_response_images > 1: if image_params["batch_size"] * image_params["n_iter"] > 1:
seed = image_info_post['all_seeds'][info_index] all_seeds.append(image_info_post['all_seeds'][i])
subseed = image_info_post['all_subseeds'][info_index] all_subseeds.append(image_info_post['all_subseeds'][i])
negative_prompt = image_info_post['all_negative_prompts'][info_index] all_negative_prompts.append(image_info_post['all_negative_prompts'][i])
pos_prompt = image_info_post['all_prompts'][info_index] all_prompts.append(image_info_post['all_prompts'][i])
else: else: # only a single image received
seed = image_info_post['seed'] all_seeds.append(image_info_post['seed'])
subseed = image_info_post['subseed'] all_subseeds.append(image_info_post['subseed'])
negative_prompt = image_info_post['negative_prompt'] all_negative_prompts.append(image_info_post['negative_prompt'])
pos_prompt = image_info_post['prompt'] all_prompts.append(image_info_post['prompt'])
except IndexError: except IndexError:
# like with controlnet masks, there isn't always full post-gen info, so we use the first images' # # like with controlnet masks, there isn't always full post-gen info, so we use the first images'
logger.debug(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data") # logger.debug(f"Image at index {info_index} for '{job.worker.label}' was missing some post-generation data")
processed_inject_image(image=image, info_index=0, response=response) # self.processed_inject_image(image=image, info_index=0, job=job, p=p)
return # return
logger.critical(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
continue
processed.all_seeds.append(seed) # parse image
processed.all_subseeds.append(subseed) image_bytes = base64.b64decode(job.worker.response["images"][i])
processed.all_negative_prompts.append(negative_prompt) image = Image.open(io.BytesIO(image_bytes))
processed.all_prompts.append(pos_prompt) transform = ToTensor()
processed.images.append(image) # actual received image images.append(transform(image))
# generate info-text string return all_seeds, all_subseeds, all_negative_prompts, all_prompts, images
def inject_job(self, job: Job, p, pp):
"""Adds the work completed by one Job via its worker response to the processing and postprocessing objects"""
all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = self.api_to_internal(job)
p.seeds.extend(all_seeds)
p.subseeds.extend(all_subseeds)
p.negative_prompts.extend(all_negative_prompts)
p.prompts.extend(all_prompts)
num_local = self.world.p.n_iter * self.world.p.batch_size + (opts.return_grid - self.world.thin_client_mode)
num_injected = len(pp.images) - self.world.p.batch_size
for i, image in enumerate(images):
# modules.ui_common -> update_generation_info renders to html below gallery # modules.ui_common -> update_generation_info renders to html below gallery
images_per_batch = p.n_iter * p.batch_size gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
# zero-indexed position of image in total batch (so including master results) job.gallery_map.append(gallery_index) # so we know where to edit infotext
true_image_pos = len(processed.images) - 1 pp.images.append(image)
num_remote_images = images_per_batch * p.batch_size logger.debug(f"image {gallery_index + 1 + self.world.thin_client_mode}/{self.world.num_gallery()}")
if p.n_iter > 1: # if splitting by batch count
num_remote_images *= p.n_iter - 1
logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, " def update_gallery(self, pp, p):
f"info-index: {info_index}") """adds all remotely generated images to the image gallery after waiting for all workers to finish"""
if self.world.thin_client_mode:
p.all_negative_prompts = processed.all_negative_prompts
try:
info_text = image_info_post['infotexts'][i]
except IndexError:
if not grid:
logger.warning(f"image {true_image_pos + 1} was missing info-text")
info_text = processed.infotexts[0]
info_text += f", Worker Label: {job.worker.label}"
processed.infotexts.append(info_text)
# automatically save received image to local disk if desired
if cmd_opts.distributed_remotes_autosave:
save_image(
image=image,
path=p.outpath_samples if save_path_override is None else save_path_override,
basename="",
seed=seed,
prompt=pos_prompt,
info=info_text,
extension=opts.samples_format
)
# get master ipm by estimating based on worker speed # get master ipm by estimating based on worker speed
master_elapsed = time.time() - self.master_start master_elapsed = time.time() - self.master_start
@ -158,8 +143,7 @@ class DistributedScript(scripts.Script):
logger.debug("all worker request threads returned") logger.debug("all worker request threads returned")
webui_state.textinfo = "Distributed - injecting images" webui_state.textinfo = "Distributed - injecting images"
# some worker which we know has a good response that we can use for generating the grid received_images = False
donor_worker = None
for job in self.world.jobs: for job in self.world.jobs:
if job.worker.response is None or job.batch_size < 1 or job.worker.master: if job.worker.response is None or job.batch_size < 1 or job.worker.master:
continue continue
@ -170,8 +154,7 @@ class DistributedScript(scripts.Script):
if (job.batch_size * p.n_iter) < len(images): if (job.batch_size * p.n_iter) < len(images):
logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}") logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}")
if donor_worker is None: received_images = True
donor_worker = job.worker
except KeyError: except KeyError:
if job.batch_size > 0: if job.batch_size > 0:
logger.warning(f"Worker '{job.worker.label}' had no images") logger.warning(f"Worker '{job.worker.label}' had no images")
@ -185,41 +168,27 @@ class DistributedScript(scripts.Script):
logger.exception(e) logger.exception(e)
continue continue
# visibly add work from workers to the image gallery # adding the images in
for i in range(0, len(images)): self.inject_job(job, p, pp)
image_bytes = base64.b64decode(images[i])
image = Image.open(io.BytesIO(image_bytes))
# inject image # TODO fix controlnet masks returned via api having no generation info
processed_inject_image(image=image, info_index=i, response=job.worker.response) if received_images is False:
if donor_worker is None:
logger.critical("couldn't collect any responses, the extension will have no effect") logger.critical("couldn't collect any responses, the extension will have no effect")
return return
# generate and inject grid p.batch_size = len(pp.images)
if opts.return_grid and len(processed.images) > 1: webui_state.textinfo = ""
grid = image_grid(processed.images, len(processed.images))
processed_inject_image(
image=grid,
info_index=0,
save_path_override=p.outpath_grids,
grid=True,
response=donor_worker.response
)
# cleanup after we're doing using all the responses
for worker in self.world.get_workers():
worker.response = None
p.batch_size = len(processed.images)
return return
# p's type is # p's type is
# "modules.processing.StableDiffusionProcessing*" # "modules.processing.StableDiffusionProcessing*"
def before_process(self, p, *args): def before_process(self, p, *args):
if not self.world.enabled: is_img2img = getattr(p, 'init_images', False)
logger.debug("extension is disabled") if is_img2img and self.world.enabled_i2i is False:
logger.debug("extension is disabled for i2i")
return
elif not is_img2img and self.world.enabled is False:
logger.debug("extension is disabled for t2i")
return return
self.world.update(p) self.world.update(p)
@ -234,6 +203,14 @@ class DistributedScript(scripts.Script):
continue continue
title = script.title() title = script.title()
if title == "ADetailer":
adetailer_args = p.script_args[script.args_from:script.args_to]
# InputAccordion main toggle, skip img2img toggle
if adetailer_args[0] and adetailer_args[1]:
logger.debug(f"adetailer is skipping img2img, returning control to wui")
return
# check for supported scripts # check for supported scripts
if title == "ControlNet": if title == "ControlNet":
# grab all controlnet units # grab all controlnet units
@ -346,18 +323,34 @@ class DistributedScript(scripts.Script):
p.batch_size = self.world.master_job().batch_size p.batch_size = self.world.master_job().batch_size
self.master_start = time.time() self.master_start = time.time()
# generate images assigned to local machine
p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
self.runs_since_init += 1 self.runs_since_init += 1
return return
def postprocess(self, p, processed, *args): def postprocess_batch_list(self, p, pp, *args, **kwargs):
if not self.world.enabled: if not self.world.thin_client_mode and p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
return
is_img2img = getattr(p, 'init_images', False)
if is_img2img and self.world.enabled_i2i is False:
return
elif not is_img2img and self.world.enabled is False:
return return
if self.master_start is not None: if self.master_start is not None:
self.add_to_gallery(p=p, processed=processed) self.update_gallery(p=p, pp=pp)
def postprocess(self, p, processed, *args):
for job in self.world.jobs:
if job.worker.response is not None:
for i, v in enumerate(job.gallery_map):
infotext = json.loads(job.worker.response['info'])['infotexts'][i]
infotext += f", Worker Label: {job.worker.label}"
processed.infotexts[v] = infotext
# cleanup
for worker in self.world.get_workers():
worker.response = None
# restore process_images_inner if it was monkey-patched # restore process_images_inner if it was monkey-patched
processing.process_images_inner = self.original_process_images_inner processing.process_images_inner = self.original_process_images_inner
# save any dangling state to prevent load_config in next iteration overwriting it # save any dangling state to prevent load_config in next iteration overwriting it

View File

@ -1,9 +1,12 @@
# https://github.com/Mikubill/sd-webui-controlnet/wiki/API#examples-1
import copy import copy
from PIL import Image from PIL import Image
from modules.api.api import encode_pil_to_base64 from modules.api.api import encode_pil_to_base64
from scripts.spartan.shared import logger from scripts.spartan.shared import logger
import numpy as np import numpy as np
import json import json
import enum
def np_to_b64(image: np.ndarray): def np_to_b64(image: np.ndarray):
@ -58,10 +61,14 @@ def pack_control_net(cn_units) -> dict:
unit['mask'] = mask_b64 # mikubill unit['mask'] = mask_b64 # mikubill
unit['mask_image'] = mask_b64 # forge unit['mask_image'] = mask_b64 # forge
# serialize all enums
for k in unit.keys():
if isinstance(unit[k], enum.Enum):
unit[k] = unit[k].value
# avoid returning duplicate detection maps since master should return the same one # avoid returning duplicate detection maps since master should return the same one
unit['save_detected_map'] = False unit['save_detected_map'] = False
# remove anything unserializable
del unit['input_mode']
try: try:
json.dumps(controlnet) json.dumps(controlnet)

View File

@ -29,17 +29,18 @@ class Worker_Model(BaseModel):
default=False default=False
) )
state: Optional[Any] = Field(default=1, description="The last known state of this worker") state: Optional[Any] = Field(default=1, description="The last known state of this worker")
user: Optional[str] = Field(description="The username to be used when authenticating with this worker") user: Optional[str] = Field(description="The username to be used when authenticating with this worker", default=None)
password: Optional[str] = Field(description="The password to be used when authenticating with this worker") password: Optional[str] = Field(description="The password to be used when authenticating with this worker", default=None)
pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit") pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit")
class ConfigModel(BaseModel): class ConfigModel(BaseModel):
workers: List[Dict[str, Worker_Model]] workers: List[Dict[str, Worker_Model]]
benchmark_payload: Dict = Field( benchmark_payload: Benchmark_Payload = Field(
default=Benchmark_Payload, default=Benchmark_Payload,
description='the payload used when benchmarking a node' description='the payload used when benchmarking a node'
) )
job_timeout: Optional[int] = Field(default=3) job_timeout: Optional[int] = Field(default=3)
enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True) enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True)
enabled_i2i: Optional[bool] = Field(description="Same as above but for image to image", default=True)
complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True) complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True)
step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False) step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False)

View File

@ -65,7 +65,7 @@ samples = 3 # number of times to benchmark worker after warmup benchmarks are c
class BenchmarkPayload(BaseModel): class BenchmarkPayload(BaseModel):
validate_assignment = True # validate_assignment = True
prompt: str = Field(default="A herd of cows grazing at the bottom of a sunny valley") prompt: str = Field(default="A herd of cows grazing at the bottom of a sunny valley")
negative_prompt: str = Field(default="") negative_prompt: str = Field(default="")
steps: int = Field(default=20) steps: int = Field(default=20)

View File

@ -9,16 +9,17 @@ from .shared import logger, LOG_LEVEL, gui_handler
from .worker import State from .worker import State
from modules.call_queue import queue_lock from modules.call_queue import queue_lock
from modules import progress from modules import progress
from modules.ui_components import InputAccordion
worker_select_dropdown = None worker_select_dropdown = None
class UI: class UI:
"""extension user interface related things""" """extension user interface related things"""
def __init__(self, world): def __init__(self, world, is_img2img):
self.world = world self.world = world
self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange
self.is_img2img = is_img2img
# handlers # handlers
@staticmethod @staticmethod
@ -184,12 +185,20 @@ class UI:
worker.session.auth = (user, password) worker.session.auth = (user, password)
self.world.save_config() self.world.save_config()
def main_toggle_btn(self): def main_toggle_btn(self, state):
self.world.enabled = not self.world.enabled if self.is_img2img:
if self.world.enabled_i2i == state: # just prevents a redundant config save if ui desyncs
return
self.world.enabled_i2i = state
else:
if self.world.enabled == state:
return
self.world.enabled = state
self.world.save_config() self.world.save_config()
# restore vanilla sdwui handler for model dropdown if extension is disabled or inject if otherwise # restore vanilla sdwui handler for model dropdown if extension is disabled or inject if otherwise
if not self.world.enabled: if not self.world.enabled and not self.world.enabled_i2i:
model_dropdown = opts.data_labels.get('sd_model_checkpoint') model_dropdown = opts.data_labels.get('sd_model_checkpoint')
if self.original_model_dropdown_handler is not None: if self.original_model_dropdown_handler is not None:
model_dropdown.onchange = self.original_model_dropdown_handler model_dropdown.onchange = self.original_model_dropdown_handler
@ -208,17 +217,14 @@ class UI:
def create_ui(self): def create_ui(self):
"""creates the extension UI and returns relevant components""" """creates the extension UI and returns relevant components"""
components = [] components = []
elem_id = 'enabled'
if self.is_img2img:
elem_id += '_i2i'
with gradio.Blocks(variant='compact'): # Group() and Box() remove spacing with gradio.Blocks(variant='compact'): # Group() and Box() remove spacing
with gradio.Accordion(label='Distributed', open=False): with InputAccordion(label='Distributed', open=False, value=self.world.config().get(elem_id), elem_id=elem_id) as main_toggle:
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/6109#issuecomment-1403315784 main_toggle.input(self.main_toggle_btn, inputs=[main_toggle])
main_toggle = gradio.Checkbox( # main on/off ext. toggle setattr(main_toggle.accordion, 'do_not_save_to_config', True) # InputAccordion is really a CheckBox
elem_id='enable',
label='Enable',
value=self.world.enabled if self.world.enabled is not None else True,
interactive=True
)
main_toggle.input(self.main_toggle_btn)
components.append(main_toggle) components.append(main_toggle)
with gradio.Tab('Status') as status_tab: with gradio.Tab('Status') as status_tab:
@ -236,8 +242,8 @@ class UI:
info='top-most message is newest' info='top-most message is newest'
) )
refresh_status_btn = gradio.Button(value='Refresh 🔄', size='sm') refresh_status_btn = gradio.Button(value='Refresh 🔄', size='sm', elem_id='distributed-refresh-status', visible=False)
refresh_status_btn.click(self.status_btn, inputs=[], outputs=[jobs, status, logs]) refresh_status_btn.click(self.status_btn, inputs=[], outputs=[jobs, status, logs], show_progress=False)
status_tab.select(fn=self.status_btn, inputs=[], outputs=[jobs, status, logs]) status_tab.select(fn=self.status_btn, inputs=[], outputs=[jobs, status, logs])
components += [status, jobs, logs, refresh_status_btn] components += [status, jobs, logs, refresh_status_btn]

View File

@ -15,6 +15,7 @@ from modules.shared import cmd_opts
from modules.shared import state as master_state from modules.shared import state as master_state
from . import shared as sh from . import shared as sh
from .shared import logger, warmup_samples, LOG_LEVEL from .shared import logger, warmup_samples, LOG_LEVEL
from PIL import Image
try: try:
from webui import server_name from webui import server_name
@ -40,6 +41,13 @@ class State(Enum):
DISABLED = 5 DISABLED = 5
# looks redundant when encode_pil...() could be used, but it does not support all file formats. E.g. AVIF
def pil_to_64(image: Image) -> str:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return 'data:image/png;base64,' + str(base64.b64encode(buffer.getvalue()), 'utf-8')
class Worker: class Worker:
""" """
This class represents a worker node in a distributed computing setup. This class represents a worker node in a distributed computing setup.
@ -361,10 +369,7 @@ class Worker:
mode = 'img2img' # for use in checking script compat mode = 'img2img' # for use in checking script compat
images = [] images = []
for image in init_images: for image in init_images:
buffer = io.BytesIO() images.append(pil_to_64(image))
image.save(buffer, format="PNG")
image = 'data:image/png;base64,' + str(base64.b64encode(buffer.getvalue()), 'utf-8')
images.append(image)
payload['init_images'] = images payload['init_images'] = images
alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example
@ -401,9 +406,7 @@ class Worker:
# if an image mask is present # if an image mask is present
image_mask = payload.get('image_mask', None) image_mask = payload.get('image_mask', None)
if image_mask is not None: if image_mask is not None:
image_b64 = encode_pil_to_base64(image_mask) payload['mask'] = pil_to_64(image_mask)
image_b64 = str(image_b64, 'utf-8')
payload['mask'] = image_b64
del payload['image_mask'] del payload['image_mask']
# see if there is anything else wrong with serializing to payload # see if there is anything else wrong with serializing to payload
@ -606,6 +609,7 @@ class Worker:
self.full_url("memory"), self.full_url("memory"),
timeout=3 timeout=3
) )
self.response = response
return response.status_code == 200 return response.status_code == 200
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
@ -640,6 +644,7 @@ class Worker:
return [] return []
def load_options(self, model, vae=None): def load_options(self, model, vae=None):
failure_msg = f"failed to load options for worker '{self.label}'"
if self.master: if self.master:
return return
@ -653,17 +658,25 @@ class Worker:
if vae is not None: if vae is not None:
payload['sd_vae'] = vae payload['sd_vae'] = vae
self.set_state(State.WORKING) state_cache = self.state
self.set_state(State.WORKING, expect_cycle=True) # may already be WORKING if called by worker.request()
start = time.time() start = time.time()
response = self.session.post( try:
self.full_url("options"), response = self.session.post(
json=payload self.full_url("options"),
) json=payload
)
except requests.exceptions.RequestException:
self.set_state(State.UNAVAILABLE)
logger.error(f"{failure_msg} (connection error... OOM?)")
return
elapsed = time.time() - start elapsed = time.time() - start
self.set_state(State.IDLE) if state_cache != State.WORKING: # see above comment, this lets caller determine when worker is IDLE
self.set_state(State.IDLE)
if response.status_code != 200: if response.status_code != 200:
logger.debug(f"failed to load options for worker '{self.label}'") logger.debug(failure_msg)
else: else:
logger.debug(f"worker '{self.label}' loaded weights in {elapsed:.2f}s") logger.debug(f"worker '{self.label}' loaded weights in {elapsed:.2f}s")
self.loaded_model = model_name self.loaded_model = model_name

View File

@ -21,6 +21,9 @@ from .worker import Worker, State
from modules.call_queue import wrap_queued_call, queue_lock from modules.call_queue import wrap_queued_call, queue_lock
from modules import processing from modules import processing
from modules import progress from modules import progress
from modules.scripts import PostprocessBatchListArgs
from torchvision.transforms import ToPILImage
from modules.images import image_grid
class NotBenchmarked(Exception): class NotBenchmarked(Exception):
@ -46,6 +49,7 @@ class Job:
self.complementary: bool = False self.complementary: bool = False
self.step_override = None self.step_override = None
self.thread = None self.thread = None
self.gallery_map: List[int] = []
def __str__(self): def __str__(self):
prefix = '' prefix = ''
@ -91,6 +95,7 @@ class World:
self.verify_remotes = verify_remotes self.verify_remotes = verify_remotes
self.thin_client_mode = False self.thin_client_mode = False
self.enabled = True self.enabled = True
self.enabled_i2i = True
self.is_dropdown_handler_injected = False self.is_dropdown_handler_injected = False
self.complement_production = True self.complement_production = True
self.step_scaling = False self.step_scaling = False
@ -103,7 +108,6 @@ class World:
def __repr__(self): def __repr__(self):
return f"{len(self._workers)} workers" return f"{len(self._workers)} workers"
def default_batch_size(self) -> int: def default_batch_size(self) -> int:
"""the amount of images/total images requested that a worker would compute if conditions were perfect and """the amount of images/total images requested that a worker would compute if conditions were perfect and
each worker generated at the same speed. assumes one batch only""" each worker generated at the same speed. assumes one batch only"""
@ -113,7 +117,7 @@ class World:
def size(self) -> int: def size(self) -> int:
""" """
Returns: Returns:
int: The number of nodes currently registered in the world. int: The number of nodes currently registered in the world and in a valid state
""" """
return len(self.get_workers()) return len(self.get_workers())
@ -273,7 +277,7 @@ class World:
gradio.Info("Distributed: benchmarking complete!") gradio.Info("Distributed: benchmarking complete!")
self.save_config() self.save_config()
def get_current_output_size(self) -> int: def num_requested(self) -> int:
""" """
returns how many images would be returned from all jobs returns how many images would be returned from all jobs
""" """
@ -285,6 +289,11 @@ class World:
return num_images return num_images
def num_gallery(self) -> int:
"""How many images should appear in the gallery. This includes local generations and a grid(if enabled)"""
return self.num_requested() * self.p.n_iter + shared.opts.return_grid
def speed_summary(self) -> str: def speed_summary(self) -> str:
""" """
Returns string listing workers by their ipm in descending order. Returns string listing workers by their ipm in descending order.
@ -393,7 +402,6 @@ class World:
self.initialized = True self.initialized = True
logger.debug("world initialized!") logger.debug("world initialized!")
def get_workers(self): def get_workers(self):
filtered: List[Worker] = [] filtered: List[Worker] = []
for worker in self._workers: for worker in self._workers:
@ -472,7 +480,7 @@ class World:
####################### #######################
# when total number of requested images was not cleanly divisible by world size then we tack the remainder on # when total number of requested images was not cleanly divisible by world size then we tack the remainder on
remainder_images = self.p.batch_size - self.get_current_output_size() remainder_images = self.p.batch_size - self.num_requested()
if remainder_images >= 1: if remainder_images >= 1:
logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed") logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
@ -551,21 +559,37 @@ class World:
else: else:
logger.debug("complementary image production is disabled") logger.debug("complementary image production is disabled")
logger.info(self.distro_summary(payload)) logger.info(self.distro_summary())
if self.thin_client_mode is True or self.master_job().batch_size == 0: if self.thin_client_mode is True or self.master_job().batch_size == 0:
# save original process_images_inner for later so we can restore once we're done # save original process_images_inner for later so we can restore once we're done
logger.debug(f"bypassing local generation completely") logger.debug(f"bypassing local generation completely")
def process_images_inner_bypass(p) -> processing.Processed: def process_images_inner_bypass(p) -> processing.Processed:
p.seeds, p.subseeds, p.negative_prompts, p.prompts = [], [], [], []
pp = PostprocessBatchListArgs(images=[])
self.p.scripts.postprocess_batch_list(p, pp)
processed = processing.Processed(p, [], p.seed, info="") processed = processing.Processed(p, [], p.seed, info="")
processed.all_prompts = [] processed.all_prompts = p.prompts
processed.all_seeds = [] processed.all_seeds = p.seeds
processed.all_subseeds = [] processed.all_subseeds = p.subseeds
processed.all_negative_prompts = [] processed.all_negative_prompts = p.negative_prompts
processed.infotexts = [] processed.images = pp.images
processed.prompt = None processed.infotexts = [''] * self.num_requested()
transform = ToPILImage()
for i, image in enumerate(processed.images):
processed.images[i] = transform(image)
self.p.scripts.postprocess(p, processed) self.p.scripts.postprocess(p, processed)
# generate grid if enabled
if shared.opts.return_grid and len(processed.images) > 1:
grid = image_grid(processed.images, len(processed.images))
processed.images.insert(0, grid)
processed.infotexts.insert(0, processed.infotexts[0])
return processed return processed
processing.process_images_inner = process_images_inner_bypass processing.process_images_inner = process_images_inner_bypass
@ -576,18 +600,17 @@ class World:
del self.jobs[last] del self.jobs[last]
last -= 1 last -= 1
def distro_summary(self, payload): def distro_summary(self):
# iterations = dict(payload)['n_iter'] num_returning = self.num_requested()
iterations = self.p.n_iter
num_returning = self.get_current_output_size()
num_complementary = num_returning - self.p.batch_size num_complementary = num_returning - self.p.batch_size
distro_summary = "Job distribution:\n" distro_summary = "Job distribution:\n"
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)" distro_summary += f"{self.p.batch_size} * {self.p.n_iter} iteration(s)"
if num_complementary > 0: if num_complementary > 0:
distro_summary += f" + {num_complementary} complementary" distro_summary += f" + {num_complementary} complementary"
distro_summary += f": {num_returning} images total\n" distro_summary += f": {num_returning * self.p.n_iter} images total\n"
for job in self.jobs: for job in self.jobs:
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n" distro_summary += f"'{job.worker.label}' - {job.batch_size * self.p.n_iter} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
return distro_summary return distro_summary
def config(self) -> dict: def config(self) -> dict:
@ -670,9 +693,10 @@ class World:
self.add_worker(**fields) self.add_worker(**fields)
sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload) sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload.dict())
self.job_timeout = config.job_timeout self.job_timeout = config.job_timeout
self.enabled = config.enabled self.enabled = config.enabled
self.enabled_i2i = config.enabled_i2i
self.complement_production = config.complement_production self.complement_production = config.complement_production
self.step_scaling = config.step_scaling self.step_scaling = config.step_scaling
@ -688,6 +712,7 @@ class World:
benchmark_payload=sh.benchmark_payload, benchmark_payload=sh.benchmark_payload,
job_timeout=self.job_timeout, job_timeout=self.job_timeout,
enabled=self.enabled, enabled=self.enabled,
enabled_i2i=self.enabled_i2i,
complement_production=self.complement_production, complement_production=self.complement_production,
step_scaling=self.step_scaling step_scaling=self.step_scaling
) )
@ -743,6 +768,9 @@ class World:
worker.set_state(State.IDLE, expect_cycle=True) worker.set_state(State.IDLE, expect_cycle=True)
else: else:
msg = f"worker '{worker.label}' is unreachable" msg = f"worker '{worker.label}' is unreachable"
if worker.response.status_code is not None:
msg += f" <{worker.response.status_code}>"
logger.info(msg) logger.info(msg)
gradio.Warning("Distributed: "+msg) gradio.Warning("Distributed: "+msg)
worker.set_state(State.UNAVAILABLE) worker.set_state(State.UNAVAILABLE)
@ -753,7 +781,6 @@ class World:
for worker in self._workers: for worker in self._workers:
worker.restart() worker.restart()
def inject_model_dropdown_handler(self): def inject_model_dropdown_handler(self):
if self.config().get('enabled', False): # TODO avoid access from config() if self.config().get('enabled', False): # TODO avoid access from config()
return return