commit
8fd65ebdc1
15
CHANGELOG.md
15
CHANGELOG.md
|
|
@ -1,10 +1,23 @@
|
||||||
# Change Log
|
# Change Log
|
||||||
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
|
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
|
||||||
|
|
||||||
|
## [2.3.0] - 2024-10-26
|
||||||
|
|
||||||
|
## Added
|
||||||
|
- Compatibility for some extensions which mostly only do postprocessing (e.g. Adetailer)
|
||||||
|
- Separate toggle state for img2img tab so txt2img can be enabled and t2i disabled or vice versa
|
||||||
|
|
||||||
|
## Changed
|
||||||
|
- Status tab will now automatically refresh
|
||||||
|
- Main toggle is now in the form of an InputAccordion
|
||||||
|
|
||||||
|
## Fixed
|
||||||
|
- An issue affecting controlnet and inpainting
|
||||||
|
- Toggle state sometimes desyncing when the page was refreshed
|
||||||
|
|
||||||
## [2.2.2] - 2024-8-30
|
## [2.2.2] - 2024-8-30
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
- Unavailable state sometimes being ignored
|
- Unavailable state sometimes being ignored
|
||||||
|
|
||||||
## [2.2.1] - 2024-5-16
|
## [2.2.1] - 2024-5-16
|
||||||
|
|
|
||||||
|
|
@ -2,3 +2,22 @@
|
||||||
function confirm_restart_workers(_) {
|
function confirm_restart_workers(_) {
|
||||||
return confirm('Restart remote workers?')
|
return confirm('Restart remote workers?')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// live updates
|
||||||
|
function update() {
|
||||||
|
try {
|
||||||
|
let currentTab = get_uiCurrentTabContent()
|
||||||
|
let buttons = document.querySelectorAll('#distributed-refresh-status')
|
||||||
|
for(let i = 0; i < buttons.length; i++) {
|
||||||
|
if(currentTab.contains(buttons[i])) {
|
||||||
|
buttons[i].click()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
if (!(e instanceof TypeError)) {
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setInterval(update, 1500)
|
||||||
|
|
@ -13,6 +13,7 @@ import time
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import List
|
from typing import List
|
||||||
import gradio
|
import gradio
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
import urllib3
|
import urllib3
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules import processing
|
from modules import processing
|
||||||
|
|
@ -24,7 +25,7 @@ from modules.shared import state as webui_state
|
||||||
from scripts.spartan.control_net import pack_control_net
|
from scripts.spartan.control_net import pack_control_net
|
||||||
from scripts.spartan.shared import logger
|
from scripts.spartan.shared import logger
|
||||||
from scripts.spartan.ui import UI
|
from scripts.spartan.ui import UI
|
||||||
from scripts.spartan.world import World, State
|
from scripts.spartan.world import World, State, Job
|
||||||
|
|
||||||
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||||
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||||
|
|
@ -61,7 +62,7 @@ class DistributedScript(scripts.Script):
|
||||||
return scripts.AlwaysVisible
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
extension_ui = UI(world=self.world)
|
extension_ui = UI(world=self.world, is_img2img=is_img2img)
|
||||||
# root, api_exposed = extension_ui.create_ui()
|
# root, api_exposed = extension_ui.create_ui()
|
||||||
components = extension_ui.create_ui()
|
components = extension_ui.create_ui()
|
||||||
|
|
||||||
|
|
@ -71,77 +72,61 @@ class DistributedScript(scripts.Script):
|
||||||
# return some components that should be exposed to the api
|
# return some components that should be exposed to the api
|
||||||
return components
|
return components
|
||||||
|
|
||||||
def add_to_gallery(self, processed, p):
|
def api_to_internal(self, job) -> ([], [], [], [], []):
|
||||||
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
# takes worker response received from api and returns parsed objects in internal sdwui format. E.g. all_seeds
|
||||||
|
|
||||||
def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None):
|
image_params: json = job.worker.response['parameters']
|
||||||
image_params: json = response['parameters']
|
image_info_post: json = json.loads(job.worker.response["info"]) # image info known after processing
|
||||||
image_info_post: json = json.loads(response["info"]) # image info known after processing
|
all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = [], [], [], [], []
|
||||||
num_response_images = image_params["batch_size"] * image_params["n_iter"]
|
|
||||||
|
|
||||||
seed = None
|
|
||||||
subseed = None
|
|
||||||
negative_prompt = None
|
|
||||||
pos_prompt = None
|
|
||||||
|
|
||||||
|
for i in range(len(job.worker.response["images"])):
|
||||||
try:
|
try:
|
||||||
if num_response_images > 1:
|
if image_params["batch_size"] * image_params["n_iter"] > 1:
|
||||||
seed = image_info_post['all_seeds'][info_index]
|
all_seeds.append(image_info_post['all_seeds'][i])
|
||||||
subseed = image_info_post['all_subseeds'][info_index]
|
all_subseeds.append(image_info_post['all_subseeds'][i])
|
||||||
negative_prompt = image_info_post['all_negative_prompts'][info_index]
|
all_negative_prompts.append(image_info_post['all_negative_prompts'][i])
|
||||||
pos_prompt = image_info_post['all_prompts'][info_index]
|
all_prompts.append(image_info_post['all_prompts'][i])
|
||||||
else:
|
else: # only a single image received
|
||||||
seed = image_info_post['seed']
|
all_seeds.append(image_info_post['seed'])
|
||||||
subseed = image_info_post['subseed']
|
all_subseeds.append(image_info_post['subseed'])
|
||||||
negative_prompt = image_info_post['negative_prompt']
|
all_negative_prompts.append(image_info_post['negative_prompt'])
|
||||||
pos_prompt = image_info_post['prompt']
|
all_prompts.append(image_info_post['prompt'])
|
||||||
except IndexError:
|
except IndexError:
|
||||||
# like with controlnet masks, there isn't always full post-gen info, so we use the first images'
|
# # like with controlnet masks, there isn't always full post-gen info, so we use the first images'
|
||||||
logger.debug(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
|
# logger.debug(f"Image at index {info_index} for '{job.worker.label}' was missing some post-generation data")
|
||||||
processed_inject_image(image=image, info_index=0, response=response)
|
# self.processed_inject_image(image=image, info_index=0, job=job, p=p)
|
||||||
return
|
# return
|
||||||
|
logger.critical(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
|
||||||
|
continue
|
||||||
|
|
||||||
processed.all_seeds.append(seed)
|
# parse image
|
||||||
processed.all_subseeds.append(subseed)
|
image_bytes = base64.b64decode(job.worker.response["images"][i])
|
||||||
processed.all_negative_prompts.append(negative_prompt)
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
processed.all_prompts.append(pos_prompt)
|
transform = ToTensor()
|
||||||
processed.images.append(image) # actual received image
|
images.append(transform(image))
|
||||||
|
|
||||||
# generate info-text string
|
return all_seeds, all_subseeds, all_negative_prompts, all_prompts, images
|
||||||
|
|
||||||
|
def inject_job(self, job: Job, p, pp):
|
||||||
|
"""Adds the work completed by one Job via its worker response to the processing and postprocessing objects"""
|
||||||
|
all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = self.api_to_internal(job)
|
||||||
|
|
||||||
|
p.seeds.extend(all_seeds)
|
||||||
|
p.subseeds.extend(all_subseeds)
|
||||||
|
p.negative_prompts.extend(all_negative_prompts)
|
||||||
|
p.prompts.extend(all_prompts)
|
||||||
|
|
||||||
|
num_local = self.world.p.n_iter * self.world.p.batch_size + (opts.return_grid - self.world.thin_client_mode)
|
||||||
|
num_injected = len(pp.images) - self.world.p.batch_size
|
||||||
|
for i, image in enumerate(images):
|
||||||
# modules.ui_common -> update_generation_info renders to html below gallery
|
# modules.ui_common -> update_generation_info renders to html below gallery
|
||||||
images_per_batch = p.n_iter * p.batch_size
|
gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
|
||||||
# zero-indexed position of image in total batch (so including master results)
|
job.gallery_map.append(gallery_index) # so we know where to edit infotext
|
||||||
true_image_pos = len(processed.images) - 1
|
pp.images.append(image)
|
||||||
num_remote_images = images_per_batch * p.batch_size
|
logger.debug(f"image {gallery_index + 1 + self.world.thin_client_mode}/{self.world.num_gallery()}")
|
||||||
if p.n_iter > 1: # if splitting by batch count
|
|
||||||
num_remote_images *= p.n_iter - 1
|
|
||||||
|
|
||||||
logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, "
|
def update_gallery(self, pp, p):
|
||||||
f"info-index: {info_index}")
|
"""adds all remotely generated images to the image gallery after waiting for all workers to finish"""
|
||||||
|
|
||||||
if self.world.thin_client_mode:
|
|
||||||
p.all_negative_prompts = processed.all_negative_prompts
|
|
||||||
|
|
||||||
try:
|
|
||||||
info_text = image_info_post['infotexts'][i]
|
|
||||||
except IndexError:
|
|
||||||
if not grid:
|
|
||||||
logger.warning(f"image {true_image_pos + 1} was missing info-text")
|
|
||||||
info_text = processed.infotexts[0]
|
|
||||||
info_text += f", Worker Label: {job.worker.label}"
|
|
||||||
processed.infotexts.append(info_text)
|
|
||||||
|
|
||||||
# automatically save received image to local disk if desired
|
|
||||||
if cmd_opts.distributed_remotes_autosave:
|
|
||||||
save_image(
|
|
||||||
image=image,
|
|
||||||
path=p.outpath_samples if save_path_override is None else save_path_override,
|
|
||||||
basename="",
|
|
||||||
seed=seed,
|
|
||||||
prompt=pos_prompt,
|
|
||||||
info=info_text,
|
|
||||||
extension=opts.samples_format
|
|
||||||
)
|
|
||||||
|
|
||||||
# get master ipm by estimating based on worker speed
|
# get master ipm by estimating based on worker speed
|
||||||
master_elapsed = time.time() - self.master_start
|
master_elapsed = time.time() - self.master_start
|
||||||
|
|
@ -158,8 +143,7 @@ class DistributedScript(scripts.Script):
|
||||||
logger.debug("all worker request threads returned")
|
logger.debug("all worker request threads returned")
|
||||||
webui_state.textinfo = "Distributed - injecting images"
|
webui_state.textinfo = "Distributed - injecting images"
|
||||||
|
|
||||||
# some worker which we know has a good response that we can use for generating the grid
|
received_images = False
|
||||||
donor_worker = None
|
|
||||||
for job in self.world.jobs:
|
for job in self.world.jobs:
|
||||||
if job.worker.response is None or job.batch_size < 1 or job.worker.master:
|
if job.worker.response is None or job.batch_size < 1 or job.worker.master:
|
||||||
continue
|
continue
|
||||||
|
|
@ -170,8 +154,7 @@ class DistributedScript(scripts.Script):
|
||||||
if (job.batch_size * p.n_iter) < len(images):
|
if (job.batch_size * p.n_iter) < len(images):
|
||||||
logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}")
|
logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}")
|
||||||
|
|
||||||
if donor_worker is None:
|
received_images = True
|
||||||
donor_worker = job.worker
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if job.batch_size > 0:
|
if job.batch_size > 0:
|
||||||
logger.warning(f"Worker '{job.worker.label}' had no images")
|
logger.warning(f"Worker '{job.worker.label}' had no images")
|
||||||
|
|
@ -185,41 +168,27 @@ class DistributedScript(scripts.Script):
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# visibly add work from workers to the image gallery
|
# adding the images in
|
||||||
for i in range(0, len(images)):
|
self.inject_job(job, p, pp)
|
||||||
image_bytes = base64.b64decode(images[i])
|
|
||||||
image = Image.open(io.BytesIO(image_bytes))
|
|
||||||
|
|
||||||
# inject image
|
# TODO fix controlnet masks returned via api having no generation info
|
||||||
processed_inject_image(image=image, info_index=i, response=job.worker.response)
|
if received_images is False:
|
||||||
|
|
||||||
if donor_worker is None:
|
|
||||||
logger.critical("couldn't collect any responses, the extension will have no effect")
|
logger.critical("couldn't collect any responses, the extension will have no effect")
|
||||||
return
|
return
|
||||||
|
|
||||||
# generate and inject grid
|
p.batch_size = len(pp.images)
|
||||||
if opts.return_grid and len(processed.images) > 1:
|
webui_state.textinfo = ""
|
||||||
grid = image_grid(processed.images, len(processed.images))
|
|
||||||
processed_inject_image(
|
|
||||||
image=grid,
|
|
||||||
info_index=0,
|
|
||||||
save_path_override=p.outpath_grids,
|
|
||||||
grid=True,
|
|
||||||
response=donor_worker.response
|
|
||||||
)
|
|
||||||
|
|
||||||
# cleanup after we're doing using all the responses
|
|
||||||
for worker in self.world.get_workers():
|
|
||||||
worker.response = None
|
|
||||||
|
|
||||||
p.batch_size = len(processed.images)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# p's type is
|
# p's type is
|
||||||
# "modules.processing.StableDiffusionProcessing*"
|
# "modules.processing.StableDiffusionProcessing*"
|
||||||
def before_process(self, p, *args):
|
def before_process(self, p, *args):
|
||||||
if not self.world.enabled:
|
is_img2img = getattr(p, 'init_images', False)
|
||||||
logger.debug("extension is disabled")
|
if is_img2img and self.world.enabled_i2i is False:
|
||||||
|
logger.debug("extension is disabled for i2i")
|
||||||
|
return
|
||||||
|
elif not is_img2img and self.world.enabled is False:
|
||||||
|
logger.debug("extension is disabled for t2i")
|
||||||
return
|
return
|
||||||
self.world.update(p)
|
self.world.update(p)
|
||||||
|
|
||||||
|
|
@ -234,6 +203,14 @@ class DistributedScript(scripts.Script):
|
||||||
continue
|
continue
|
||||||
title = script.title()
|
title = script.title()
|
||||||
|
|
||||||
|
if title == "ADetailer":
|
||||||
|
adetailer_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
|
||||||
|
# InputAccordion main toggle, skip img2img toggle
|
||||||
|
if adetailer_args[0] and adetailer_args[1]:
|
||||||
|
logger.debug(f"adetailer is skipping img2img, returning control to wui")
|
||||||
|
return
|
||||||
|
|
||||||
# check for supported scripts
|
# check for supported scripts
|
||||||
if title == "ControlNet":
|
if title == "ControlNet":
|
||||||
# grab all controlnet units
|
# grab all controlnet units
|
||||||
|
|
@ -346,18 +323,34 @@ class DistributedScript(scripts.Script):
|
||||||
p.batch_size = self.world.master_job().batch_size
|
p.batch_size = self.world.master_job().batch_size
|
||||||
self.master_start = time.time()
|
self.master_start = time.time()
|
||||||
|
|
||||||
# generate images assigned to local machine
|
|
||||||
p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
|
|
||||||
self.runs_since_init += 1
|
self.runs_since_init += 1
|
||||||
return
|
return
|
||||||
|
|
||||||
def postprocess(self, p, processed, *args):
|
def postprocess_batch_list(self, p, pp, *args, **kwargs):
|
||||||
if not self.world.enabled:
|
if not self.world.thin_client_mode and p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
|
||||||
|
return
|
||||||
|
|
||||||
|
is_img2img = getattr(p, 'init_images', False)
|
||||||
|
if is_img2img and self.world.enabled_i2i is False:
|
||||||
|
return
|
||||||
|
elif not is_img2img and self.world.enabled is False:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.master_start is not None:
|
if self.master_start is not None:
|
||||||
self.add_to_gallery(p=p, processed=processed)
|
self.update_gallery(p=p, pp=pp)
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess(self, p, processed, *args):
|
||||||
|
for job in self.world.jobs:
|
||||||
|
if job.worker.response is not None:
|
||||||
|
for i, v in enumerate(job.gallery_map):
|
||||||
|
infotext = json.loads(job.worker.response['info'])['infotexts'][i]
|
||||||
|
infotext += f", Worker Label: {job.worker.label}"
|
||||||
|
processed.infotexts[v] = infotext
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
for worker in self.world.get_workers():
|
||||||
|
worker.response = None
|
||||||
# restore process_images_inner if it was monkey-patched
|
# restore process_images_inner if it was monkey-patched
|
||||||
processing.process_images_inner = self.original_process_images_inner
|
processing.process_images_inner = self.original_process_images_inner
|
||||||
# save any dangling state to prevent load_config in next iteration overwriting it
|
# save any dangling state to prevent load_config in next iteration overwriting it
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
|
# https://github.com/Mikubill/sd-webui-controlnet/wiki/API#examples-1
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules.api.api import encode_pil_to_base64
|
from modules.api.api import encode_pil_to_base64
|
||||||
from scripts.spartan.shared import logger
|
from scripts.spartan.shared import logger
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
def np_to_b64(image: np.ndarray):
|
def np_to_b64(image: np.ndarray):
|
||||||
|
|
@ -58,10 +61,14 @@ def pack_control_net(cn_units) -> dict:
|
||||||
unit['mask'] = mask_b64 # mikubill
|
unit['mask'] = mask_b64 # mikubill
|
||||||
unit['mask_image'] = mask_b64 # forge
|
unit['mask_image'] = mask_b64 # forge
|
||||||
|
|
||||||
|
|
||||||
|
# serialize all enums
|
||||||
|
for k in unit.keys():
|
||||||
|
if isinstance(unit[k], enum.Enum):
|
||||||
|
unit[k] = unit[k].value
|
||||||
|
|
||||||
# avoid returning duplicate detection maps since master should return the same one
|
# avoid returning duplicate detection maps since master should return the same one
|
||||||
unit['save_detected_map'] = False
|
unit['save_detected_map'] = False
|
||||||
# remove anything unserializable
|
|
||||||
del unit['input_mode']
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
json.dumps(controlnet)
|
json.dumps(controlnet)
|
||||||
|
|
|
||||||
|
|
@ -29,17 +29,18 @@ class Worker_Model(BaseModel):
|
||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
state: Optional[Any] = Field(default=1, description="The last known state of this worker")
|
state: Optional[Any] = Field(default=1, description="The last known state of this worker")
|
||||||
user: Optional[str] = Field(description="The username to be used when authenticating with this worker")
|
user: Optional[str] = Field(description="The username to be used when authenticating with this worker", default=None)
|
||||||
password: Optional[str] = Field(description="The password to be used when authenticating with this worker")
|
password: Optional[str] = Field(description="The password to be used when authenticating with this worker", default=None)
|
||||||
pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit")
|
pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit")
|
||||||
|
|
||||||
class ConfigModel(BaseModel):
|
class ConfigModel(BaseModel):
|
||||||
workers: List[Dict[str, Worker_Model]]
|
workers: List[Dict[str, Worker_Model]]
|
||||||
benchmark_payload: Dict = Field(
|
benchmark_payload: Benchmark_Payload = Field(
|
||||||
default=Benchmark_Payload,
|
default=Benchmark_Payload,
|
||||||
description='the payload used when benchmarking a node'
|
description='the payload used when benchmarking a node'
|
||||||
)
|
)
|
||||||
job_timeout: Optional[int] = Field(default=3)
|
job_timeout: Optional[int] = Field(default=3)
|
||||||
enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True)
|
enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True)
|
||||||
|
enabled_i2i: Optional[bool] = Field(description="Same as above but for image to image", default=True)
|
||||||
complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True)
|
complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True)
|
||||||
step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False)
|
step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False)
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ samples = 3 # number of times to benchmark worker after warmup benchmarks are c
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkPayload(BaseModel):
|
class BenchmarkPayload(BaseModel):
|
||||||
validate_assignment = True
|
# validate_assignment = True
|
||||||
prompt: str = Field(default="A herd of cows grazing at the bottom of a sunny valley")
|
prompt: str = Field(default="A herd of cows grazing at the bottom of a sunny valley")
|
||||||
negative_prompt: str = Field(default="")
|
negative_prompt: str = Field(default="")
|
||||||
steps: int = Field(default=20)
|
steps: int = Field(default=20)
|
||||||
|
|
|
||||||
|
|
@ -9,16 +9,17 @@ from .shared import logger, LOG_LEVEL, gui_handler
|
||||||
from .worker import State
|
from .worker import State
|
||||||
from modules.call_queue import queue_lock
|
from modules.call_queue import queue_lock
|
||||||
from modules import progress
|
from modules import progress
|
||||||
|
from modules.ui_components import InputAccordion
|
||||||
|
|
||||||
worker_select_dropdown = None
|
worker_select_dropdown = None
|
||||||
|
|
||||||
|
|
||||||
class UI:
|
class UI:
|
||||||
"""extension user interface related things"""
|
"""extension user interface related things"""
|
||||||
|
|
||||||
def __init__(self, world):
|
def __init__(self, world, is_img2img):
|
||||||
self.world = world
|
self.world = world
|
||||||
self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange
|
self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange
|
||||||
|
self.is_img2img = is_img2img
|
||||||
|
|
||||||
# handlers
|
# handlers
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -184,12 +185,20 @@ class UI:
|
||||||
worker.session.auth = (user, password)
|
worker.session.auth = (user, password)
|
||||||
self.world.save_config()
|
self.world.save_config()
|
||||||
|
|
||||||
def main_toggle_btn(self):
|
def main_toggle_btn(self, state):
|
||||||
self.world.enabled = not self.world.enabled
|
if self.is_img2img:
|
||||||
|
if self.world.enabled_i2i == state: # just prevents a redundant config save if ui desyncs
|
||||||
|
return
|
||||||
|
self.world.enabled_i2i = state
|
||||||
|
else:
|
||||||
|
if self.world.enabled == state:
|
||||||
|
return
|
||||||
|
self.world.enabled = state
|
||||||
|
|
||||||
self.world.save_config()
|
self.world.save_config()
|
||||||
|
|
||||||
# restore vanilla sdwui handler for model dropdown if extension is disabled or inject if otherwise
|
# restore vanilla sdwui handler for model dropdown if extension is disabled or inject if otherwise
|
||||||
if not self.world.enabled:
|
if not self.world.enabled and not self.world.enabled_i2i:
|
||||||
model_dropdown = opts.data_labels.get('sd_model_checkpoint')
|
model_dropdown = opts.data_labels.get('sd_model_checkpoint')
|
||||||
if self.original_model_dropdown_handler is not None:
|
if self.original_model_dropdown_handler is not None:
|
||||||
model_dropdown.onchange = self.original_model_dropdown_handler
|
model_dropdown.onchange = self.original_model_dropdown_handler
|
||||||
|
|
@ -208,17 +217,14 @@ class UI:
|
||||||
def create_ui(self):
|
def create_ui(self):
|
||||||
"""creates the extension UI and returns relevant components"""
|
"""creates the extension UI and returns relevant components"""
|
||||||
components = []
|
components = []
|
||||||
|
elem_id = 'enabled'
|
||||||
|
if self.is_img2img:
|
||||||
|
elem_id += '_i2i'
|
||||||
|
|
||||||
with gradio.Blocks(variant='compact'): # Group() and Box() remove spacing
|
with gradio.Blocks(variant='compact'): # Group() and Box() remove spacing
|
||||||
with gradio.Accordion(label='Distributed', open=False):
|
with InputAccordion(label='Distributed', open=False, value=self.world.config().get(elem_id), elem_id=elem_id) as main_toggle:
|
||||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/6109#issuecomment-1403315784
|
main_toggle.input(self.main_toggle_btn, inputs=[main_toggle])
|
||||||
main_toggle = gradio.Checkbox( # main on/off ext. toggle
|
setattr(main_toggle.accordion, 'do_not_save_to_config', True) # InputAccordion is really a CheckBox
|
||||||
elem_id='enable',
|
|
||||||
label='Enable',
|
|
||||||
value=self.world.enabled if self.world.enabled is not None else True,
|
|
||||||
interactive=True
|
|
||||||
)
|
|
||||||
main_toggle.input(self.main_toggle_btn)
|
|
||||||
components.append(main_toggle)
|
components.append(main_toggle)
|
||||||
|
|
||||||
with gradio.Tab('Status') as status_tab:
|
with gradio.Tab('Status') as status_tab:
|
||||||
|
|
@ -236,8 +242,8 @@ class UI:
|
||||||
info='top-most message is newest'
|
info='top-most message is newest'
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_status_btn = gradio.Button(value='Refresh 🔄', size='sm')
|
refresh_status_btn = gradio.Button(value='Refresh 🔄', size='sm', elem_id='distributed-refresh-status', visible=False)
|
||||||
refresh_status_btn.click(self.status_btn, inputs=[], outputs=[jobs, status, logs])
|
refresh_status_btn.click(self.status_btn, inputs=[], outputs=[jobs, status, logs], show_progress=False)
|
||||||
|
|
||||||
status_tab.select(fn=self.status_btn, inputs=[], outputs=[jobs, status, logs])
|
status_tab.select(fn=self.status_btn, inputs=[], outputs=[jobs, status, logs])
|
||||||
components += [status, jobs, logs, refresh_status_btn]
|
components += [status, jobs, logs, refresh_status_btn]
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from modules.shared import cmd_opts
|
||||||
from modules.shared import state as master_state
|
from modules.shared import state as master_state
|
||||||
from . import shared as sh
|
from . import shared as sh
|
||||||
from .shared import logger, warmup_samples, LOG_LEVEL
|
from .shared import logger, warmup_samples, LOG_LEVEL
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from webui import server_name
|
from webui import server_name
|
||||||
|
|
@ -40,6 +41,13 @@ class State(Enum):
|
||||||
DISABLED = 5
|
DISABLED = 5
|
||||||
|
|
||||||
|
|
||||||
|
# looks redundant when encode_pil...() could be used, but it does not support all file formats. E.g. AVIF
|
||||||
|
def pil_to_64(image: Image) -> str:
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
image.save(buffer, format="PNG")
|
||||||
|
return 'data:image/png;base64,' + str(base64.b64encode(buffer.getvalue()), 'utf-8')
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
"""
|
"""
|
||||||
This class represents a worker node in a distributed computing setup.
|
This class represents a worker node in a distributed computing setup.
|
||||||
|
|
@ -361,10 +369,7 @@ class Worker:
|
||||||
mode = 'img2img' # for use in checking script compat
|
mode = 'img2img' # for use in checking script compat
|
||||||
images = []
|
images = []
|
||||||
for image in init_images:
|
for image in init_images:
|
||||||
buffer = io.BytesIO()
|
images.append(pil_to_64(image))
|
||||||
image.save(buffer, format="PNG")
|
|
||||||
image = 'data:image/png;base64,' + str(base64.b64encode(buffer.getvalue()), 'utf-8')
|
|
||||||
images.append(image)
|
|
||||||
payload['init_images'] = images
|
payload['init_images'] = images
|
||||||
|
|
||||||
alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example
|
alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example
|
||||||
|
|
@ -401,9 +406,7 @@ class Worker:
|
||||||
# if an image mask is present
|
# if an image mask is present
|
||||||
image_mask = payload.get('image_mask', None)
|
image_mask = payload.get('image_mask', None)
|
||||||
if image_mask is not None:
|
if image_mask is not None:
|
||||||
image_b64 = encode_pil_to_base64(image_mask)
|
payload['mask'] = pil_to_64(image_mask)
|
||||||
image_b64 = str(image_b64, 'utf-8')
|
|
||||||
payload['mask'] = image_b64
|
|
||||||
del payload['image_mask']
|
del payload['image_mask']
|
||||||
|
|
||||||
# see if there is anything else wrong with serializing to payload
|
# see if there is anything else wrong with serializing to payload
|
||||||
|
|
@ -606,6 +609,7 @@ class Worker:
|
||||||
self.full_url("memory"),
|
self.full_url("memory"),
|
||||||
timeout=3
|
timeout=3
|
||||||
)
|
)
|
||||||
|
self.response = response
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
|
|
||||||
except requests.exceptions.ConnectionError as e:
|
except requests.exceptions.ConnectionError as e:
|
||||||
|
|
@ -640,6 +644,7 @@ class Worker:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def load_options(self, model, vae=None):
|
def load_options(self, model, vae=None):
|
||||||
|
failure_msg = f"failed to load options for worker '{self.label}'"
|
||||||
if self.master:
|
if self.master:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -653,17 +658,25 @@ class Worker:
|
||||||
if vae is not None:
|
if vae is not None:
|
||||||
payload['sd_vae'] = vae
|
payload['sd_vae'] = vae
|
||||||
|
|
||||||
self.set_state(State.WORKING)
|
state_cache = self.state
|
||||||
|
self.set_state(State.WORKING, expect_cycle=True) # may already be WORKING if called by worker.request()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
response = self.session.post(
|
try:
|
||||||
self.full_url("options"),
|
response = self.session.post(
|
||||||
json=payload
|
self.full_url("options"),
|
||||||
)
|
json=payload
|
||||||
|
)
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
self.set_state(State.UNAVAILABLE)
|
||||||
|
logger.error(f"{failure_msg} (connection error... OOM?)")
|
||||||
|
return
|
||||||
|
|
||||||
elapsed = time.time() - start
|
elapsed = time.time() - start
|
||||||
self.set_state(State.IDLE)
|
if state_cache != State.WORKING: # see above comment, this lets caller determine when worker is IDLE
|
||||||
|
self.set_state(State.IDLE)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.debug(f"failed to load options for worker '{self.label}'")
|
logger.debug(failure_msg)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"worker '{self.label}' loaded weights in {elapsed:.2f}s")
|
logger.debug(f"worker '{self.label}' loaded weights in {elapsed:.2f}s")
|
||||||
self.loaded_model = model_name
|
self.loaded_model = model_name
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,9 @@ from .worker import Worker, State
|
||||||
from modules.call_queue import wrap_queued_call, queue_lock
|
from modules.call_queue import wrap_queued_call, queue_lock
|
||||||
from modules import processing
|
from modules import processing
|
||||||
from modules import progress
|
from modules import progress
|
||||||
|
from modules.scripts import PostprocessBatchListArgs
|
||||||
|
from torchvision.transforms import ToPILImage
|
||||||
|
from modules.images import image_grid
|
||||||
|
|
||||||
|
|
||||||
class NotBenchmarked(Exception):
|
class NotBenchmarked(Exception):
|
||||||
|
|
@ -46,6 +49,7 @@ class Job:
|
||||||
self.complementary: bool = False
|
self.complementary: bool = False
|
||||||
self.step_override = None
|
self.step_override = None
|
||||||
self.thread = None
|
self.thread = None
|
||||||
|
self.gallery_map: List[int] = []
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
prefix = ''
|
prefix = ''
|
||||||
|
|
@ -91,6 +95,7 @@ class World:
|
||||||
self.verify_remotes = verify_remotes
|
self.verify_remotes = verify_remotes
|
||||||
self.thin_client_mode = False
|
self.thin_client_mode = False
|
||||||
self.enabled = True
|
self.enabled = True
|
||||||
|
self.enabled_i2i = True
|
||||||
self.is_dropdown_handler_injected = False
|
self.is_dropdown_handler_injected = False
|
||||||
self.complement_production = True
|
self.complement_production = True
|
||||||
self.step_scaling = False
|
self.step_scaling = False
|
||||||
|
|
@ -103,7 +108,6 @@ class World:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{len(self._workers)} workers"
|
return f"{len(self._workers)} workers"
|
||||||
|
|
||||||
|
|
||||||
def default_batch_size(self) -> int:
|
def default_batch_size(self) -> int:
|
||||||
"""the amount of images/total images requested that a worker would compute if conditions were perfect and
|
"""the amount of images/total images requested that a worker would compute if conditions were perfect and
|
||||||
each worker generated at the same speed. assumes one batch only"""
|
each worker generated at the same speed. assumes one batch only"""
|
||||||
|
|
@ -113,7 +117,7 @@ class World:
|
||||||
def size(self) -> int:
|
def size(self) -> int:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
int: The number of nodes currently registered in the world.
|
int: The number of nodes currently registered in the world and in a valid state
|
||||||
"""
|
"""
|
||||||
return len(self.get_workers())
|
return len(self.get_workers())
|
||||||
|
|
||||||
|
|
@ -273,7 +277,7 @@ class World:
|
||||||
gradio.Info("Distributed: benchmarking complete!")
|
gradio.Info("Distributed: benchmarking complete!")
|
||||||
self.save_config()
|
self.save_config()
|
||||||
|
|
||||||
def get_current_output_size(self) -> int:
|
def num_requested(self) -> int:
|
||||||
"""
|
"""
|
||||||
returns how many images would be returned from all jobs
|
returns how many images would be returned from all jobs
|
||||||
"""
|
"""
|
||||||
|
|
@ -285,6 +289,11 @@ class World:
|
||||||
|
|
||||||
return num_images
|
return num_images
|
||||||
|
|
||||||
|
def num_gallery(self) -> int:
|
||||||
|
"""How many images should appear in the gallery. This includes local generations and a grid(if enabled)"""
|
||||||
|
|
||||||
|
return self.num_requested() * self.p.n_iter + shared.opts.return_grid
|
||||||
|
|
||||||
def speed_summary(self) -> str:
|
def speed_summary(self) -> str:
|
||||||
"""
|
"""
|
||||||
Returns string listing workers by their ipm in descending order.
|
Returns string listing workers by their ipm in descending order.
|
||||||
|
|
@ -393,7 +402,6 @@ class World:
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
logger.debug("world initialized!")
|
logger.debug("world initialized!")
|
||||||
|
|
||||||
|
|
||||||
def get_workers(self):
|
def get_workers(self):
|
||||||
filtered: List[Worker] = []
|
filtered: List[Worker] = []
|
||||||
for worker in self._workers:
|
for worker in self._workers:
|
||||||
|
|
@ -472,7 +480,7 @@ class World:
|
||||||
#######################
|
#######################
|
||||||
|
|
||||||
# when total number of requested images was not cleanly divisible by world size then we tack the remainder on
|
# when total number of requested images was not cleanly divisible by world size then we tack the remainder on
|
||||||
remainder_images = self.p.batch_size - self.get_current_output_size()
|
remainder_images = self.p.batch_size - self.num_requested()
|
||||||
if remainder_images >= 1:
|
if remainder_images >= 1:
|
||||||
logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
|
logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
|
||||||
|
|
||||||
|
|
@ -551,21 +559,37 @@ class World:
|
||||||
else:
|
else:
|
||||||
logger.debug("complementary image production is disabled")
|
logger.debug("complementary image production is disabled")
|
||||||
|
|
||||||
logger.info(self.distro_summary(payload))
|
logger.info(self.distro_summary())
|
||||||
|
|
||||||
if self.thin_client_mode is True or self.master_job().batch_size == 0:
|
if self.thin_client_mode is True or self.master_job().batch_size == 0:
|
||||||
# save original process_images_inner for later so we can restore once we're done
|
# save original process_images_inner for later so we can restore once we're done
|
||||||
logger.debug(f"bypassing local generation completely")
|
logger.debug(f"bypassing local generation completely")
|
||||||
def process_images_inner_bypass(p) -> processing.Processed:
|
def process_images_inner_bypass(p) -> processing.Processed:
|
||||||
|
p.seeds, p.subseeds, p.negative_prompts, p.prompts = [], [], [], []
|
||||||
|
pp = PostprocessBatchListArgs(images=[])
|
||||||
|
self.p.scripts.postprocess_batch_list(p, pp)
|
||||||
|
|
||||||
processed = processing.Processed(p, [], p.seed, info="")
|
processed = processing.Processed(p, [], p.seed, info="")
|
||||||
processed.all_prompts = []
|
processed.all_prompts = p.prompts
|
||||||
processed.all_seeds = []
|
processed.all_seeds = p.seeds
|
||||||
processed.all_subseeds = []
|
processed.all_subseeds = p.subseeds
|
||||||
processed.all_negative_prompts = []
|
processed.all_negative_prompts = p.negative_prompts
|
||||||
processed.infotexts = []
|
processed.images = pp.images
|
||||||
processed.prompt = None
|
processed.infotexts = [''] * self.num_requested()
|
||||||
|
|
||||||
|
transform = ToPILImage()
|
||||||
|
for i, image in enumerate(processed.images):
|
||||||
|
processed.images[i] = transform(image)
|
||||||
|
|
||||||
|
|
||||||
self.p.scripts.postprocess(p, processed)
|
self.p.scripts.postprocess(p, processed)
|
||||||
|
|
||||||
|
# generate grid if enabled
|
||||||
|
if shared.opts.return_grid and len(processed.images) > 1:
|
||||||
|
grid = image_grid(processed.images, len(processed.images))
|
||||||
|
processed.images.insert(0, grid)
|
||||||
|
processed.infotexts.insert(0, processed.infotexts[0])
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
processing.process_images_inner = process_images_inner_bypass
|
processing.process_images_inner = process_images_inner_bypass
|
||||||
|
|
||||||
|
|
@ -576,18 +600,17 @@ class World:
|
||||||
del self.jobs[last]
|
del self.jobs[last]
|
||||||
last -= 1
|
last -= 1
|
||||||
|
|
||||||
def distro_summary(self, payload):
|
def distro_summary(self):
|
||||||
# iterations = dict(payload)['n_iter']
|
num_returning = self.num_requested()
|
||||||
iterations = self.p.n_iter
|
|
||||||
num_returning = self.get_current_output_size()
|
|
||||||
num_complementary = num_returning - self.p.batch_size
|
num_complementary = num_returning - self.p.batch_size
|
||||||
|
|
||||||
distro_summary = "Job distribution:\n"
|
distro_summary = "Job distribution:\n"
|
||||||
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
distro_summary += f"{self.p.batch_size} * {self.p.n_iter} iteration(s)"
|
||||||
if num_complementary > 0:
|
if num_complementary > 0:
|
||||||
distro_summary += f" + {num_complementary} complementary"
|
distro_summary += f" + {num_complementary} complementary"
|
||||||
distro_summary += f": {num_returning} images total\n"
|
distro_summary += f": {num_returning * self.p.n_iter} images total\n"
|
||||||
for job in self.jobs:
|
for job in self.jobs:
|
||||||
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
distro_summary += f"'{job.worker.label}' - {job.batch_size * self.p.n_iter} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
||||||
return distro_summary
|
return distro_summary
|
||||||
|
|
||||||
def config(self) -> dict:
|
def config(self) -> dict:
|
||||||
|
|
@ -670,9 +693,10 @@ class World:
|
||||||
|
|
||||||
self.add_worker(**fields)
|
self.add_worker(**fields)
|
||||||
|
|
||||||
sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload)
|
sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload.dict())
|
||||||
self.job_timeout = config.job_timeout
|
self.job_timeout = config.job_timeout
|
||||||
self.enabled = config.enabled
|
self.enabled = config.enabled
|
||||||
|
self.enabled_i2i = config.enabled_i2i
|
||||||
self.complement_production = config.complement_production
|
self.complement_production = config.complement_production
|
||||||
self.step_scaling = config.step_scaling
|
self.step_scaling = config.step_scaling
|
||||||
|
|
||||||
|
|
@ -688,6 +712,7 @@ class World:
|
||||||
benchmark_payload=sh.benchmark_payload,
|
benchmark_payload=sh.benchmark_payload,
|
||||||
job_timeout=self.job_timeout,
|
job_timeout=self.job_timeout,
|
||||||
enabled=self.enabled,
|
enabled=self.enabled,
|
||||||
|
enabled_i2i=self.enabled_i2i,
|
||||||
complement_production=self.complement_production,
|
complement_production=self.complement_production,
|
||||||
step_scaling=self.step_scaling
|
step_scaling=self.step_scaling
|
||||||
)
|
)
|
||||||
|
|
@ -743,6 +768,9 @@ class World:
|
||||||
worker.set_state(State.IDLE, expect_cycle=True)
|
worker.set_state(State.IDLE, expect_cycle=True)
|
||||||
else:
|
else:
|
||||||
msg = f"worker '{worker.label}' is unreachable"
|
msg = f"worker '{worker.label}' is unreachable"
|
||||||
|
if worker.response.status_code is not None:
|
||||||
|
msg += f" <{worker.response.status_code}>"
|
||||||
|
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
gradio.Warning("Distributed: "+msg)
|
gradio.Warning("Distributed: "+msg)
|
||||||
worker.set_state(State.UNAVAILABLE)
|
worker.set_state(State.UNAVAILABLE)
|
||||||
|
|
@ -753,7 +781,6 @@ class World:
|
||||||
for worker in self._workers:
|
for worker in self._workers:
|
||||||
worker.restart()
|
worker.restart()
|
||||||
|
|
||||||
|
|
||||||
def inject_model_dropdown_handler(self):
|
def inject_model_dropdown_handler(self):
|
||||||
if self.config().get('enabled', False): # TODO avoid access from config()
|
if self.config().get('enabled', False): # TODO avoid access from config()
|
||||||
return
|
return
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue