merge dev making 2.2.0

master v2.2.0
unknown 2024-05-12 00:24:20 -05:00
commit b1f9d4f769
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
8 changed files with 338 additions and 256 deletions

View File

@ -1,6 +1,27 @@
# Change Log
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
## [2.2.0] - 2024-5-11
### Added
- Toggle for allowing automatic step scaling which can increase overall utilization
### Changed
- Adding workers which have the same socket definition as master will no longer be allowed and an error will show #28
- Workers in an invalid state should no longer be benchmarked
- The worker port under worker config will now default to 7860 to prevent mishaps
- Config should once again only be loaded once per session startup
- A warning will be shown when trying to use the user script button but no script exists
### Fixed
- Thin-client mode
- Some problems with sdwui forge branch
- Certificate verification setting sometimes not saving
- Master being assigned no work stopping generation (same problem as thin-client)
### Removed
- Adding workers using deprecated cmdline argument
## [2.1.0] - 2024-3-03
### Added

View File

@ -10,7 +10,7 @@ import re
import signal
import sys
import time
from threading import Thread, current_thread
from threading import Thread
from typing import List
import gradio
import urllib3
@ -18,28 +18,24 @@ from PIL import Image
from modules import processing
from modules import scripts
from modules.images import save_image
from modules.processing import fix_seed, Processed
from modules.processing import fix_seed
from modules.shared import opts, cmd_opts
from modules.shared import state as webui_state
from scripts.spartan.control_net import pack_control_net
from scripts.spartan.shared import logger
from scripts.spartan.ui import UI
from scripts.spartan.world import World, WorldAlreadyInitialized
from scripts.spartan.world import World
old_sigint_handler = signal.getsignal(signal.SIGINT)
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
# TODO implement advertisement of some sort in sdwui api to allow extension to automatically discover workers?
# noinspection PyMissingOrEmptyDocstring
class Script(scripts.Script):
class DistributedScript(scripts.Script):
# global old_sigterm_handler, old_sigterm_handler
worker_threads: List[Thread] = []
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
is_img2img = True
is_txt2img = True
alwayson = True
master_start = None
runs_since_init = 0
name = "distributed"
@ -50,20 +46,15 @@ class Script(scripts.Script):
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# build world
world = World(initial_payload=None, verify_remotes=verify_remotes)
# add workers to the world
world = World(verify_remotes=verify_remotes)
world.load_config()
if cmd_opts.distributed_remotes is not None and len(cmd_opts.distributed_remotes) > 0:
logger.warning(f"--distributed-remotes is deprecated and may be removed in the future\n"
f"gui/external modification of {world.config_path} will be prioritized going forward")
for worker in cmd_opts.distributed_remotes:
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2], tls=False)
world.save_config()
# do an early check to see which workers are online
logger.info("doing initial ping sweep to see which workers are reachable")
world.ping_remotes(indiscriminate=True)
# constructed for both txt2img and img2img
def __init__(self):
super().__init__()
def title(self):
return "Distribute"
@ -71,21 +62,18 @@ class Script(scripts.Script):
return scripts.AlwaysVisible
def ui(self, is_img2img):
self.world.load_config()
extension_ui = UI(script=Script, world=Script.world)
extension_ui = UI(world=self.world)
# root, api_exposed = extension_ui.create_ui()
components = extension_ui.create_ui()
# The first injection of handler for the models dropdown(sd_model_checkpoint) which is often present
# in the quick-settings bar of a user. Helps ensure model swaps propagate to all nodes ASAP.
Script.world.inject_model_dropdown_handler()
self.world.inject_model_dropdown_handler()
# return some components that should be exposed to the api
return components
@staticmethod
def add_to_gallery(processed, p):
def add_to_gallery(self, processed, p):
"""adds generated images to the image gallery after waiting for all workers to finish"""
webui_state.textinfo = "Distributed - injecting images"
def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None):
image_params: json = response['parameters']
@ -129,10 +117,10 @@ class Script(scripts.Script):
if p.n_iter > 1: # if splitting by batch count
num_remote_images *= p.n_iter - 1
logger.debug(f"image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, "
logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, "
f"info-index: {info_index}")
if Script.world.thin_client_mode:
if self.world.thin_client_mode:
p.all_negative_prompts = processed.all_negative_prompts
try:
@ -157,19 +145,21 @@ class Script(scripts.Script):
)
# get master ipm by estimating based on worker speed
master_elapsed = time.time() - Script.master_start
master_elapsed = time.time() - self.master_start
logger.debug(f"Took master {master_elapsed:.2f}s")
# wait for response from all workers
for thread in Script.worker_threads:
webui_state.textinfo = "Distributed - receiving results"
for thread in self.worker_threads:
logger.debug(f"waiting for worker thread '{thread.name}'")
thread.join()
Script.worker_threads.clear()
self.worker_threads.clear()
logger.debug("all worker request threads returned")
webui_state.textinfo = "Distributed - injecting images"
# some worker which we know has a good response that we can use for generating the grid
donor_worker = None
for job in Script.world.jobs:
for job in self.world.jobs:
if job.batch_size < 1 or job.worker.master:
continue
@ -177,7 +167,7 @@ class Script(scripts.Script):
images: json = job.worker.response["images"]
# if we for some reason get more than we asked for
if (job.batch_size * p.n_iter) < len(images):
logger.debug(f"Requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}")
logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}")
if donor_worker is None:
donor_worker = job.worker
@ -208,7 +198,7 @@ class Script(scripts.Script):
# generate and inject grid
if opts.return_grid:
grid = processing.images.image_grid(processed.images, len(processed.images))
grid = images.image_grid(processed.images, len(processed.images))
processed_inject_image(
image=grid,
info_index=0,
@ -218,35 +208,22 @@ class Script(scripts.Script):
)
# cleanup after we're doing using all the responses
for worker in Script.world.get_workers():
for worker in self.world.get_workers():
worker.response = None
p.batch_size = len(processed.images)
return
@staticmethod
def initialize(initial_payload):
# get default batch size
try:
batch_size = initial_payload.batch_size
except AttributeError:
batch_size = 1
try:
Script.world.initialize(batch_size)
logger.debug(f"World initialized!")
except WorldAlreadyInitialized:
Script.world.update_world(total_batch_size=batch_size)
# p's type is
# "modules.processing.StableDiffusionProcessingTxt2Img"
# "modules.processing.StableDiffusionProcessing*"
def before_process(self, p, *args):
if not self.world.enabled:
logger.debug("extension is disabled")
return
self.world.update(p)
current_thread().name = "distributed_main"
Script.initialize(initial_payload=p)
# save original process_images_inner function for later if we monkeypatch it
self.original_process_images_inner = processing.process_images_inner
# strip scripts that aren't yet supported and warn user
packed_script_args: List[dict] = [] # list of api formatted per-script argument objects
@ -261,8 +238,9 @@ class Script(scripts.Script):
# grab all controlnet units
cn_units = []
cn_args = p.script_args[script.args_from:script.args_to]
for cn_arg in cn_args:
if type(cn_arg).__name__ == "UiControlNetUnit":
if "ControlNetUnit" in type(cn_arg).__name__:
cn_units.append(cn_arg)
logger.debug(f"Detected {len(cn_units)} controlnet unit(s)")
@ -281,7 +259,7 @@ class Script(scripts.Script):
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
# see test/basic_features/txt2img_test.py for an example
payload = copy.copy(p.__dict__)
payload['batch_size'] = Script.world.default_batch_size()
payload['batch_size'] = self.world.default_batch_size()
payload['scripts'] = None
try:
del payload['script_args']
@ -300,8 +278,9 @@ class Script(scripts.Script):
# TODO api for some reason returns 200 even if something failed to be set.
# for now we may have to make redundant GET requests to check if actually successful...
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146
name = re.sub(r'\s?\[[^]]*]$', '', opts.data["sd_model_checkpoint"])
vae = opts.data["sd_vae"]
vae = opts.data.get('sd_vae')
option_payload = {
"sd_model_checkpoint": name,
"sd_vae": vae
@ -309,11 +288,11 @@ class Script(scripts.Script):
# start generating images assigned to remote machines
sync = False # should only really need to sync once per job
Script.world.optimize_jobs(payload) # optimize work assignment before dispatching
self.world.optimize_jobs(payload) # optimize work assignment before dispatching
started_jobs = []
# check if anything even needs to be done
if len(Script.world.jobs) == 1 and Script.world.jobs[0].worker.master:
if len(self.world.jobs) == 1 and self.world.jobs[0].worker.master:
if payload['batch_size'] >= 2:
msg = f"all remote workers are offline or unreachable"
@ -324,7 +303,7 @@ class Script(scripts.Script):
return
for job in Script.world.jobs:
for job in self.world.jobs:
payload_temp = copy.copy(payload)
del payload_temp['scripts_value']
payload_temp = copy.deepcopy(payload_temp)
@ -339,11 +318,14 @@ class Script(scripts.Script):
prior_images += j.batch_size * p.n_iter
payload_temp['batch_size'] = job.batch_size
if job.step_override is not None:
payload_temp['steps'] = job.step_override
payload_temp['subseed'] += prior_images
payload_temp['seed'] += prior_images if payload_temp['subseed_strength'] == 0 else 0
logger.debug(
f"'{job.worker.label}' job's given starting seed is "
f"{payload_temp['seed']} with {prior_images} coming before it")
f"{payload_temp['seed']} with {prior_images} coming before it"
)
if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
sync = True
@ -354,42 +336,34 @@ class Script(scripts.Script):
name=f"{job.worker.label}_request")
t.start()
Script.worker_threads.append(t)
self.worker_threads.append(t)
started_jobs.append(job)
# if master batch size was changed again due to optimization change it to the updated value
if not self.world.thin_client_mode:
p.batch_size = Script.world.master_job().batch_size
Script.master_start = time.time()
p.batch_size = self.world.master_job().batch_size
self.master_start = time.time()
# generate images assigned to local machine
p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
if Script.world.thin_client_mode:
p.batch_size = 0
processed = Processed(p=p, images_list=[])
processed.all_prompts = []
processed.all_seeds = []
processed.all_subseeds = []
processed.all_negative_prompts = []
processed.infotexts = []
processed.prompt = None
Script.runs_since_init += 1
self.runs_since_init += 1
return
@staticmethod
def postprocess(p, processed, *args):
if not Script.world.enabled:
def postprocess(self, p, processed, *args):
if not self.world.enabled:
return
if len(processed.images) >= 1 and Script.master_start is not None:
Script.add_to_gallery(p=p, processed=processed)
if self.master_start is not None:
self.add_to_gallery(p=p, processed=processed)
# restore process_images_inner if it was monkey-patched
processing.process_images_inner = self.original_process_images_inner
@staticmethod
def signal_handler(sig, frame):
logger.debug("handling interrupt signal")
# do cleanup
Script.world.save_config()
DistributedScript.world.save_config()
if sig == signal.SIGINT:
if callable(old_sigint_handler):

View File

@ -2,6 +2,16 @@ import copy
from PIL import Image
from modules.api.api import encode_pil_to_base64
from scripts.spartan.shared import logger
import numpy as np
import json
def np_to_b64(image: np.ndarray):
pil = Image.fromarray(image)
image_b64 = str(encode_pil_to_base64(pil), 'utf-8')
image_b64 = 'data:image/png;base64,' + image_b64
return image_b64
def pack_control_net(cn_units) -> dict:
@ -17,28 +27,46 @@ def pack_control_net(cn_units) -> dict:
cn_args = controlnet['controlnet']['args']
for i in range(0, len(cn_units)):
# copy control net unit to payload
cn_args.append(copy.copy(cn_units[i].__dict__))
if cn_units[i].enabled:
cn_args.append(copy.deepcopy(cn_units[i].__dict__))
else:
logger.debug(f"controlnet unit {i} is not enabled (ignoring)")
for i in range(0, len(cn_args)):
unit = cn_args[i]
# if unit isn't enabled then don't bother including
if not unit['enabled']:
del unit['input_mode']
del unit['image']
logger.debug(f"Controlnet unit {i} is not enabled. Ignoring")
continue
# serialize image
if unit['image'] is not None:
image = unit['image']['image']
# mask = unit['image']['mask']
pil = Image.fromarray(image)
image_b64 = encode_pil_to_base64(pil)
image_b64 = str(image_b64, 'utf-8')
unit['input_image'] = image_b64
image_pair = unit.get('image')
if image_pair is not None:
image_b64 = np_to_b64(image_pair['image'])
unit['input_image'] = image_b64 # mikubill
unit['image'] = image_b64 # forge
if np.all(image_pair['mask'] == 0):
# stand-alone mask from second gradio component
standalone_mask = unit.get('mask_image')
if standalone_mask is not None:
logger.debug(f"found stand-alone mask for controlnet unit {i}")
mask_b64 = np_to_b64(unit['mask_image']['mask'])
unit['mask'] = mask_b64 # mikubill
unit['mask_image'] = mask_b64 # forge
else:
# mask from singular gradio image component
logger.debug(f"found mask for controlnet unit {i}")
mask_b64 = np_to_b64(image_pair['mask'])
unit['mask'] = mask_b64 # mikubill
unit['mask_image'] = mask_b64 # forge
# avoid returning duplicate detection maps since master should return the same one
unit['save_detected_map'] = False
# remove anything unserializable
del unit['input_mode']
del unit['image']
try:
json.dumps(controlnet)
except Exception as e:
logger.error(f"failed to serialize controlnet\nfirst unit:\n{controlnet['controlnet']['args'][0]}")
return {}
return controlnet

View File

@ -42,3 +42,4 @@ class ConfigModel(BaseModel):
job_timeout: Optional[int] = Field(default=3)
enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True)
complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True)
step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False)

View File

@ -61,6 +61,7 @@ logger.addHandler(gui_handler)
# end logging
warmup_samples = 2 # number of samples to do before recording a valid benchmark sample
samples = 3 # number of times to benchmark worker after warmup benchmarks are completed
class BenchmarkPayload(BaseModel):

View File

@ -14,8 +14,7 @@ worker_select_dropdown = None
class UI:
"""extension user interface related things"""
def __init__(self, script, world):
self.script = script
def __init__(self, world):
self.world = world
self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange
@ -25,10 +24,17 @@ class UI:
"""executes a script placed by the user at <extension>/user/sync*"""
user_scripts = Path(os.path.abspath(__file__)).parent.parent.joinpath('user')
user_script = None
for file in user_scripts.iterdir():
logger.debug(f"found possible script {file.name}")
if file.is_file() and file.name.startswith('sync'):
user_script = file
if user_script is None:
logger.error(
"couldn't find user script\n"
"script must be placed under <extension>/user/ and filename must begin with sync"
)
return False
suffix = user_script.suffix[1:]
@ -74,15 +80,14 @@ class UI:
return 'No active jobs!', worker_status, logs
def save_btn(self, thin_client_mode, job_timeout, complement_production):
def save_btn(self, thin_client_mode, job_timeout, complement_production, step_scaling):
"""updates the options visible on the settings tab"""
self.world.thin_client_mode = thin_client_mode
logger.debug(f"thin client mode is now {thin_client_mode}")
job_timeout = int(job_timeout)
self.world.job_timeout = job_timeout
logger.debug(f"job timeout is now {job_timeout} seconds")
self.world.complement_production = complement_production
self.world.step_scaling = step_scaling
self.world.save_config()
def save_worker_btn(self, label, address, port, tls, disabled):
@ -102,7 +107,7 @@ class UI:
self.world.add_worker(
label=label,
address=address,
port=port,
port=port if len(port) > 0 else 7860,
tls=tls,
state=state
)
@ -208,7 +213,6 @@ class UI:
interactive=True
)
main_toggle.input(self.main_toggle_btn)
setattr(main_toggle, 'do_not_save_to_config', True) # ui_loadsave.py apply_field()
components.append(main_toggle)
with gradio.Tab('Status') as status_tab:
@ -240,7 +244,7 @@ class UI:
reload_config_btn = gradio.Button(value='📜 Reload config')
reload_config_btn.click(self.world.load_config)
redo_benchmarks_btn = gradio.Button(value='📊 Redo benchmarks', variant='stop')
redo_benchmarks_btn = gradio.Button(value='📊 Redo benchmarks')
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
run_usr_btn = gradio.Button(value='⚙️ Run script')
@ -252,10 +256,10 @@ class UI:
reconnect_lost_workers_btn = gradio.Button(value='🔌 Reconnect workers')
reconnect_lost_workers_btn.click(self.world.ping_remotes)
interrupt_all_btn = gradio.Button(value='⏸️ Interrupt all', variant='stop')
interrupt_all_btn = gradio.Button(value='⏸️ Interrupt all')
interrupt_all_btn.click(self.world.interrupt_remotes)
restart_workers_btn = gradio.Button(value="🔁 Restart All", variant='stop')
restart_workers_btn = gradio.Button(value="🔁 Restart All")
restart_workers_btn.click(
_js="confirm_restart_workers",
fn=lambda confirmed: self.world.restart_all() if confirmed else None,
@ -305,7 +309,7 @@ class UI:
# API authentication
worker_api_auth_cbx = gradio.Checkbox(label='API Authentication')
worker_user_field = gradio.Textbox(label='Username')
worker_password_field = gradio.Textbox(label='Password')
worker_password_field = gradio.Textbox(label='Password', type='password')
update_credentials_btn = gradio.Button(value='Update API Credentials')
update_credentials_btn.click(self.update_credentials_btn, inputs=[
worker_api_auth_cbx,
@ -346,25 +350,33 @@ class UI:
with gradio.Tab('Settings'):
thin_client_cbx = gradio.Checkbox(
label='Thin-client mode (experimental)',
info="(BROKEN) Only generate images using remote workers. There will be no previews when enabled.",
label='Thin-client mode',
info="Only generate images remotely (no image previews yet)",
value=self.world.thin_client_mode
)
job_timeout = gradio.Number(
label='Job timeout', value=self.world.job_timeout,
info="Seconds until a worker is considered too slow to be assigned an"
" equal share of the total request. Longer than 2 seconds is recommended."
" equal share of the total request. Longer than 2 seconds is recommended"
)
complement_production = gradio.Checkbox(
label='Complement production',
info='Prevents under-utilization of hardware by requesting additional images',
info='Prevents under-utilization by requesting additional images when possible',
value=self.world.complement_production
)
# reduces image quality the more the sample-count must be reduced
# good for mixed setups where each worker may not be around the same speed
step_scaling = gradio.Checkbox(
label='Step scaling',
info='Prevents under-utilization via sample reduction in order to meet time constraints',
value=self.world.step_scaling
)
save_btn = gradio.Button(value='Update')
save_btn.click(fn=self.save_btn, inputs=[thin_client_cbx, job_timeout, complement_production])
components += [thin_client_cbx, job_timeout, complement_production, save_btn]
save_btn.click(fn=self.save_btn, inputs=[thin_client_cbx, job_timeout, complement_production, step_scaling])
components += [thin_client_cbx, job_timeout, complement_production, step_scaling, save_btn]
with gradio.Tab('Help'):
gradio.Markdown(
@ -374,4 +386,7 @@ class UI:
"""
)
# prevent wui from overriding any values
for component in components:
setattr(component, 'do_not_save_to_config', True) # ui_loadsave.py apply_field()
return components

View File

@ -156,6 +156,11 @@ class Worker:
def __repr__(self):
return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}"
def __eq__(self, other):
if isinstance(other, Worker) and other.label == self.label:
return True
return False
@property
def model(self) -> Worker_Model:
return Worker_Model(**self.__dict__)
@ -189,14 +194,14 @@ class Worker:
protocol = 'http' if not self.tls else 'https'
return f"{protocol}://{self.__str__()}/sdapi/v1/{route}"
def batch_eta_hr(self, payload: dict) -> float:
def eta_hr(self, payload: dict) -> float:
"""
takes a normal payload and returns the eta of a pseudo payload which mirrors the hr-fix parameters
This returns the eta of how long it would take to run hr-fix on the original image
"""
pseudo_payload = copy.copy(payload)
pseudo_payload['enable_hr'] = False # prevent overflow in self.batch_eta
pseudo_payload['enable_hr'] = False # prevent overflow in self.eta
res_ratio = pseudo_payload['hr_scale']
original_steps = pseudo_payload['steps']
second_pass_steps = pseudo_payload['hr_second_pass_steps']
@ -212,12 +217,11 @@ class Worker:
pseudo_payload['width'] = pseudo_width
pseudo_payload['height'] = pseudo_height
eta = self.batch_eta(payload=pseudo_payload, quiet=True)
return eta
return self.eta(payload=pseudo_payload, quiet=True)
def batch_eta(self, payload: dict, quiet: bool = False, batch_size: int = None) -> float:
def eta(self, payload: dict, quiet: bool = False, batch_size: int = None, samples: int = None) -> float:
"""
estimate how long it will take to generate <batch_size> images on a worker in seconds
estimate how long it will take to generate image(s) on a worker in seconds
Args:
payload: Sdwui api formatted payload
@ -225,7 +229,7 @@ class Worker:
batch_size: Overrides the batch_size parameter of the payload
"""
steps = payload['steps']
steps = payload['steps'] if samples is None else samples
num_images = payload['batch_size'] if batch_size is None else batch_size
# if worker has not yet been benchmarked then
@ -237,7 +241,7 @@ class Worker:
# show effect of high-res fix
hr = payload.get('enable_hr', False)
if hr:
eta += self.batch_eta_hr(payload=payload)
eta += self.eta_hr(payload=payload)
# show effect of image size
real_pix_to_benched = (payload['width'] * payload['height']) \
@ -331,7 +335,7 @@ class Worker:
self.load_options(model=option_payload['sd_model_checkpoint'], vae=option_payload['sd_vae'])
if self.benchmarked:
eta = self.batch_eta(payload=payload) * payload['n_iter']
eta = self.eta(payload=payload) * payload['n_iter']
logger.debug(f"worker '{self.label}' predicts it will take {eta:.3f}s to generate "
f"{payload['batch_size'] * payload['n_iter']} image(s) "
f"at a speed of {self.avg_ipm:.2f} ipm\n")
@ -471,7 +475,7 @@ class Worker:
self.response_time = time.time() - start
variance = ((eta - self.response_time) / self.response_time) * 100
logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%.\n"
logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%\n"
f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
# if the variance is greater than 500% then we ignore it to prevent variation inflation
@ -501,20 +505,20 @@ class Worker:
self.jobs_requested += 1
return
def benchmark(self) -> float:
def benchmark(self, sample_function: callable = None) -> float:
"""
given a worker, run a small benchmark and return its performance in images/minute
makes standard request(s) of 512x512 images and averages them to get the result
"""
t: Thread
samples = 2 # number of times to benchmark the remote / accuracy
if self.state == State.DISABLED or self.state == State.UNAVAILABLE:
if self.state in (State.DISABLED, State.UNAVAILABLE):
logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark")
return 0
if self.master is True:
if self.master and sample_function is None:
logger.critical(f"no function provided for benchmarking master")
return -1
def ipm(seconds: float) -> float:
@ -533,19 +537,24 @@ class Worker:
results: List[float] = []
# it used to be lower for the first couple of generations
# this was due to something torch does at startup according to auto and is now done at sdwui startup
self.state = State.WORKING
for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up"
for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up"
if self.state == State.UNAVAILABLE:
self.response = None
return 0
t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,),
name=f"{self.label}_benchmark_request")
try: # if the worker is unreachable/offline then handle that here
t.start()
start = time.time()
t.join()
elapsed = time.time() - start
elapsed = None
if not callable(sample_function):
start = time.time()
t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,),
name=f"{self.label}_benchmark_request")
t.start()
t.join()
elapsed = time.time() - start
else:
elapsed = sample_function()
sample_ipm = ipm(elapsed)
except InvalidWorkerResponse as e:
raise e
@ -558,15 +567,13 @@ class Worker:
logger.debug(f"{self.label} finished warming up\n")
# average the sample results for accuracy
ipm_sum = 0
for ipm_result in results:
ipm_sum += ipm_result
avg_ipm_result = ipm_sum / samples
avg_ipm_result = sum(results) / sh.samples
logger.debug(f"Worker '{self.label}' average ipm: {avg_ipm_result:.2f}")
self.avg_ipm = avg_ipm_result
self.response = None
self.benchmarked = True
self.eta_percent_error = [] # likely inaccurate after rebenching
self.state = State.IDLE
return avg_ipm_result
@ -610,6 +617,10 @@ class Worker:
except requests.exceptions.ConnectionError as e:
logger.error(e)
return False
except requests.ReadTimeout as e:
logger.critical(f"worker '{self.label}' is online but not responding (crashed?)")
logger.error(e)
return False
def mark_unreachable(self):
if self.state == State.DISABLED:
@ -677,6 +688,10 @@ class Worker:
if vae is not None:
self.loaded_vae = vae
self.response = response
return self
def restart(self) -> bool:
err_msg = f"could not restart worker '{self.label}'"
success_msg = f"worker '{self.label}' is restarting"

View File

@ -4,7 +4,7 @@ This module facilitates the creation of a stable-diffusion-webui centered distri
World:
The main class which should be instantiated in order to create a new sdwui distributed system.
"""
import concurrent.futures
import copy
import json
import os
@ -16,8 +16,10 @@ import modules.shared as shared
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
from . import shared as sh
from .pmodels import ConfigModel, Benchmark_Payload
from .shared import logger, warmup_samples, extension_path
from .shared import logger, extension_path
from .worker import Worker, State
from modules.call_queue import wrap_queued_call
from modules import processing
class NotBenchmarked(Exception):
@ -28,13 +30,6 @@ class NotBenchmarked(Exception):
pass
class WorldAlreadyInitialized(Exception):
"""
Raised when attempting to initialize the World when it has already been initialized.
"""
pass
class Job:
"""
Keeps track of how much work a given worker should handle.
@ -48,6 +43,7 @@ class Job:
self.worker: Worker = worker
self.batch_size: int = batch_size
self.complementary: bool = False
self.step_override = None
def __str__(self):
prefix = ''
@ -75,7 +71,7 @@ class World:
The frame or "world" which holds all workers (including the local machine).
Args:
initial_payload: The original txt2img payload created by the user initiating the generation request on master.
p: The original processing state object created by the user initiating the generation request on master.
verify_remotes (bool): Whether to validate remote worker certificates.
"""
@ -83,19 +79,19 @@ class World:
config_path = shared.cmd_opts.distributed_config
old_config_path = worker_info_path = extension_path.joinpath('workers.json')
def __init__(self, initial_payload, verify_remotes: bool = True):
def __init__(self, verify_remotes: bool = True):
self.p = None
self.master_worker = Worker(master=True)
self.total_batch_size: int = 0
self._workers: List[Worker] = [self.master_worker]
self.jobs: List[Job] = []
self.job_timeout: int = 3 # seconds
self.initialized: bool = False
self.verify_remotes = verify_remotes
self.initial_payload = copy.copy(initial_payload)
self.thin_client_mode = False
self.enabled = True
self.is_dropdown_handler_injected = False
self.complement_production = True
self.step_scaling = False
def __getitem__(self, label: str) -> Worker:
for worker in self._workers:
@ -105,32 +101,12 @@ class World:
def __repr__(self):
return f"{len(self._workers)} workers"
def update_world(self, total_batch_size):
"""
Updates the world with information vital to handling the local generation request after
the world has already been initialized.
Args:
total_batch_size (int): The total number of images requested by the local/master sdwui instance.
"""
self.total_batch_size = total_batch_size
self.update_jobs()
def initialize(self, total_batch_size):
"""should be called before a world instance is used for anything"""
if self.initialized:
raise WorldAlreadyInitialized("This world instance was already initialized")
self.benchmark()
self.update_world(total_batch_size=total_batch_size)
self.initialized = True
def default_batch_size(self) -> int:
"""the amount of images/total images requested that a worker would compute if conditions were perfect and
each worker generated at the same speed. assumes one batch only"""
return self.total_batch_size // self.size()
return self.p.batch_size // self.size()
def size(self) -> int:
"""
@ -169,6 +145,14 @@ class World:
Worker: The worker object.
"""
# protect against user trying to make cyclical setups and connections
is_master = kwargs.get('master')
if is_master is None or not is_master:
m = self.master()
if kwargs['address'] == m.address and kwargs['port'] == m.port:
logger.error(f"refusing to add worker {kwargs['label']} as its socket definition({m.address}:{m.port}) matches master")
return None
original = self[kwargs['label']] # if worker doesn't already exist then just make a new one
if original is None:
new = Worker(**kwargs)
@ -192,16 +176,33 @@ class World:
if worker.master:
continue
t = Thread(target=worker.interrupt, args=())
t.start()
Thread(target=worker.interrupt, args=()).start()
def refresh_checkpoints(self):
for worker in self.get_workers():
if worker.master:
continue
t = Thread(target=worker.refresh_checkpoints, args=())
t.start()
Thread(target=worker.refresh_checkpoints, args=()).start()
def sample_master(self) -> float:
# wrap our benchmark payload
master_bench_payload = StableDiffusionProcessingTxt2Img()
d = sh.benchmark_payload.dict()
for key in d:
setattr(master_bench_payload, key, d[key])
# Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to.
master_bench_payload.do_not_save_samples = True
# shared.state.begin(job='distributed_master_bench')
wrapped = (wrap_queued_call(process_images))
start = time.time()
wrapped(master_bench_payload)
# wrap_gradio_gpu_call(process_images)(master_bench_payload)
# shared.state.end()
return time.time() - start
def benchmark(self, rebenchmark: bool = False):
"""
@ -209,14 +210,6 @@ class World:
"""
unbenched_workers = []
benchmark_threads: List[Thread] = []
sync_threads: List[Thread] = []
def benchmark_wrapped(worker):
bench_func = worker.benchmark if not worker.master else self.benchmark_master
worker.avg_ipm = bench_func()
worker.benchmarked = True
if rebenchmark:
for worker in self._workers:
worker.benchmarked = False
@ -231,28 +224,44 @@ class World:
else:
worker.benchmarked = True
# have every unbenched worker load the same weights before the benchmark
for worker in unbenched_workers:
if worker.master or worker.state == State.DISABLED:
continue
with concurrent.futures.ThreadPoolExecutor(thread_name_prefix='distributed_benchmark') as executor:
futures = []
sync_thread = Thread(target=worker.load_options, args=(shared.opts.sd_model_checkpoint, shared.opts.sd_vae))
sync_threads.append(sync_thread)
sync_thread.start()
for thread in sync_threads:
thread.join()
# have every unbenched worker load the same weights before the benchmark
for worker in unbenched_workers:
if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE):
continue
# benchmark those that haven't been
for worker in unbenched_workers:
t = Thread(target=benchmark_wrapped, args=(worker, ), name=f"{worker.label}_benchmark")
benchmark_threads.append(t)
t.start()
logger.info(f"benchmarking worker '{worker.label}'")
futures.append(
executor.submit(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae)
)
for future in concurrent.futures.as_completed(futures):
worker = future.result()
if worker is None:
continue
# wait for all benchmarks to finish and update stats on newly benchmarked workers
if len(benchmark_threads) > 0:
for t in benchmark_threads:
t.join()
if worker.response.status_code != 200:
logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n"
f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers")
unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers))
futures.clear()
# benchmark those that haven't been
for worker in unbenched_workers:
if worker.state in (State.DISABLED, State.UNAVAILABLE):
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark")
continue
if worker.model_override is not None:
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
f"*all workers should be evaluated against the same model")
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
futures.append(executor.submit(chosen, worker))
logger.info(f"benchmarking worker '{worker.label}'")
# wait for all benchmarks to finish and update stats on newly benchmarked workers
concurrent.futures.wait(futures)
logger.info("benchmarking finished")
# save benchmark results to workers.json
@ -348,43 +357,11 @@ class World:
if worker == fastest_worker:
return 0
lag = worker.batch_eta(payload=payload, quiet=True, batch_size=batch_size) - fastest_worker.batch_eta(payload=payload, quiet=True, batch_size=batch_size)
lag = worker.eta(payload=payload, quiet=True, batch_size=batch_size) - fastest_worker.eta(payload=payload, quiet=True, batch_size=batch_size)
return lag
def benchmark_master(self) -> float:
"""
Benchmarks the local/master worker.
Returns:
float: Local worker speed in ipm
"""
# wrap our benchmark payload
master_bench_payload = StableDiffusionProcessingTxt2Img()
d = sh.benchmark_payload.dict()
for key in d:
setattr(master_bench_payload, key, d[key])
# Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to.
master_bench_payload.do_not_save_samples = True
# "warm up" due to initial generation lag
for _ in range(warmup_samples):
process_images(master_bench_payload)
# get actual sample
start = time.time()
process_images(master_bench_payload)
elapsed = time.time() - start
ipm = sh.benchmark_payload.batch_size / (elapsed / 60)
logger.debug(f"Master benchmark took {elapsed:.2f}: {ipm:.2f} ipm")
self.master().benchmarked = True
return ipm
def update_jobs(self):
def make_jobs(self):
"""creates initial jobs (before optimization) """
# clear jobs if this is not the first time running
@ -398,6 +375,19 @@ class World:
worker.benchmark()
self.jobs.append(Job(worker=worker, batch_size=batch_size))
logger.debug(f"added job for worker {worker.label}")
def update(self, p):
"""preps world for another run"""
if not self.initialized:
self.benchmark()
self.initialized = True
logger.debug("world initialized!")
else:
logger.debug("world was already initialized")
self.p = p
self.make_jobs()
def get_workers(self):
filtered: List[Worker] = []
@ -434,7 +424,7 @@ class World:
logger.debug(f"worker '{job.worker.label}' would stall the image gallery by ~{lag:.2f}s\n")
job.complementary = True
if deferred_images + images_checked + payload['batch_size'] > self.total_batch_size:
if deferred_images + images_checked + payload['batch_size'] > self.p.batch_size:
logger.debug(f"would go over actual requested size")
else:
deferred_images += payload['batch_size']
@ -477,9 +467,9 @@ class World:
#######################
# when total number of requested images was not cleanly divisible by world size then we tack the remainder on
remainder_images = self.total_batch_size - self.get_current_output_size()
remainder_images = self.p.batch_size - self.get_current_output_size()
if remainder_images >= 1:
logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
realtime_jobs = self.realtime_jobs()
realtime_jobs.sort(key=lambda x: x.batch_size)
@ -521,16 +511,16 @@ class World:
fastest_active = self.fastest_realtime_job().worker
for j in self.jobs:
if j.worker.label == fastest_active.label:
slack_time = fastest_active.batch_eta(payload=payload, batch_size=j.batch_size) + self.job_timeout
slack_time = fastest_active.eta(payload=payload, batch_size=j.batch_size) + self.job_timeout
logger.debug(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.label}'")
# see how long it would take to produce only 1 image on this complementary worker
secs_per_batch_image = job.worker.batch_eta(payload=payload, batch_size=1)
secs_per_batch_image = job.worker.eta(payload=payload, batch_size=1)
num_images_compensate = int(slack_time / secs_per_batch_image)
logger.debug(
f"worker '{job.worker.label}':\n"
f"{num_images_compensate} complementary image(s) = {slack_time:.2f}s slack"
f"/ {secs_per_batch_image:.2f}s per requested image"
f" ÷ {secs_per_batch_image:.2f}s per requested image"
)
if not job.add_work(payload, batch_size=num_images_compensate):
@ -538,14 +528,29 @@ class World:
request_img_size = payload['width'] * payload['height']
max_images = job.worker.pixel_cap // request_img_size
job.add_work(payload, batch_size=max_images)
# when not even a singular image can be squeezed out
# if step scaling is enabled, then find how many samples would be considered realtime and adjust
if num_images_compensate == 0 and self.step_scaling:
seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1)
realtime_samples = slack_time // seconds_per_sample
logger.debug(
f"job for '{job.worker.label}' downscaled to {realtime_samples} samples to meet time constraints\n"
f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack ÷ {seconds_per_sample:.2f}s/sample\n"
f" step reduction: {payload['steps']} -> {realtime_samples:.0f}"
)
job.add_work(payload=payload, batch_size=1)
job.step_override = realtime_samples
else:
logger.debug("complementary image production is disabled")
iterations = payload['n_iter']
num_returning = self.get_current_output_size()
num_complementary = num_returning - self.total_batch_size
num_complementary = num_returning - self.p.batch_size
distro_summary = "Job distribution:\n"
distro_summary += f"{self.total_batch_size} * {iterations} iteration(s)"
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
if num_complementary > 0:
distro_summary += f" + {num_complementary} complementary"
distro_summary += f": {num_returning} images total\n"
@ -553,6 +558,22 @@ class World:
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
logger.info(distro_summary)
if self.thin_client_mode is True or self.master_job().batch_size == 0:
# save original process_images_inner for later so we can restore once we're done
logger.debug(f"bypassing local generation completely")
def process_images_inner_bypass(p) -> processing.Processed:
processed = processing.Processed(p, [], p.seed, info="")
processed.all_prompts = []
processed.all_seeds = []
processed.all_subseeds = []
processed.all_negative_prompts = []
processed.infotexts = []
processed.prompt = None
self.p.scripts.postprocess(p, processed)
return processed
processing.process_images_inner = process_images_inner_bypass
# delete any jobs that have no work
last = len(self.jobs) - 1
while last > 0:
@ -631,6 +652,8 @@ class World:
label = next(iter(w.keys()))
fields = w[label].__dict__
fields['label'] = label
# TODO must be overridden everytime here or later converted to a config file variable at some point
fields['verify_remotes'] = self.verify_remotes
self.add_worker(**fields)
@ -638,6 +661,7 @@ class World:
self.job_timeout = config.job_timeout
self.enabled = config.enabled
self.complement_production = config.complement_production
self.step_scaling = config.step_scaling
logger.debug("config loaded")
@ -651,7 +675,8 @@ class World:
benchmark_payload=sh.benchmark_payload,
job_timeout=self.job_timeout,
enabled=self.enabled,
complement_production=self.complement_production
complement_production=self.complement_production,
step_scaling=self.step_scaling
)
with open(self.config_path, 'w+') as config_file:
@ -679,22 +704,24 @@ class World:
if worker.queried and worker.state == State.IDLE: # TODO worker.queried
continue
# for now skip/remove scripts that are not "always on" since there is currently no way to run
# them at the same time as distributed
supported_scripts = {
'txt2img': [],
'img2img': []
}
script_info = worker.session.get(url=worker.full_url('script-info')).json()
for key in script_info:
name = key.get('name', None)
response = worker.session.get(url=worker.full_url('script-info'))
if response.status_code == 200:
script_info = response.json()
for key in script_info:
name = key.get('name', None)
if name is not None:
is_alwayson = key.get('is_alwayson', False)
is_img2img = key.get('is_img2img', False)
if is_alwayson:
supported_scripts['img2img' if is_img2img else 'txt2img'].append(name)
if name is not None:
is_alwayson = key.get('is_alwayson', False)
is_img2img = key.get('is_img2img', False)
if is_alwayson:
supported_scripts['img2img' if is_img2img else 'txt2img'].append(name)
else:
logger.error(f"failed to query script-info for worker '{worker.label}': {response}")
worker.supported_scripts = supported_scripts
msg = f"worker '{worker.label}' is online"