commit
8fd65ebdc1
15
CHANGELOG.md
15
CHANGELOG.md
|
|
@ -1,10 +1,23 @@
|
|||
# Change Log
|
||||
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
|
||||
|
||||
## [2.3.0] - 2024-10-26
|
||||
|
||||
## Added
|
||||
- Compatibility for some extensions which mostly only do postprocessing (e.g. Adetailer)
|
||||
- Separate toggle state for img2img tab so txt2img can be enabled and t2i disabled or vice versa
|
||||
|
||||
## Changed
|
||||
- Status tab will now automatically refresh
|
||||
- Main toggle is now in the form of an InputAccordion
|
||||
|
||||
## Fixed
|
||||
- An issue affecting controlnet and inpainting
|
||||
- Toggle state sometimes desyncing when the page was refreshed
|
||||
|
||||
## [2.2.2] - 2024-8-30
|
||||
|
||||
### Fixed
|
||||
|
||||
- Unavailable state sometimes being ignored
|
||||
|
||||
## [2.2.1] - 2024-5-16
|
||||
|
|
|
|||
|
|
@ -2,3 +2,22 @@
|
|||
function confirm_restart_workers(_) {
|
||||
return confirm('Restart remote workers?')
|
||||
}
|
||||
|
||||
// live updates
|
||||
function update() {
|
||||
try {
|
||||
let currentTab = get_uiCurrentTabContent()
|
||||
let buttons = document.querySelectorAll('#distributed-refresh-status')
|
||||
for(let i = 0; i < buttons.length; i++) {
|
||||
if(currentTab.contains(buttons[i])) {
|
||||
buttons[i].click()
|
||||
break
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (!(e instanceof TypeError)) {
|
||||
throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
setInterval(update, 1500)
|
||||
|
|
@ -13,6 +13,7 @@ import time
|
|||
from threading import Thread
|
||||
from typing import List
|
||||
import gradio
|
||||
from torchvision.transforms import ToTensor
|
||||
import urllib3
|
||||
from PIL import Image
|
||||
from modules import processing
|
||||
|
|
@ -24,7 +25,7 @@ from modules.shared import state as webui_state
|
|||
from scripts.spartan.control_net import pack_control_net
|
||||
from scripts.spartan.shared import logger
|
||||
from scripts.spartan.ui import UI
|
||||
from scripts.spartan.world import World, State
|
||||
from scripts.spartan.world import World, State, Job
|
||||
|
||||
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||
|
|
@ -61,7 +62,7 @@ class DistributedScript(scripts.Script):
|
|||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
extension_ui = UI(world=self.world)
|
||||
extension_ui = UI(world=self.world, is_img2img=is_img2img)
|
||||
# root, api_exposed = extension_ui.create_ui()
|
||||
components = extension_ui.create_ui()
|
||||
|
||||
|
|
@ -71,77 +72,61 @@ class DistributedScript(scripts.Script):
|
|||
# return some components that should be exposed to the api
|
||||
return components
|
||||
|
||||
def add_to_gallery(self, processed, p):
|
||||
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
||||
def api_to_internal(self, job) -> ([], [], [], [], []):
|
||||
# takes worker response received from api and returns parsed objects in internal sdwui format. E.g. all_seeds
|
||||
|
||||
def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None):
|
||||
image_params: json = response['parameters']
|
||||
image_info_post: json = json.loads(response["info"]) # image info known after processing
|
||||
num_response_images = image_params["batch_size"] * image_params["n_iter"]
|
||||
|
||||
seed = None
|
||||
subseed = None
|
||||
negative_prompt = None
|
||||
pos_prompt = None
|
||||
image_params: json = job.worker.response['parameters']
|
||||
image_info_post: json = json.loads(job.worker.response["info"]) # image info known after processing
|
||||
all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = [], [], [], [], []
|
||||
|
||||
for i in range(len(job.worker.response["images"])):
|
||||
try:
|
||||
if num_response_images > 1:
|
||||
seed = image_info_post['all_seeds'][info_index]
|
||||
subseed = image_info_post['all_subseeds'][info_index]
|
||||
negative_prompt = image_info_post['all_negative_prompts'][info_index]
|
||||
pos_prompt = image_info_post['all_prompts'][info_index]
|
||||
else:
|
||||
seed = image_info_post['seed']
|
||||
subseed = image_info_post['subseed']
|
||||
negative_prompt = image_info_post['negative_prompt']
|
||||
pos_prompt = image_info_post['prompt']
|
||||
if image_params["batch_size"] * image_params["n_iter"] > 1:
|
||||
all_seeds.append(image_info_post['all_seeds'][i])
|
||||
all_subseeds.append(image_info_post['all_subseeds'][i])
|
||||
all_negative_prompts.append(image_info_post['all_negative_prompts'][i])
|
||||
all_prompts.append(image_info_post['all_prompts'][i])
|
||||
else: # only a single image received
|
||||
all_seeds.append(image_info_post['seed'])
|
||||
all_subseeds.append(image_info_post['subseed'])
|
||||
all_negative_prompts.append(image_info_post['negative_prompt'])
|
||||
all_prompts.append(image_info_post['prompt'])
|
||||
except IndexError:
|
||||
# like with controlnet masks, there isn't always full post-gen info, so we use the first images'
|
||||
logger.debug(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
|
||||
processed_inject_image(image=image, info_index=0, response=response)
|
||||
return
|
||||
# # like with controlnet masks, there isn't always full post-gen info, so we use the first images'
|
||||
# logger.debug(f"Image at index {info_index} for '{job.worker.label}' was missing some post-generation data")
|
||||
# self.processed_inject_image(image=image, info_index=0, job=job, p=p)
|
||||
# return
|
||||
logger.critical(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
|
||||
continue
|
||||
|
||||
processed.all_seeds.append(seed)
|
||||
processed.all_subseeds.append(subseed)
|
||||
processed.all_negative_prompts.append(negative_prompt)
|
||||
processed.all_prompts.append(pos_prompt)
|
||||
processed.images.append(image) # actual received image
|
||||
# parse image
|
||||
image_bytes = base64.b64decode(job.worker.response["images"][i])
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
transform = ToTensor()
|
||||
images.append(transform(image))
|
||||
|
||||
# generate info-text string
|
||||
return all_seeds, all_subseeds, all_negative_prompts, all_prompts, images
|
||||
|
||||
def inject_job(self, job: Job, p, pp):
|
||||
"""Adds the work completed by one Job via its worker response to the processing and postprocessing objects"""
|
||||
all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = self.api_to_internal(job)
|
||||
|
||||
p.seeds.extend(all_seeds)
|
||||
p.subseeds.extend(all_subseeds)
|
||||
p.negative_prompts.extend(all_negative_prompts)
|
||||
p.prompts.extend(all_prompts)
|
||||
|
||||
num_local = self.world.p.n_iter * self.world.p.batch_size + (opts.return_grid - self.world.thin_client_mode)
|
||||
num_injected = len(pp.images) - self.world.p.batch_size
|
||||
for i, image in enumerate(images):
|
||||
# modules.ui_common -> update_generation_info renders to html below gallery
|
||||
images_per_batch = p.n_iter * p.batch_size
|
||||
# zero-indexed position of image in total batch (so including master results)
|
||||
true_image_pos = len(processed.images) - 1
|
||||
num_remote_images = images_per_batch * p.batch_size
|
||||
if p.n_iter > 1: # if splitting by batch count
|
||||
num_remote_images *= p.n_iter - 1
|
||||
gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
|
||||
job.gallery_map.append(gallery_index) # so we know where to edit infotext
|
||||
pp.images.append(image)
|
||||
logger.debug(f"image {gallery_index + 1 + self.world.thin_client_mode}/{self.world.num_gallery()}")
|
||||
|
||||
logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, "
|
||||
f"info-index: {info_index}")
|
||||
|
||||
if self.world.thin_client_mode:
|
||||
p.all_negative_prompts = processed.all_negative_prompts
|
||||
|
||||
try:
|
||||
info_text = image_info_post['infotexts'][i]
|
||||
except IndexError:
|
||||
if not grid:
|
||||
logger.warning(f"image {true_image_pos + 1} was missing info-text")
|
||||
info_text = processed.infotexts[0]
|
||||
info_text += f", Worker Label: {job.worker.label}"
|
||||
processed.infotexts.append(info_text)
|
||||
|
||||
# automatically save received image to local disk if desired
|
||||
if cmd_opts.distributed_remotes_autosave:
|
||||
save_image(
|
||||
image=image,
|
||||
path=p.outpath_samples if save_path_override is None else save_path_override,
|
||||
basename="",
|
||||
seed=seed,
|
||||
prompt=pos_prompt,
|
||||
info=info_text,
|
||||
extension=opts.samples_format
|
||||
)
|
||||
def update_gallery(self, pp, p):
|
||||
"""adds all remotely generated images to the image gallery after waiting for all workers to finish"""
|
||||
|
||||
# get master ipm by estimating based on worker speed
|
||||
master_elapsed = time.time() - self.master_start
|
||||
|
|
@ -158,8 +143,7 @@ class DistributedScript(scripts.Script):
|
|||
logger.debug("all worker request threads returned")
|
||||
webui_state.textinfo = "Distributed - injecting images"
|
||||
|
||||
# some worker which we know has a good response that we can use for generating the grid
|
||||
donor_worker = None
|
||||
received_images = False
|
||||
for job in self.world.jobs:
|
||||
if job.worker.response is None or job.batch_size < 1 or job.worker.master:
|
||||
continue
|
||||
|
|
@ -170,8 +154,7 @@ class DistributedScript(scripts.Script):
|
|||
if (job.batch_size * p.n_iter) < len(images):
|
||||
logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}")
|
||||
|
||||
if donor_worker is None:
|
||||
donor_worker = job.worker
|
||||
received_images = True
|
||||
except KeyError:
|
||||
if job.batch_size > 0:
|
||||
logger.warning(f"Worker '{job.worker.label}' had no images")
|
||||
|
|
@ -185,41 +168,27 @@ class DistributedScript(scripts.Script):
|
|||
logger.exception(e)
|
||||
continue
|
||||
|
||||
# visibly add work from workers to the image gallery
|
||||
for i in range(0, len(images)):
|
||||
image_bytes = base64.b64decode(images[i])
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
# adding the images in
|
||||
self.inject_job(job, p, pp)
|
||||
|
||||
# inject image
|
||||
processed_inject_image(image=image, info_index=i, response=job.worker.response)
|
||||
|
||||
if donor_worker is None:
|
||||
# TODO fix controlnet masks returned via api having no generation info
|
||||
if received_images is False:
|
||||
logger.critical("couldn't collect any responses, the extension will have no effect")
|
||||
return
|
||||
|
||||
# generate and inject grid
|
||||
if opts.return_grid and len(processed.images) > 1:
|
||||
grid = image_grid(processed.images, len(processed.images))
|
||||
processed_inject_image(
|
||||
image=grid,
|
||||
info_index=0,
|
||||
save_path_override=p.outpath_grids,
|
||||
grid=True,
|
||||
response=donor_worker.response
|
||||
)
|
||||
|
||||
# cleanup after we're doing using all the responses
|
||||
for worker in self.world.get_workers():
|
||||
worker.response = None
|
||||
|
||||
p.batch_size = len(processed.images)
|
||||
p.batch_size = len(pp.images)
|
||||
webui_state.textinfo = ""
|
||||
return
|
||||
|
||||
# p's type is
|
||||
# "modules.processing.StableDiffusionProcessing*"
|
||||
def before_process(self, p, *args):
|
||||
if not self.world.enabled:
|
||||
logger.debug("extension is disabled")
|
||||
is_img2img = getattr(p, 'init_images', False)
|
||||
if is_img2img and self.world.enabled_i2i is False:
|
||||
logger.debug("extension is disabled for i2i")
|
||||
return
|
||||
elif not is_img2img and self.world.enabled is False:
|
||||
logger.debug("extension is disabled for t2i")
|
||||
return
|
||||
self.world.update(p)
|
||||
|
||||
|
|
@ -234,6 +203,14 @@ class DistributedScript(scripts.Script):
|
|||
continue
|
||||
title = script.title()
|
||||
|
||||
if title == "ADetailer":
|
||||
adetailer_args = p.script_args[script.args_from:script.args_to]
|
||||
|
||||
# InputAccordion main toggle, skip img2img toggle
|
||||
if adetailer_args[0] and adetailer_args[1]:
|
||||
logger.debug(f"adetailer is skipping img2img, returning control to wui")
|
||||
return
|
||||
|
||||
# check for supported scripts
|
||||
if title == "ControlNet":
|
||||
# grab all controlnet units
|
||||
|
|
@ -346,18 +323,34 @@ class DistributedScript(scripts.Script):
|
|||
p.batch_size = self.world.master_job().batch_size
|
||||
self.master_start = time.time()
|
||||
|
||||
# generate images assigned to local machine
|
||||
p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
|
||||
self.runs_since_init += 1
|
||||
return
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
if not self.world.enabled:
|
||||
def postprocess_batch_list(self, p, pp, *args, **kwargs):
|
||||
if not self.world.thin_client_mode and p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
|
||||
return
|
||||
|
||||
is_img2img = getattr(p, 'init_images', False)
|
||||
if is_img2img and self.world.enabled_i2i is False:
|
||||
return
|
||||
elif not is_img2img and self.world.enabled is False:
|
||||
return
|
||||
|
||||
if self.master_start is not None:
|
||||
self.add_to_gallery(p=p, processed=processed)
|
||||
self.update_gallery(p=p, pp=pp)
|
||||
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
for job in self.world.jobs:
|
||||
if job.worker.response is not None:
|
||||
for i, v in enumerate(job.gallery_map):
|
||||
infotext = json.loads(job.worker.response['info'])['infotexts'][i]
|
||||
infotext += f", Worker Label: {job.worker.label}"
|
||||
processed.infotexts[v] = infotext
|
||||
|
||||
# cleanup
|
||||
for worker in self.world.get_workers():
|
||||
worker.response = None
|
||||
# restore process_images_inner if it was monkey-patched
|
||||
processing.process_images_inner = self.original_process_images_inner
|
||||
# save any dangling state to prevent load_config in next iteration overwriting it
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
# https://github.com/Mikubill/sd-webui-controlnet/wiki/API#examples-1
|
||||
|
||||
import copy
|
||||
from PIL import Image
|
||||
from modules.api.api import encode_pil_to_base64
|
||||
from scripts.spartan.shared import logger
|
||||
import numpy as np
|
||||
import json
|
||||
import enum
|
||||
|
||||
|
||||
def np_to_b64(image: np.ndarray):
|
||||
|
|
@ -58,10 +61,14 @@ def pack_control_net(cn_units) -> dict:
|
|||
unit['mask'] = mask_b64 # mikubill
|
||||
unit['mask_image'] = mask_b64 # forge
|
||||
|
||||
|
||||
# serialize all enums
|
||||
for k in unit.keys():
|
||||
if isinstance(unit[k], enum.Enum):
|
||||
unit[k] = unit[k].value
|
||||
|
||||
# avoid returning duplicate detection maps since master should return the same one
|
||||
unit['save_detected_map'] = False
|
||||
# remove anything unserializable
|
||||
del unit['input_mode']
|
||||
|
||||
try:
|
||||
json.dumps(controlnet)
|
||||
|
|
|
|||
|
|
@ -29,17 +29,18 @@ class Worker_Model(BaseModel):
|
|||
default=False
|
||||
)
|
||||
state: Optional[Any] = Field(default=1, description="The last known state of this worker")
|
||||
user: Optional[str] = Field(description="The username to be used when authenticating with this worker")
|
||||
password: Optional[str] = Field(description="The password to be used when authenticating with this worker")
|
||||
user: Optional[str] = Field(description="The username to be used when authenticating with this worker", default=None)
|
||||
password: Optional[str] = Field(description="The password to be used when authenticating with this worker", default=None)
|
||||
pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit")
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
workers: List[Dict[str, Worker_Model]]
|
||||
benchmark_payload: Dict = Field(
|
||||
benchmark_payload: Benchmark_Payload = Field(
|
||||
default=Benchmark_Payload,
|
||||
description='the payload used when benchmarking a node'
|
||||
)
|
||||
job_timeout: Optional[int] = Field(default=3)
|
||||
enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True)
|
||||
enabled_i2i: Optional[bool] = Field(description="Same as above but for image to image", default=True)
|
||||
complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True)
|
||||
step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False)
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ samples = 3 # number of times to benchmark worker after warmup benchmarks are c
|
|||
|
||||
|
||||
class BenchmarkPayload(BaseModel):
|
||||
validate_assignment = True
|
||||
# validate_assignment = True
|
||||
prompt: str = Field(default="A herd of cows grazing at the bottom of a sunny valley")
|
||||
negative_prompt: str = Field(default="")
|
||||
steps: int = Field(default=20)
|
||||
|
|
|
|||
|
|
@ -9,16 +9,17 @@ from .shared import logger, LOG_LEVEL, gui_handler
|
|||
from .worker import State
|
||||
from modules.call_queue import queue_lock
|
||||
from modules import progress
|
||||
from modules.ui_components import InputAccordion
|
||||
|
||||
worker_select_dropdown = None
|
||||
|
||||
|
||||
class UI:
|
||||
"""extension user interface related things"""
|
||||
|
||||
def __init__(self, world):
|
||||
def __init__(self, world, is_img2img):
|
||||
self.world = world
|
||||
self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange
|
||||
self.is_img2img = is_img2img
|
||||
|
||||
# handlers
|
||||
@staticmethod
|
||||
|
|
@ -184,12 +185,20 @@ class UI:
|
|||
worker.session.auth = (user, password)
|
||||
self.world.save_config()
|
||||
|
||||
def main_toggle_btn(self):
|
||||
self.world.enabled = not self.world.enabled
|
||||
def main_toggle_btn(self, state):
|
||||
if self.is_img2img:
|
||||
if self.world.enabled_i2i == state: # just prevents a redundant config save if ui desyncs
|
||||
return
|
||||
self.world.enabled_i2i = state
|
||||
else:
|
||||
if self.world.enabled == state:
|
||||
return
|
||||
self.world.enabled = state
|
||||
|
||||
self.world.save_config()
|
||||
|
||||
# restore vanilla sdwui handler for model dropdown if extension is disabled or inject if otherwise
|
||||
if not self.world.enabled:
|
||||
if not self.world.enabled and not self.world.enabled_i2i:
|
||||
model_dropdown = opts.data_labels.get('sd_model_checkpoint')
|
||||
if self.original_model_dropdown_handler is not None:
|
||||
model_dropdown.onchange = self.original_model_dropdown_handler
|
||||
|
|
@ -208,17 +217,14 @@ class UI:
|
|||
def create_ui(self):
|
||||
"""creates the extension UI and returns relevant components"""
|
||||
components = []
|
||||
elem_id = 'enabled'
|
||||
if self.is_img2img:
|
||||
elem_id += '_i2i'
|
||||
|
||||
with gradio.Blocks(variant='compact'): # Group() and Box() remove spacing
|
||||
with gradio.Accordion(label='Distributed', open=False):
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/6109#issuecomment-1403315784
|
||||
main_toggle = gradio.Checkbox( # main on/off ext. toggle
|
||||
elem_id='enable',
|
||||
label='Enable',
|
||||
value=self.world.enabled if self.world.enabled is not None else True,
|
||||
interactive=True
|
||||
)
|
||||
main_toggle.input(self.main_toggle_btn)
|
||||
with InputAccordion(label='Distributed', open=False, value=self.world.config().get(elem_id), elem_id=elem_id) as main_toggle:
|
||||
main_toggle.input(self.main_toggle_btn, inputs=[main_toggle])
|
||||
setattr(main_toggle.accordion, 'do_not_save_to_config', True) # InputAccordion is really a CheckBox
|
||||
components.append(main_toggle)
|
||||
|
||||
with gradio.Tab('Status') as status_tab:
|
||||
|
|
@ -236,8 +242,8 @@ class UI:
|
|||
info='top-most message is newest'
|
||||
)
|
||||
|
||||
refresh_status_btn = gradio.Button(value='Refresh 🔄', size='sm')
|
||||
refresh_status_btn.click(self.status_btn, inputs=[], outputs=[jobs, status, logs])
|
||||
refresh_status_btn = gradio.Button(value='Refresh 🔄', size='sm', elem_id='distributed-refresh-status', visible=False)
|
||||
refresh_status_btn.click(self.status_btn, inputs=[], outputs=[jobs, status, logs], show_progress=False)
|
||||
|
||||
status_tab.select(fn=self.status_btn, inputs=[], outputs=[jobs, status, logs])
|
||||
components += [status, jobs, logs, refresh_status_btn]
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from modules.shared import cmd_opts
|
|||
from modules.shared import state as master_state
|
||||
from . import shared as sh
|
||||
from .shared import logger, warmup_samples, LOG_LEVEL
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
from webui import server_name
|
||||
|
|
@ -40,6 +41,13 @@ class State(Enum):
|
|||
DISABLED = 5
|
||||
|
||||
|
||||
# looks redundant when encode_pil...() could be used, but it does not support all file formats. E.g. AVIF
|
||||
def pil_to_64(image: Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
return 'data:image/png;base64,' + str(base64.b64encode(buffer.getvalue()), 'utf-8')
|
||||
|
||||
|
||||
class Worker:
|
||||
"""
|
||||
This class represents a worker node in a distributed computing setup.
|
||||
|
|
@ -361,10 +369,7 @@ class Worker:
|
|||
mode = 'img2img' # for use in checking script compat
|
||||
images = []
|
||||
for image in init_images:
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
image = 'data:image/png;base64,' + str(base64.b64encode(buffer.getvalue()), 'utf-8')
|
||||
images.append(image)
|
||||
images.append(pil_to_64(image))
|
||||
payload['init_images'] = images
|
||||
|
||||
alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example
|
||||
|
|
@ -401,9 +406,7 @@ class Worker:
|
|||
# if an image mask is present
|
||||
image_mask = payload.get('image_mask', None)
|
||||
if image_mask is not None:
|
||||
image_b64 = encode_pil_to_base64(image_mask)
|
||||
image_b64 = str(image_b64, 'utf-8')
|
||||
payload['mask'] = image_b64
|
||||
payload['mask'] = pil_to_64(image_mask)
|
||||
del payload['image_mask']
|
||||
|
||||
# see if there is anything else wrong with serializing to payload
|
||||
|
|
@ -606,6 +609,7 @@ class Worker:
|
|||
self.full_url("memory"),
|
||||
timeout=3
|
||||
)
|
||||
self.response = response
|
||||
return response.status_code == 200
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
|
|
@ -640,6 +644,7 @@ class Worker:
|
|||
return []
|
||||
|
||||
def load_options(self, model, vae=None):
|
||||
failure_msg = f"failed to load options for worker '{self.label}'"
|
||||
if self.master:
|
||||
return
|
||||
|
||||
|
|
@ -653,17 +658,25 @@ class Worker:
|
|||
if vae is not None:
|
||||
payload['sd_vae'] = vae
|
||||
|
||||
self.set_state(State.WORKING)
|
||||
state_cache = self.state
|
||||
self.set_state(State.WORKING, expect_cycle=True) # may already be WORKING if called by worker.request()
|
||||
start = time.time()
|
||||
response = self.session.post(
|
||||
self.full_url("options"),
|
||||
json=payload
|
||||
)
|
||||
try:
|
||||
response = self.session.post(
|
||||
self.full_url("options"),
|
||||
json=payload
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
self.set_state(State.UNAVAILABLE)
|
||||
logger.error(f"{failure_msg} (connection error... OOM?)")
|
||||
return
|
||||
|
||||
elapsed = time.time() - start
|
||||
self.set_state(State.IDLE)
|
||||
if state_cache != State.WORKING: # see above comment, this lets caller determine when worker is IDLE
|
||||
self.set_state(State.IDLE)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.debug(f"failed to load options for worker '{self.label}'")
|
||||
logger.debug(failure_msg)
|
||||
else:
|
||||
logger.debug(f"worker '{self.label}' loaded weights in {elapsed:.2f}s")
|
||||
self.loaded_model = model_name
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ from .worker import Worker, State
|
|||
from modules.call_queue import wrap_queued_call, queue_lock
|
||||
from modules import processing
|
||||
from modules import progress
|
||||
from modules.scripts import PostprocessBatchListArgs
|
||||
from torchvision.transforms import ToPILImage
|
||||
from modules.images import image_grid
|
||||
|
||||
|
||||
class NotBenchmarked(Exception):
|
||||
|
|
@ -46,6 +49,7 @@ class Job:
|
|||
self.complementary: bool = False
|
||||
self.step_override = None
|
||||
self.thread = None
|
||||
self.gallery_map: List[int] = []
|
||||
|
||||
def __str__(self):
|
||||
prefix = ''
|
||||
|
|
@ -91,6 +95,7 @@ class World:
|
|||
self.verify_remotes = verify_remotes
|
||||
self.thin_client_mode = False
|
||||
self.enabled = True
|
||||
self.enabled_i2i = True
|
||||
self.is_dropdown_handler_injected = False
|
||||
self.complement_production = True
|
||||
self.step_scaling = False
|
||||
|
|
@ -103,7 +108,6 @@ class World:
|
|||
def __repr__(self):
|
||||
return f"{len(self._workers)} workers"
|
||||
|
||||
|
||||
def default_batch_size(self) -> int:
|
||||
"""the amount of images/total images requested that a worker would compute if conditions were perfect and
|
||||
each worker generated at the same speed. assumes one batch only"""
|
||||
|
|
@ -113,7 +117,7 @@ class World:
|
|||
def size(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
int: The number of nodes currently registered in the world.
|
||||
int: The number of nodes currently registered in the world and in a valid state
|
||||
"""
|
||||
return len(self.get_workers())
|
||||
|
||||
|
|
@ -273,7 +277,7 @@ class World:
|
|||
gradio.Info("Distributed: benchmarking complete!")
|
||||
self.save_config()
|
||||
|
||||
def get_current_output_size(self) -> int:
|
||||
def num_requested(self) -> int:
|
||||
"""
|
||||
returns how many images would be returned from all jobs
|
||||
"""
|
||||
|
|
@ -285,6 +289,11 @@ class World:
|
|||
|
||||
return num_images
|
||||
|
||||
def num_gallery(self) -> int:
|
||||
"""How many images should appear in the gallery. This includes local generations and a grid(if enabled)"""
|
||||
|
||||
return self.num_requested() * self.p.n_iter + shared.opts.return_grid
|
||||
|
||||
def speed_summary(self) -> str:
|
||||
"""
|
||||
Returns string listing workers by their ipm in descending order.
|
||||
|
|
@ -393,7 +402,6 @@ class World:
|
|||
self.initialized = True
|
||||
logger.debug("world initialized!")
|
||||
|
||||
|
||||
def get_workers(self):
|
||||
filtered: List[Worker] = []
|
||||
for worker in self._workers:
|
||||
|
|
@ -472,7 +480,7 @@ class World:
|
|||
#######################
|
||||
|
||||
# when total number of requested images was not cleanly divisible by world size then we tack the remainder on
|
||||
remainder_images = self.p.batch_size - self.get_current_output_size()
|
||||
remainder_images = self.p.batch_size - self.num_requested()
|
||||
if remainder_images >= 1:
|
||||
logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
|
||||
|
||||
|
|
@ -551,21 +559,37 @@ class World:
|
|||
else:
|
||||
logger.debug("complementary image production is disabled")
|
||||
|
||||
logger.info(self.distro_summary(payload))
|
||||
logger.info(self.distro_summary())
|
||||
|
||||
if self.thin_client_mode is True or self.master_job().batch_size == 0:
|
||||
# save original process_images_inner for later so we can restore once we're done
|
||||
logger.debug(f"bypassing local generation completely")
|
||||
def process_images_inner_bypass(p) -> processing.Processed:
|
||||
p.seeds, p.subseeds, p.negative_prompts, p.prompts = [], [], [], []
|
||||
pp = PostprocessBatchListArgs(images=[])
|
||||
self.p.scripts.postprocess_batch_list(p, pp)
|
||||
|
||||
processed = processing.Processed(p, [], p.seed, info="")
|
||||
processed.all_prompts = []
|
||||
processed.all_seeds = []
|
||||
processed.all_subseeds = []
|
||||
processed.all_negative_prompts = []
|
||||
processed.infotexts = []
|
||||
processed.prompt = None
|
||||
processed.all_prompts = p.prompts
|
||||
processed.all_seeds = p.seeds
|
||||
processed.all_subseeds = p.subseeds
|
||||
processed.all_negative_prompts = p.negative_prompts
|
||||
processed.images = pp.images
|
||||
processed.infotexts = [''] * self.num_requested()
|
||||
|
||||
transform = ToPILImage()
|
||||
for i, image in enumerate(processed.images):
|
||||
processed.images[i] = transform(image)
|
||||
|
||||
|
||||
self.p.scripts.postprocess(p, processed)
|
||||
|
||||
# generate grid if enabled
|
||||
if shared.opts.return_grid and len(processed.images) > 1:
|
||||
grid = image_grid(processed.images, len(processed.images))
|
||||
processed.images.insert(0, grid)
|
||||
processed.infotexts.insert(0, processed.infotexts[0])
|
||||
|
||||
return processed
|
||||
processing.process_images_inner = process_images_inner_bypass
|
||||
|
||||
|
|
@ -576,18 +600,17 @@ class World:
|
|||
del self.jobs[last]
|
||||
last -= 1
|
||||
|
||||
def distro_summary(self, payload):
|
||||
# iterations = dict(payload)['n_iter']
|
||||
iterations = self.p.n_iter
|
||||
num_returning = self.get_current_output_size()
|
||||
def distro_summary(self):
|
||||
num_returning = self.num_requested()
|
||||
num_complementary = num_returning - self.p.batch_size
|
||||
|
||||
distro_summary = "Job distribution:\n"
|
||||
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
||||
distro_summary += f"{self.p.batch_size} * {self.p.n_iter} iteration(s)"
|
||||
if num_complementary > 0:
|
||||
distro_summary += f" + {num_complementary} complementary"
|
||||
distro_summary += f": {num_returning} images total\n"
|
||||
distro_summary += f": {num_returning * self.p.n_iter} images total\n"
|
||||
for job in self.jobs:
|
||||
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
||||
distro_summary += f"'{job.worker.label}' - {job.batch_size * self.p.n_iter} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
||||
return distro_summary
|
||||
|
||||
def config(self) -> dict:
|
||||
|
|
@ -670,9 +693,10 @@ class World:
|
|||
|
||||
self.add_worker(**fields)
|
||||
|
||||
sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload)
|
||||
sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload.dict())
|
||||
self.job_timeout = config.job_timeout
|
||||
self.enabled = config.enabled
|
||||
self.enabled_i2i = config.enabled_i2i
|
||||
self.complement_production = config.complement_production
|
||||
self.step_scaling = config.step_scaling
|
||||
|
||||
|
|
@ -688,6 +712,7 @@ class World:
|
|||
benchmark_payload=sh.benchmark_payload,
|
||||
job_timeout=self.job_timeout,
|
||||
enabled=self.enabled,
|
||||
enabled_i2i=self.enabled_i2i,
|
||||
complement_production=self.complement_production,
|
||||
step_scaling=self.step_scaling
|
||||
)
|
||||
|
|
@ -743,6 +768,9 @@ class World:
|
|||
worker.set_state(State.IDLE, expect_cycle=True)
|
||||
else:
|
||||
msg = f"worker '{worker.label}' is unreachable"
|
||||
if worker.response.status_code is not None:
|
||||
msg += f" <{worker.response.status_code}>"
|
||||
|
||||
logger.info(msg)
|
||||
gradio.Warning("Distributed: "+msg)
|
||||
worker.set_state(State.UNAVAILABLE)
|
||||
|
|
@ -753,7 +781,6 @@ class World:
|
|||
for worker in self._workers:
|
||||
worker.restart()
|
||||
|
||||
|
||||
def inject_model_dropdown_handler(self):
|
||||
if self.config().get('enabled', False): # TODO avoid access from config()
|
||||
return
|
||||
|
|
|
|||
Loading…
Reference in New Issue