commit
b1f9d4f769
21
CHANGELOG.md
21
CHANGELOG.md
|
|
@ -1,6 +1,27 @@
|
|||
# Change Log
|
||||
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
|
||||
|
||||
## [2.2.0] - 2024-5-11
|
||||
|
||||
### Added
|
||||
- Toggle for allowing automatic step scaling which can increase overall utilization
|
||||
|
||||
### Changed
|
||||
- Adding workers which have the same socket definition as master will no longer be allowed and an error will show #28
|
||||
- Workers in an invalid state should no longer be benchmarked
|
||||
- The worker port under worker config will now default to 7860 to prevent mishaps
|
||||
- Config should once again only be loaded once per session startup
|
||||
- A warning will be shown when trying to use the user script button but no script exists
|
||||
|
||||
### Fixed
|
||||
- Thin-client mode
|
||||
- Some problems with sdwui forge branch
|
||||
- Certificate verification setting sometimes not saving
|
||||
- Master being assigned no work stopping generation (same problem as thin-client)
|
||||
|
||||
### Removed
|
||||
- Adding workers using deprecated cmdline argument
|
||||
|
||||
## [2.1.0] - 2024-3-03
|
||||
|
||||
### Added
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import re
|
|||
import signal
|
||||
import sys
|
||||
import time
|
||||
from threading import Thread, current_thread
|
||||
from threading import Thread
|
||||
from typing import List
|
||||
import gradio
|
||||
import urllib3
|
||||
|
|
@ -18,28 +18,24 @@ from PIL import Image
|
|||
from modules import processing
|
||||
from modules import scripts
|
||||
from modules.images import save_image
|
||||
from modules.processing import fix_seed, Processed
|
||||
from modules.processing import fix_seed
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.shared import state as webui_state
|
||||
from scripts.spartan.control_net import pack_control_net
|
||||
from scripts.spartan.shared import logger
|
||||
from scripts.spartan.ui import UI
|
||||
from scripts.spartan.world import World, WorldAlreadyInitialized
|
||||
from scripts.spartan.world import World
|
||||
|
||||
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||
|
||||
|
||||
# TODO implement advertisement of some sort in sdwui api to allow extension to automatically discover workers?
|
||||
# noinspection PyMissingOrEmptyDocstring
|
||||
class Script(scripts.Script):
|
||||
class DistributedScript(scripts.Script):
|
||||
# global old_sigterm_handler, old_sigterm_handler
|
||||
worker_threads: List[Thread] = []
|
||||
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
||||
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
|
||||
|
||||
is_img2img = True
|
||||
is_txt2img = True
|
||||
alwayson = True
|
||||
master_start = None
|
||||
runs_since_init = 0
|
||||
name = "distributed"
|
||||
|
|
@ -50,20 +46,15 @@ class Script(scripts.Script):
|
|||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
# build world
|
||||
world = World(initial_payload=None, verify_remotes=verify_remotes)
|
||||
# add workers to the world
|
||||
world = World(verify_remotes=verify_remotes)
|
||||
world.load_config()
|
||||
if cmd_opts.distributed_remotes is not None and len(cmd_opts.distributed_remotes) > 0:
|
||||
logger.warning(f"--distributed-remotes is deprecated and may be removed in the future\n"
|
||||
f"gui/external modification of {world.config_path} will be prioritized going forward")
|
||||
|
||||
for worker in cmd_opts.distributed_remotes:
|
||||
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2], tls=False)
|
||||
world.save_config()
|
||||
# do an early check to see which workers are online
|
||||
logger.info("doing initial ping sweep to see which workers are reachable")
|
||||
world.ping_remotes(indiscriminate=True)
|
||||
|
||||
# constructed for both txt2img and img2img
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def title(self):
|
||||
return "Distribute"
|
||||
|
||||
|
|
@ -71,21 +62,18 @@ class Script(scripts.Script):
|
|||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
self.world.load_config()
|
||||
extension_ui = UI(script=Script, world=Script.world)
|
||||
extension_ui = UI(world=self.world)
|
||||
# root, api_exposed = extension_ui.create_ui()
|
||||
components = extension_ui.create_ui()
|
||||
|
||||
# The first injection of handler for the models dropdown(sd_model_checkpoint) which is often present
|
||||
# in the quick-settings bar of a user. Helps ensure model swaps propagate to all nodes ASAP.
|
||||
Script.world.inject_model_dropdown_handler()
|
||||
self.world.inject_model_dropdown_handler()
|
||||
# return some components that should be exposed to the api
|
||||
return components
|
||||
|
||||
@staticmethod
|
||||
def add_to_gallery(processed, p):
|
||||
def add_to_gallery(self, processed, p):
|
||||
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
||||
webui_state.textinfo = "Distributed - injecting images"
|
||||
|
||||
def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None):
|
||||
image_params: json = response['parameters']
|
||||
|
|
@ -129,10 +117,10 @@ class Script(scripts.Script):
|
|||
if p.n_iter > 1: # if splitting by batch count
|
||||
num_remote_images *= p.n_iter - 1
|
||||
|
||||
logger.debug(f"image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, "
|
||||
logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, "
|
||||
f"info-index: {info_index}")
|
||||
|
||||
if Script.world.thin_client_mode:
|
||||
if self.world.thin_client_mode:
|
||||
p.all_negative_prompts = processed.all_negative_prompts
|
||||
|
||||
try:
|
||||
|
|
@ -157,19 +145,21 @@ class Script(scripts.Script):
|
|||
)
|
||||
|
||||
# get master ipm by estimating based on worker speed
|
||||
master_elapsed = time.time() - Script.master_start
|
||||
master_elapsed = time.time() - self.master_start
|
||||
logger.debug(f"Took master {master_elapsed:.2f}s")
|
||||
|
||||
# wait for response from all workers
|
||||
for thread in Script.worker_threads:
|
||||
webui_state.textinfo = "Distributed - receiving results"
|
||||
for thread in self.worker_threads:
|
||||
logger.debug(f"waiting for worker thread '{thread.name}'")
|
||||
thread.join()
|
||||
Script.worker_threads.clear()
|
||||
self.worker_threads.clear()
|
||||
logger.debug("all worker request threads returned")
|
||||
webui_state.textinfo = "Distributed - injecting images"
|
||||
|
||||
# some worker which we know has a good response that we can use for generating the grid
|
||||
donor_worker = None
|
||||
for job in Script.world.jobs:
|
||||
for job in self.world.jobs:
|
||||
if job.batch_size < 1 or job.worker.master:
|
||||
continue
|
||||
|
||||
|
|
@ -177,7 +167,7 @@ class Script(scripts.Script):
|
|||
images: json = job.worker.response["images"]
|
||||
# if we for some reason get more than we asked for
|
||||
if (job.batch_size * p.n_iter) < len(images):
|
||||
logger.debug(f"Requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}")
|
||||
logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}")
|
||||
|
||||
if donor_worker is None:
|
||||
donor_worker = job.worker
|
||||
|
|
@ -208,7 +198,7 @@ class Script(scripts.Script):
|
|||
|
||||
# generate and inject grid
|
||||
if opts.return_grid:
|
||||
grid = processing.images.image_grid(processed.images, len(processed.images))
|
||||
grid = images.image_grid(processed.images, len(processed.images))
|
||||
processed_inject_image(
|
||||
image=grid,
|
||||
info_index=0,
|
||||
|
|
@ -218,35 +208,22 @@ class Script(scripts.Script):
|
|||
)
|
||||
|
||||
# cleanup after we're doing using all the responses
|
||||
for worker in Script.world.get_workers():
|
||||
for worker in self.world.get_workers():
|
||||
worker.response = None
|
||||
|
||||
p.batch_size = len(processed.images)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def initialize(initial_payload):
|
||||
# get default batch size
|
||||
try:
|
||||
batch_size = initial_payload.batch_size
|
||||
except AttributeError:
|
||||
batch_size = 1
|
||||
|
||||
try:
|
||||
Script.world.initialize(batch_size)
|
||||
logger.debug(f"World initialized!")
|
||||
except WorldAlreadyInitialized:
|
||||
Script.world.update_world(total_batch_size=batch_size)
|
||||
|
||||
# p's type is
|
||||
# "modules.processing.StableDiffusionProcessingTxt2Img"
|
||||
# "modules.processing.StableDiffusionProcessing*"
|
||||
def before_process(self, p, *args):
|
||||
if not self.world.enabled:
|
||||
logger.debug("extension is disabled")
|
||||
return
|
||||
self.world.update(p)
|
||||
|
||||
current_thread().name = "distributed_main"
|
||||
Script.initialize(initial_payload=p)
|
||||
# save original process_images_inner function for later if we monkeypatch it
|
||||
self.original_process_images_inner = processing.process_images_inner
|
||||
|
||||
# strip scripts that aren't yet supported and warn user
|
||||
packed_script_args: List[dict] = [] # list of api formatted per-script argument objects
|
||||
|
|
@ -261,8 +238,9 @@ class Script(scripts.Script):
|
|||
# grab all controlnet units
|
||||
cn_units = []
|
||||
cn_args = p.script_args[script.args_from:script.args_to]
|
||||
|
||||
for cn_arg in cn_args:
|
||||
if type(cn_arg).__name__ == "UiControlNetUnit":
|
||||
if "ControlNetUnit" in type(cn_arg).__name__:
|
||||
cn_units.append(cn_arg)
|
||||
logger.debug(f"Detected {len(cn_units)} controlnet unit(s)")
|
||||
|
||||
|
|
@ -281,7 +259,7 @@ class Script(scripts.Script):
|
|||
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
|
||||
# see test/basic_features/txt2img_test.py for an example
|
||||
payload = copy.copy(p.__dict__)
|
||||
payload['batch_size'] = Script.world.default_batch_size()
|
||||
payload['batch_size'] = self.world.default_batch_size()
|
||||
payload['scripts'] = None
|
||||
try:
|
||||
del payload['script_args']
|
||||
|
|
@ -300,8 +278,9 @@ class Script(scripts.Script):
|
|||
# TODO api for some reason returns 200 even if something failed to be set.
|
||||
# for now we may have to make redundant GET requests to check if actually successful...
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146
|
||||
|
||||
name = re.sub(r'\s?\[[^]]*]$', '', opts.data["sd_model_checkpoint"])
|
||||
vae = opts.data["sd_vae"]
|
||||
vae = opts.data.get('sd_vae')
|
||||
option_payload = {
|
||||
"sd_model_checkpoint": name,
|
||||
"sd_vae": vae
|
||||
|
|
@ -309,11 +288,11 @@ class Script(scripts.Script):
|
|||
|
||||
# start generating images assigned to remote machines
|
||||
sync = False # should only really need to sync once per job
|
||||
Script.world.optimize_jobs(payload) # optimize work assignment before dispatching
|
||||
self.world.optimize_jobs(payload) # optimize work assignment before dispatching
|
||||
started_jobs = []
|
||||
|
||||
# check if anything even needs to be done
|
||||
if len(Script.world.jobs) == 1 and Script.world.jobs[0].worker.master:
|
||||
if len(self.world.jobs) == 1 and self.world.jobs[0].worker.master:
|
||||
|
||||
if payload['batch_size'] >= 2:
|
||||
msg = f"all remote workers are offline or unreachable"
|
||||
|
|
@ -324,7 +303,7 @@ class Script(scripts.Script):
|
|||
|
||||
return
|
||||
|
||||
for job in Script.world.jobs:
|
||||
for job in self.world.jobs:
|
||||
payload_temp = copy.copy(payload)
|
||||
del payload_temp['scripts_value']
|
||||
payload_temp = copy.deepcopy(payload_temp)
|
||||
|
|
@ -339,11 +318,14 @@ class Script(scripts.Script):
|
|||
prior_images += j.batch_size * p.n_iter
|
||||
|
||||
payload_temp['batch_size'] = job.batch_size
|
||||
if job.step_override is not None:
|
||||
payload_temp['steps'] = job.step_override
|
||||
payload_temp['subseed'] += prior_images
|
||||
payload_temp['seed'] += prior_images if payload_temp['subseed_strength'] == 0 else 0
|
||||
logger.debug(
|
||||
f"'{job.worker.label}' job's given starting seed is "
|
||||
f"{payload_temp['seed']} with {prior_images} coming before it")
|
||||
f"{payload_temp['seed']} with {prior_images} coming before it"
|
||||
)
|
||||
|
||||
if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
|
||||
sync = True
|
||||
|
|
@ -354,42 +336,34 @@ class Script(scripts.Script):
|
|||
name=f"{job.worker.label}_request")
|
||||
|
||||
t.start()
|
||||
Script.worker_threads.append(t)
|
||||
self.worker_threads.append(t)
|
||||
started_jobs.append(job)
|
||||
|
||||
# if master batch size was changed again due to optimization change it to the updated value
|
||||
if not self.world.thin_client_mode:
|
||||
p.batch_size = Script.world.master_job().batch_size
|
||||
Script.master_start = time.time()
|
||||
p.batch_size = self.world.master_job().batch_size
|
||||
self.master_start = time.time()
|
||||
|
||||
# generate images assigned to local machine
|
||||
p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
|
||||
if Script.world.thin_client_mode:
|
||||
p.batch_size = 0
|
||||
processed = Processed(p=p, images_list=[])
|
||||
processed.all_prompts = []
|
||||
processed.all_seeds = []
|
||||
processed.all_subseeds = []
|
||||
processed.all_negative_prompts = []
|
||||
processed.infotexts = []
|
||||
processed.prompt = None
|
||||
|
||||
Script.runs_since_init += 1
|
||||
self.runs_since_init += 1
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def postprocess(p, processed, *args):
|
||||
if not Script.world.enabled:
|
||||
def postprocess(self, p, processed, *args):
|
||||
if not self.world.enabled:
|
||||
return
|
||||
|
||||
if len(processed.images) >= 1 and Script.master_start is not None:
|
||||
Script.add_to_gallery(p=p, processed=processed)
|
||||
if self.master_start is not None:
|
||||
self.add_to_gallery(p=p, processed=processed)
|
||||
|
||||
# restore process_images_inner if it was monkey-patched
|
||||
processing.process_images_inner = self.original_process_images_inner
|
||||
|
||||
@staticmethod
|
||||
def signal_handler(sig, frame):
|
||||
logger.debug("handling interrupt signal")
|
||||
# do cleanup
|
||||
Script.world.save_config()
|
||||
DistributedScript.world.save_config()
|
||||
|
||||
if sig == signal.SIGINT:
|
||||
if callable(old_sigint_handler):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,16 @@ import copy
|
|||
from PIL import Image
|
||||
from modules.api.api import encode_pil_to_base64
|
||||
from scripts.spartan.shared import logger
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
|
||||
def np_to_b64(image: np.ndarray):
|
||||
pil = Image.fromarray(image)
|
||||
image_b64 = str(encode_pil_to_base64(pil), 'utf-8')
|
||||
image_b64 = 'data:image/png;base64,' + image_b64
|
||||
|
||||
return image_b64
|
||||
|
||||
|
||||
def pack_control_net(cn_units) -> dict:
|
||||
|
|
@ -17,28 +27,46 @@ def pack_control_net(cn_units) -> dict:
|
|||
cn_args = controlnet['controlnet']['args']
|
||||
|
||||
for i in range(0, len(cn_units)):
|
||||
# copy control net unit to payload
|
||||
cn_args.append(copy.copy(cn_units[i].__dict__))
|
||||
if cn_units[i].enabled:
|
||||
cn_args.append(copy.deepcopy(cn_units[i].__dict__))
|
||||
else:
|
||||
logger.debug(f"controlnet unit {i} is not enabled (ignoring)")
|
||||
|
||||
for i in range(0, len(cn_args)):
|
||||
unit = cn_args[i]
|
||||
|
||||
# if unit isn't enabled then don't bother including
|
||||
if not unit['enabled']:
|
||||
del unit['input_mode']
|
||||
del unit['image']
|
||||
logger.debug(f"Controlnet unit {i} is not enabled. Ignoring")
|
||||
continue
|
||||
|
||||
# serialize image
|
||||
if unit['image'] is not None:
|
||||
image = unit['image']['image']
|
||||
# mask = unit['image']['mask']
|
||||
pil = Image.fromarray(image)
|
||||
image_b64 = encode_pil_to_base64(pil)
|
||||
image_b64 = str(image_b64, 'utf-8')
|
||||
unit['input_image'] = image_b64
|
||||
image_pair = unit.get('image')
|
||||
if image_pair is not None:
|
||||
image_b64 = np_to_b64(image_pair['image'])
|
||||
unit['input_image'] = image_b64 # mikubill
|
||||
unit['image'] = image_b64 # forge
|
||||
|
||||
if np.all(image_pair['mask'] == 0):
|
||||
# stand-alone mask from second gradio component
|
||||
standalone_mask = unit.get('mask_image')
|
||||
if standalone_mask is not None:
|
||||
logger.debug(f"found stand-alone mask for controlnet unit {i}")
|
||||
mask_b64 = np_to_b64(unit['mask_image']['mask'])
|
||||
unit['mask'] = mask_b64 # mikubill
|
||||
unit['mask_image'] = mask_b64 # forge
|
||||
|
||||
else:
|
||||
# mask from singular gradio image component
|
||||
logger.debug(f"found mask for controlnet unit {i}")
|
||||
mask_b64 = np_to_b64(image_pair['mask'])
|
||||
unit['mask'] = mask_b64 # mikubill
|
||||
unit['mask_image'] = mask_b64 # forge
|
||||
|
||||
# avoid returning duplicate detection maps since master should return the same one
|
||||
unit['save_detected_map'] = False
|
||||
# remove anything unserializable
|
||||
del unit['input_mode']
|
||||
del unit['image']
|
||||
|
||||
try:
|
||||
json.dumps(controlnet)
|
||||
except Exception as e:
|
||||
logger.error(f"failed to serialize controlnet\nfirst unit:\n{controlnet['controlnet']['args'][0]}")
|
||||
return {}
|
||||
|
||||
return controlnet
|
||||
|
|
|
|||
|
|
@ -42,3 +42,4 @@ class ConfigModel(BaseModel):
|
|||
job_timeout: Optional[int] = Field(default=3)
|
||||
enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True)
|
||||
complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True)
|
||||
step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False)
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ logger.addHandler(gui_handler)
|
|||
# end logging
|
||||
|
||||
warmup_samples = 2 # number of samples to do before recording a valid benchmark sample
|
||||
samples = 3 # number of times to benchmark worker after warmup benchmarks are completed
|
||||
|
||||
|
||||
class BenchmarkPayload(BaseModel):
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@ worker_select_dropdown = None
|
|||
class UI:
|
||||
"""extension user interface related things"""
|
||||
|
||||
def __init__(self, script, world):
|
||||
self.script = script
|
||||
def __init__(self, world):
|
||||
self.world = world
|
||||
self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange
|
||||
|
||||
|
|
@ -25,10 +24,17 @@ class UI:
|
|||
"""executes a script placed by the user at <extension>/user/sync*"""
|
||||
user_scripts = Path(os.path.abspath(__file__)).parent.parent.joinpath('user')
|
||||
|
||||
user_script = None
|
||||
for file in user_scripts.iterdir():
|
||||
logger.debug(f"found possible script {file.name}")
|
||||
if file.is_file() and file.name.startswith('sync'):
|
||||
user_script = file
|
||||
if user_script is None:
|
||||
logger.error(
|
||||
"couldn't find user script\n"
|
||||
"script must be placed under <extension>/user/ and filename must begin with sync"
|
||||
)
|
||||
return False
|
||||
|
||||
suffix = user_script.suffix[1:]
|
||||
|
||||
|
|
@ -74,15 +80,14 @@ class UI:
|
|||
|
||||
return 'No active jobs!', worker_status, logs
|
||||
|
||||
def save_btn(self, thin_client_mode, job_timeout, complement_production):
|
||||
def save_btn(self, thin_client_mode, job_timeout, complement_production, step_scaling):
|
||||
"""updates the options visible on the settings tab"""
|
||||
|
||||
self.world.thin_client_mode = thin_client_mode
|
||||
logger.debug(f"thin client mode is now {thin_client_mode}")
|
||||
job_timeout = int(job_timeout)
|
||||
self.world.job_timeout = job_timeout
|
||||
logger.debug(f"job timeout is now {job_timeout} seconds")
|
||||
self.world.complement_production = complement_production
|
||||
self.world.step_scaling = step_scaling
|
||||
self.world.save_config()
|
||||
|
||||
def save_worker_btn(self, label, address, port, tls, disabled):
|
||||
|
|
@ -102,7 +107,7 @@ class UI:
|
|||
self.world.add_worker(
|
||||
label=label,
|
||||
address=address,
|
||||
port=port,
|
||||
port=port if len(port) > 0 else 7860,
|
||||
tls=tls,
|
||||
state=state
|
||||
)
|
||||
|
|
@ -208,7 +213,6 @@ class UI:
|
|||
interactive=True
|
||||
)
|
||||
main_toggle.input(self.main_toggle_btn)
|
||||
setattr(main_toggle, 'do_not_save_to_config', True) # ui_loadsave.py apply_field()
|
||||
components.append(main_toggle)
|
||||
|
||||
with gradio.Tab('Status') as status_tab:
|
||||
|
|
@ -240,7 +244,7 @@ class UI:
|
|||
reload_config_btn = gradio.Button(value='📜 Reload config')
|
||||
reload_config_btn.click(self.world.load_config)
|
||||
|
||||
redo_benchmarks_btn = gradio.Button(value='📊 Redo benchmarks', variant='stop')
|
||||
redo_benchmarks_btn = gradio.Button(value='📊 Redo benchmarks')
|
||||
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
|
||||
|
||||
run_usr_btn = gradio.Button(value='⚙️ Run script')
|
||||
|
|
@ -252,10 +256,10 @@ class UI:
|
|||
reconnect_lost_workers_btn = gradio.Button(value='🔌 Reconnect workers')
|
||||
reconnect_lost_workers_btn.click(self.world.ping_remotes)
|
||||
|
||||
interrupt_all_btn = gradio.Button(value='⏸️ Interrupt all', variant='stop')
|
||||
interrupt_all_btn = gradio.Button(value='⏸️ Interrupt all')
|
||||
interrupt_all_btn.click(self.world.interrupt_remotes)
|
||||
|
||||
restart_workers_btn = gradio.Button(value="🔁 Restart All", variant='stop')
|
||||
restart_workers_btn = gradio.Button(value="🔁 Restart All")
|
||||
restart_workers_btn.click(
|
||||
_js="confirm_restart_workers",
|
||||
fn=lambda confirmed: self.world.restart_all() if confirmed else None,
|
||||
|
|
@ -305,7 +309,7 @@ class UI:
|
|||
# API authentication
|
||||
worker_api_auth_cbx = gradio.Checkbox(label='API Authentication')
|
||||
worker_user_field = gradio.Textbox(label='Username')
|
||||
worker_password_field = gradio.Textbox(label='Password')
|
||||
worker_password_field = gradio.Textbox(label='Password', type='password')
|
||||
update_credentials_btn = gradio.Button(value='Update API Credentials')
|
||||
update_credentials_btn.click(self.update_credentials_btn, inputs=[
|
||||
worker_api_auth_cbx,
|
||||
|
|
@ -346,25 +350,33 @@ class UI:
|
|||
|
||||
with gradio.Tab('Settings'):
|
||||
thin_client_cbx = gradio.Checkbox(
|
||||
label='Thin-client mode (experimental)',
|
||||
info="(BROKEN) Only generate images using remote workers. There will be no previews when enabled.",
|
||||
label='Thin-client mode',
|
||||
info="Only generate images remotely (no image previews yet)",
|
||||
value=self.world.thin_client_mode
|
||||
)
|
||||
job_timeout = gradio.Number(
|
||||
label='Job timeout', value=self.world.job_timeout,
|
||||
info="Seconds until a worker is considered too slow to be assigned an"
|
||||
" equal share of the total request. Longer than 2 seconds is recommended."
|
||||
" equal share of the total request. Longer than 2 seconds is recommended"
|
||||
)
|
||||
|
||||
complement_production = gradio.Checkbox(
|
||||
label='Complement production',
|
||||
info='Prevents under-utilization of hardware by requesting additional images',
|
||||
info='Prevents under-utilization by requesting additional images when possible',
|
||||
value=self.world.complement_production
|
||||
)
|
||||
|
||||
# reduces image quality the more the sample-count must be reduced
|
||||
# good for mixed setups where each worker may not be around the same speed
|
||||
step_scaling = gradio.Checkbox(
|
||||
label='Step scaling',
|
||||
info='Prevents under-utilization via sample reduction in order to meet time constraints',
|
||||
value=self.world.step_scaling
|
||||
)
|
||||
|
||||
save_btn = gradio.Button(value='Update')
|
||||
save_btn.click(fn=self.save_btn, inputs=[thin_client_cbx, job_timeout, complement_production])
|
||||
components += [thin_client_cbx, job_timeout, complement_production, save_btn]
|
||||
save_btn.click(fn=self.save_btn, inputs=[thin_client_cbx, job_timeout, complement_production, step_scaling])
|
||||
components += [thin_client_cbx, job_timeout, complement_production, step_scaling, save_btn]
|
||||
|
||||
with gradio.Tab('Help'):
|
||||
gradio.Markdown(
|
||||
|
|
@ -374,4 +386,7 @@ class UI:
|
|||
"""
|
||||
)
|
||||
|
||||
# prevent wui from overriding any values
|
||||
for component in components:
|
||||
setattr(component, 'do_not_save_to_config', True) # ui_loadsave.py apply_field()
|
||||
return components
|
||||
|
|
|
|||
|
|
@ -156,6 +156,11 @@ class Worker:
|
|||
def __repr__(self):
|
||||
return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Worker) and other.label == self.label:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def model(self) -> Worker_Model:
|
||||
return Worker_Model(**self.__dict__)
|
||||
|
|
@ -189,14 +194,14 @@ class Worker:
|
|||
protocol = 'http' if not self.tls else 'https'
|
||||
return f"{protocol}://{self.__str__()}/sdapi/v1/{route}"
|
||||
|
||||
def batch_eta_hr(self, payload: dict) -> float:
|
||||
def eta_hr(self, payload: dict) -> float:
|
||||
"""
|
||||
takes a normal payload and returns the eta of a pseudo payload which mirrors the hr-fix parameters
|
||||
This returns the eta of how long it would take to run hr-fix on the original image
|
||||
"""
|
||||
|
||||
pseudo_payload = copy.copy(payload)
|
||||
pseudo_payload['enable_hr'] = False # prevent overflow in self.batch_eta
|
||||
pseudo_payload['enable_hr'] = False # prevent overflow in self.eta
|
||||
res_ratio = pseudo_payload['hr_scale']
|
||||
original_steps = pseudo_payload['steps']
|
||||
second_pass_steps = pseudo_payload['hr_second_pass_steps']
|
||||
|
|
@ -212,12 +217,11 @@ class Worker:
|
|||
pseudo_payload['width'] = pseudo_width
|
||||
pseudo_payload['height'] = pseudo_height
|
||||
|
||||
eta = self.batch_eta(payload=pseudo_payload, quiet=True)
|
||||
return eta
|
||||
return self.eta(payload=pseudo_payload, quiet=True)
|
||||
|
||||
def batch_eta(self, payload: dict, quiet: bool = False, batch_size: int = None) -> float:
|
||||
def eta(self, payload: dict, quiet: bool = False, batch_size: int = None, samples: int = None) -> float:
|
||||
"""
|
||||
estimate how long it will take to generate <batch_size> images on a worker in seconds
|
||||
estimate how long it will take to generate image(s) on a worker in seconds
|
||||
|
||||
Args:
|
||||
payload: Sdwui api formatted payload
|
||||
|
|
@ -225,7 +229,7 @@ class Worker:
|
|||
batch_size: Overrides the batch_size parameter of the payload
|
||||
"""
|
||||
|
||||
steps = payload['steps']
|
||||
steps = payload['steps'] if samples is None else samples
|
||||
num_images = payload['batch_size'] if batch_size is None else batch_size
|
||||
|
||||
# if worker has not yet been benchmarked then
|
||||
|
|
@ -237,7 +241,7 @@ class Worker:
|
|||
# show effect of high-res fix
|
||||
hr = payload.get('enable_hr', False)
|
||||
if hr:
|
||||
eta += self.batch_eta_hr(payload=payload)
|
||||
eta += self.eta_hr(payload=payload)
|
||||
|
||||
# show effect of image size
|
||||
real_pix_to_benched = (payload['width'] * payload['height']) \
|
||||
|
|
@ -331,7 +335,7 @@ class Worker:
|
|||
self.load_options(model=option_payload['sd_model_checkpoint'], vae=option_payload['sd_vae'])
|
||||
|
||||
if self.benchmarked:
|
||||
eta = self.batch_eta(payload=payload) * payload['n_iter']
|
||||
eta = self.eta(payload=payload) * payload['n_iter']
|
||||
logger.debug(f"worker '{self.label}' predicts it will take {eta:.3f}s to generate "
|
||||
f"{payload['batch_size'] * payload['n_iter']} image(s) "
|
||||
f"at a speed of {self.avg_ipm:.2f} ipm\n")
|
||||
|
|
@ -471,7 +475,7 @@ class Worker:
|
|||
self.response_time = time.time() - start
|
||||
variance = ((eta - self.response_time) / self.response_time) * 100
|
||||
|
||||
logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%.\n"
|
||||
logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%\n"
|
||||
f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
|
||||
|
||||
# if the variance is greater than 500% then we ignore it to prevent variation inflation
|
||||
|
|
@ -501,20 +505,20 @@ class Worker:
|
|||
self.jobs_requested += 1
|
||||
return
|
||||
|
||||
def benchmark(self) -> float:
|
||||
def benchmark(self, sample_function: callable = None) -> float:
|
||||
"""
|
||||
given a worker, run a small benchmark and return its performance in images/minute
|
||||
makes standard request(s) of 512x512 images and averages them to get the result
|
||||
"""
|
||||
|
||||
t: Thread
|
||||
samples = 2 # number of times to benchmark the remote / accuracy
|
||||
|
||||
if self.state == State.DISABLED or self.state == State.UNAVAILABLE:
|
||||
if self.state in (State.DISABLED, State.UNAVAILABLE):
|
||||
logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark")
|
||||
return 0
|
||||
|
||||
if self.master is True:
|
||||
if self.master and sample_function is None:
|
||||
logger.critical(f"no function provided for benchmarking master")
|
||||
return -1
|
||||
|
||||
def ipm(seconds: float) -> float:
|
||||
|
|
@ -533,19 +537,24 @@ class Worker:
|
|||
results: List[float] = []
|
||||
# it used to be lower for the first couple of generations
|
||||
# this was due to something torch does at startup according to auto and is now done at sdwui startup
|
||||
self.state = State.WORKING
|
||||
for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
||||
for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
||||
if self.state == State.UNAVAILABLE:
|
||||
self.response = None
|
||||
return 0
|
||||
|
||||
t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,),
|
||||
name=f"{self.label}_benchmark_request")
|
||||
try: # if the worker is unreachable/offline then handle that here
|
||||
t.start()
|
||||
start = time.time()
|
||||
t.join()
|
||||
elapsed = time.time() - start
|
||||
elapsed = None
|
||||
|
||||
if not callable(sample_function):
|
||||
start = time.time()
|
||||
t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,),
|
||||
name=f"{self.label}_benchmark_request")
|
||||
t.start()
|
||||
t.join()
|
||||
elapsed = time.time() - start
|
||||
else:
|
||||
elapsed = sample_function()
|
||||
|
||||
sample_ipm = ipm(elapsed)
|
||||
except InvalidWorkerResponse as e:
|
||||
raise e
|
||||
|
|
@ -558,15 +567,13 @@ class Worker:
|
|||
logger.debug(f"{self.label} finished warming up\n")
|
||||
|
||||
# average the sample results for accuracy
|
||||
ipm_sum = 0
|
||||
for ipm_result in results:
|
||||
ipm_sum += ipm_result
|
||||
avg_ipm_result = ipm_sum / samples
|
||||
avg_ipm_result = sum(results) / sh.samples
|
||||
|
||||
logger.debug(f"Worker '{self.label}' average ipm: {avg_ipm_result:.2f}")
|
||||
self.avg_ipm = avg_ipm_result
|
||||
self.response = None
|
||||
self.benchmarked = True
|
||||
self.eta_percent_error = [] # likely inaccurate after rebenching
|
||||
self.state = State.IDLE
|
||||
return avg_ipm_result
|
||||
|
||||
|
|
@ -610,6 +617,10 @@ class Worker:
|
|||
except requests.exceptions.ConnectionError as e:
|
||||
logger.error(e)
|
||||
return False
|
||||
except requests.ReadTimeout as e:
|
||||
logger.critical(f"worker '{self.label}' is online but not responding (crashed?)")
|
||||
logger.error(e)
|
||||
return False
|
||||
|
||||
def mark_unreachable(self):
|
||||
if self.state == State.DISABLED:
|
||||
|
|
@ -677,6 +688,10 @@ class Worker:
|
|||
if vae is not None:
|
||||
self.loaded_vae = vae
|
||||
|
||||
self.response = response
|
||||
|
||||
return self
|
||||
|
||||
def restart(self) -> bool:
|
||||
err_msg = f"could not restart worker '{self.label}'"
|
||||
success_msg = f"worker '{self.label}' is restarting"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ This module facilitates the creation of a stable-diffusion-webui centered distri
|
|||
World:
|
||||
The main class which should be instantiated in order to create a new sdwui distributed system.
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
|
|
@ -16,8 +16,10 @@ import modules.shared as shared
|
|||
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
|
||||
from . import shared as sh
|
||||
from .pmodels import ConfigModel, Benchmark_Payload
|
||||
from .shared import logger, warmup_samples, extension_path
|
||||
from .shared import logger, extension_path
|
||||
from .worker import Worker, State
|
||||
from modules.call_queue import wrap_queued_call
|
||||
from modules import processing
|
||||
|
||||
|
||||
class NotBenchmarked(Exception):
|
||||
|
|
@ -28,13 +30,6 @@ class NotBenchmarked(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class WorldAlreadyInitialized(Exception):
|
||||
"""
|
||||
Raised when attempting to initialize the World when it has already been initialized.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Job:
|
||||
"""
|
||||
Keeps track of how much work a given worker should handle.
|
||||
|
|
@ -48,6 +43,7 @@ class Job:
|
|||
self.worker: Worker = worker
|
||||
self.batch_size: int = batch_size
|
||||
self.complementary: bool = False
|
||||
self.step_override = None
|
||||
|
||||
def __str__(self):
|
||||
prefix = ''
|
||||
|
|
@ -75,7 +71,7 @@ class World:
|
|||
The frame or "world" which holds all workers (including the local machine).
|
||||
|
||||
Args:
|
||||
initial_payload: The original txt2img payload created by the user initiating the generation request on master.
|
||||
p: The original processing state object created by the user initiating the generation request on master.
|
||||
verify_remotes (bool): Whether to validate remote worker certificates.
|
||||
"""
|
||||
|
||||
|
|
@ -83,19 +79,19 @@ class World:
|
|||
config_path = shared.cmd_opts.distributed_config
|
||||
old_config_path = worker_info_path = extension_path.joinpath('workers.json')
|
||||
|
||||
def __init__(self, initial_payload, verify_remotes: bool = True):
|
||||
def __init__(self, verify_remotes: bool = True):
|
||||
self.p = None
|
||||
self.master_worker = Worker(master=True)
|
||||
self.total_batch_size: int = 0
|
||||
self._workers: List[Worker] = [self.master_worker]
|
||||
self.jobs: List[Job] = []
|
||||
self.job_timeout: int = 3 # seconds
|
||||
self.initialized: bool = False
|
||||
self.verify_remotes = verify_remotes
|
||||
self.initial_payload = copy.copy(initial_payload)
|
||||
self.thin_client_mode = False
|
||||
self.enabled = True
|
||||
self.is_dropdown_handler_injected = False
|
||||
self.complement_production = True
|
||||
self.step_scaling = False
|
||||
|
||||
def __getitem__(self, label: str) -> Worker:
|
||||
for worker in self._workers:
|
||||
|
|
@ -105,32 +101,12 @@ class World:
|
|||
def __repr__(self):
|
||||
return f"{len(self._workers)} workers"
|
||||
|
||||
def update_world(self, total_batch_size):
|
||||
"""
|
||||
Updates the world with information vital to handling the local generation request after
|
||||
the world has already been initialized.
|
||||
|
||||
Args:
|
||||
total_batch_size (int): The total number of images requested by the local/master sdwui instance.
|
||||
"""
|
||||
|
||||
self.total_batch_size = total_batch_size
|
||||
self.update_jobs()
|
||||
|
||||
def initialize(self, total_batch_size):
|
||||
"""should be called before a world instance is used for anything"""
|
||||
if self.initialized:
|
||||
raise WorldAlreadyInitialized("This world instance was already initialized")
|
||||
|
||||
self.benchmark()
|
||||
self.update_world(total_batch_size=total_batch_size)
|
||||
self.initialized = True
|
||||
|
||||
def default_batch_size(self) -> int:
|
||||
"""the amount of images/total images requested that a worker would compute if conditions were perfect and
|
||||
each worker generated at the same speed. assumes one batch only"""
|
||||
|
||||
return self.total_batch_size // self.size()
|
||||
return self.p.batch_size // self.size()
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
|
|
@ -169,6 +145,14 @@ class World:
|
|||
Worker: The worker object.
|
||||
"""
|
||||
|
||||
# protect against user trying to make cyclical setups and connections
|
||||
is_master = kwargs.get('master')
|
||||
if is_master is None or not is_master:
|
||||
m = self.master()
|
||||
if kwargs['address'] == m.address and kwargs['port'] == m.port:
|
||||
logger.error(f"refusing to add worker {kwargs['label']} as its socket definition({m.address}:{m.port}) matches master")
|
||||
return None
|
||||
|
||||
original = self[kwargs['label']] # if worker doesn't already exist then just make a new one
|
||||
if original is None:
|
||||
new = Worker(**kwargs)
|
||||
|
|
@ -192,16 +176,33 @@ class World:
|
|||
if worker.master:
|
||||
continue
|
||||
|
||||
t = Thread(target=worker.interrupt, args=())
|
||||
t.start()
|
||||
Thread(target=worker.interrupt, args=()).start()
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
for worker in self.get_workers():
|
||||
if worker.master:
|
||||
continue
|
||||
|
||||
t = Thread(target=worker.refresh_checkpoints, args=())
|
||||
t.start()
|
||||
Thread(target=worker.refresh_checkpoints, args=()).start()
|
||||
|
||||
def sample_master(self) -> float:
|
||||
# wrap our benchmark payload
|
||||
master_bench_payload = StableDiffusionProcessingTxt2Img()
|
||||
d = sh.benchmark_payload.dict()
|
||||
for key in d:
|
||||
setattr(master_bench_payload, key, d[key])
|
||||
|
||||
# Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to.
|
||||
master_bench_payload.do_not_save_samples = True
|
||||
# shared.state.begin(job='distributed_master_bench')
|
||||
wrapped = (wrap_queued_call(process_images))
|
||||
start = time.time()
|
||||
wrapped(master_bench_payload)
|
||||
# wrap_gradio_gpu_call(process_images)(master_bench_payload)
|
||||
# shared.state.end()
|
||||
|
||||
return time.time() - start
|
||||
|
||||
|
||||
def benchmark(self, rebenchmark: bool = False):
|
||||
"""
|
||||
|
|
@ -209,14 +210,6 @@ class World:
|
|||
"""
|
||||
|
||||
unbenched_workers = []
|
||||
benchmark_threads: List[Thread] = []
|
||||
sync_threads: List[Thread] = []
|
||||
|
||||
def benchmark_wrapped(worker):
|
||||
bench_func = worker.benchmark if not worker.master else self.benchmark_master
|
||||
worker.avg_ipm = bench_func()
|
||||
worker.benchmarked = True
|
||||
|
||||
if rebenchmark:
|
||||
for worker in self._workers:
|
||||
worker.benchmarked = False
|
||||
|
|
@ -231,28 +224,44 @@ class World:
|
|||
else:
|
||||
worker.benchmarked = True
|
||||
|
||||
# have every unbenched worker load the same weights before the benchmark
|
||||
for worker in unbenched_workers:
|
||||
if worker.master or worker.state == State.DISABLED:
|
||||
continue
|
||||
with concurrent.futures.ThreadPoolExecutor(thread_name_prefix='distributed_benchmark') as executor:
|
||||
futures = []
|
||||
|
||||
sync_thread = Thread(target=worker.load_options, args=(shared.opts.sd_model_checkpoint, shared.opts.sd_vae))
|
||||
sync_threads.append(sync_thread)
|
||||
sync_thread.start()
|
||||
for thread in sync_threads:
|
||||
thread.join()
|
||||
# have every unbenched worker load the same weights before the benchmark
|
||||
for worker in unbenched_workers:
|
||||
if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE):
|
||||
continue
|
||||
|
||||
# benchmark those that haven't been
|
||||
for worker in unbenched_workers:
|
||||
t = Thread(target=benchmark_wrapped, args=(worker, ), name=f"{worker.label}_benchmark")
|
||||
benchmark_threads.append(t)
|
||||
t.start()
|
||||
logger.info(f"benchmarking worker '{worker.label}'")
|
||||
futures.append(
|
||||
executor.submit(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae)
|
||||
)
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
worker = future.result()
|
||||
if worker is None:
|
||||
continue
|
||||
|
||||
# wait for all benchmarks to finish and update stats on newly benchmarked workers
|
||||
if len(benchmark_threads) > 0:
|
||||
for t in benchmark_threads:
|
||||
t.join()
|
||||
if worker.response.status_code != 200:
|
||||
logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n"
|
||||
f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers")
|
||||
unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers))
|
||||
futures.clear()
|
||||
|
||||
# benchmark those that haven't been
|
||||
for worker in unbenched_workers:
|
||||
if worker.state in (State.DISABLED, State.UNAVAILABLE):
|
||||
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark")
|
||||
continue
|
||||
|
||||
if worker.model_override is not None:
|
||||
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
|
||||
f"*all workers should be evaluated against the same model")
|
||||
|
||||
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
|
||||
futures.append(executor.submit(chosen, worker))
|
||||
logger.info(f"benchmarking worker '{worker.label}'")
|
||||
|
||||
# wait for all benchmarks to finish and update stats on newly benchmarked workers
|
||||
concurrent.futures.wait(futures)
|
||||
logger.info("benchmarking finished")
|
||||
|
||||
# save benchmark results to workers.json
|
||||
|
|
@ -348,43 +357,11 @@ class World:
|
|||
if worker == fastest_worker:
|
||||
return 0
|
||||
|
||||
lag = worker.batch_eta(payload=payload, quiet=True, batch_size=batch_size) - fastest_worker.batch_eta(payload=payload, quiet=True, batch_size=batch_size)
|
||||
lag = worker.eta(payload=payload, quiet=True, batch_size=batch_size) - fastest_worker.eta(payload=payload, quiet=True, batch_size=batch_size)
|
||||
|
||||
return lag
|
||||
|
||||
def benchmark_master(self) -> float:
|
||||
"""
|
||||
Benchmarks the local/master worker.
|
||||
|
||||
Returns:
|
||||
float: Local worker speed in ipm
|
||||
"""
|
||||
|
||||
# wrap our benchmark payload
|
||||
master_bench_payload = StableDiffusionProcessingTxt2Img()
|
||||
d = sh.benchmark_payload.dict()
|
||||
for key in d:
|
||||
setattr(master_bench_payload, key, d[key])
|
||||
|
||||
# Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to.
|
||||
master_bench_payload.do_not_save_samples = True
|
||||
|
||||
# "warm up" due to initial generation lag
|
||||
for _ in range(warmup_samples):
|
||||
process_images(master_bench_payload)
|
||||
|
||||
# get actual sample
|
||||
start = time.time()
|
||||
process_images(master_bench_payload)
|
||||
elapsed = time.time() - start
|
||||
|
||||
ipm = sh.benchmark_payload.batch_size / (elapsed / 60)
|
||||
|
||||
logger.debug(f"Master benchmark took {elapsed:.2f}: {ipm:.2f} ipm")
|
||||
self.master().benchmarked = True
|
||||
return ipm
|
||||
|
||||
def update_jobs(self):
|
||||
def make_jobs(self):
|
||||
"""creates initial jobs (before optimization) """
|
||||
|
||||
# clear jobs if this is not the first time running
|
||||
|
|
@ -398,6 +375,19 @@ class World:
|
|||
worker.benchmark()
|
||||
|
||||
self.jobs.append(Job(worker=worker, batch_size=batch_size))
|
||||
logger.debug(f"added job for worker {worker.label}")
|
||||
|
||||
def update(self, p):
|
||||
"""preps world for another run"""
|
||||
if not self.initialized:
|
||||
self.benchmark()
|
||||
self.initialized = True
|
||||
logger.debug("world initialized!")
|
||||
else:
|
||||
logger.debug("world was already initialized")
|
||||
|
||||
self.p = p
|
||||
self.make_jobs()
|
||||
|
||||
def get_workers(self):
|
||||
filtered: List[Worker] = []
|
||||
|
|
@ -434,7 +424,7 @@ class World:
|
|||
|
||||
logger.debug(f"worker '{job.worker.label}' would stall the image gallery by ~{lag:.2f}s\n")
|
||||
job.complementary = True
|
||||
if deferred_images + images_checked + payload['batch_size'] > self.total_batch_size:
|
||||
if deferred_images + images_checked + payload['batch_size'] > self.p.batch_size:
|
||||
logger.debug(f"would go over actual requested size")
|
||||
else:
|
||||
deferred_images += payload['batch_size']
|
||||
|
|
@ -477,9 +467,9 @@ class World:
|
|||
#######################
|
||||
|
||||
# when total number of requested images was not cleanly divisible by world size then we tack the remainder on
|
||||
remainder_images = self.total_batch_size - self.get_current_output_size()
|
||||
remainder_images = self.p.batch_size - self.get_current_output_size()
|
||||
if remainder_images >= 1:
|
||||
logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
|
||||
logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
|
||||
|
||||
realtime_jobs = self.realtime_jobs()
|
||||
realtime_jobs.sort(key=lambda x: x.batch_size)
|
||||
|
|
@ -521,16 +511,16 @@ class World:
|
|||
fastest_active = self.fastest_realtime_job().worker
|
||||
for j in self.jobs:
|
||||
if j.worker.label == fastest_active.label:
|
||||
slack_time = fastest_active.batch_eta(payload=payload, batch_size=j.batch_size) + self.job_timeout
|
||||
slack_time = fastest_active.eta(payload=payload, batch_size=j.batch_size) + self.job_timeout
|
||||
logger.debug(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.label}'")
|
||||
|
||||
# see how long it would take to produce only 1 image on this complementary worker
|
||||
secs_per_batch_image = job.worker.batch_eta(payload=payload, batch_size=1)
|
||||
secs_per_batch_image = job.worker.eta(payload=payload, batch_size=1)
|
||||
num_images_compensate = int(slack_time / secs_per_batch_image)
|
||||
logger.debug(
|
||||
f"worker '{job.worker.label}':\n"
|
||||
f"{num_images_compensate} complementary image(s) = {slack_time:.2f}s slack"
|
||||
f"/ {secs_per_batch_image:.2f}s per requested image"
|
||||
f" ÷ {secs_per_batch_image:.2f}s per requested image"
|
||||
)
|
||||
|
||||
if not job.add_work(payload, batch_size=num_images_compensate):
|
||||
|
|
@ -538,14 +528,29 @@ class World:
|
|||
request_img_size = payload['width'] * payload['height']
|
||||
max_images = job.worker.pixel_cap // request_img_size
|
||||
job.add_work(payload, batch_size=max_images)
|
||||
|
||||
# when not even a singular image can be squeezed out
|
||||
# if step scaling is enabled, then find how many samples would be considered realtime and adjust
|
||||
if num_images_compensate == 0 and self.step_scaling:
|
||||
seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1)
|
||||
realtime_samples = slack_time // seconds_per_sample
|
||||
logger.debug(
|
||||
f"job for '{job.worker.label}' downscaled to {realtime_samples} samples to meet time constraints\n"
|
||||
f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack ÷ {seconds_per_sample:.2f}s/sample\n"
|
||||
f" step reduction: {payload['steps']} -> {realtime_samples:.0f}"
|
||||
)
|
||||
|
||||
job.add_work(payload=payload, batch_size=1)
|
||||
job.step_override = realtime_samples
|
||||
|
||||
else:
|
||||
logger.debug("complementary image production is disabled")
|
||||
|
||||
iterations = payload['n_iter']
|
||||
num_returning = self.get_current_output_size()
|
||||
num_complementary = num_returning - self.total_batch_size
|
||||
num_complementary = num_returning - self.p.batch_size
|
||||
distro_summary = "Job distribution:\n"
|
||||
distro_summary += f"{self.total_batch_size} * {iterations} iteration(s)"
|
||||
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
||||
if num_complementary > 0:
|
||||
distro_summary += f" + {num_complementary} complementary"
|
||||
distro_summary += f": {num_returning} images total\n"
|
||||
|
|
@ -553,6 +558,22 @@ class World:
|
|||
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
||||
logger.info(distro_summary)
|
||||
|
||||
if self.thin_client_mode is True or self.master_job().batch_size == 0:
|
||||
# save original process_images_inner for later so we can restore once we're done
|
||||
logger.debug(f"bypassing local generation completely")
|
||||
def process_images_inner_bypass(p) -> processing.Processed:
|
||||
processed = processing.Processed(p, [], p.seed, info="")
|
||||
processed.all_prompts = []
|
||||
processed.all_seeds = []
|
||||
processed.all_subseeds = []
|
||||
processed.all_negative_prompts = []
|
||||
processed.infotexts = []
|
||||
processed.prompt = None
|
||||
|
||||
self.p.scripts.postprocess(p, processed)
|
||||
return processed
|
||||
processing.process_images_inner = process_images_inner_bypass
|
||||
|
||||
# delete any jobs that have no work
|
||||
last = len(self.jobs) - 1
|
||||
while last > 0:
|
||||
|
|
@ -631,6 +652,8 @@ class World:
|
|||
label = next(iter(w.keys()))
|
||||
fields = w[label].__dict__
|
||||
fields['label'] = label
|
||||
# TODO must be overridden everytime here or later converted to a config file variable at some point
|
||||
fields['verify_remotes'] = self.verify_remotes
|
||||
|
||||
self.add_worker(**fields)
|
||||
|
||||
|
|
@ -638,6 +661,7 @@ class World:
|
|||
self.job_timeout = config.job_timeout
|
||||
self.enabled = config.enabled
|
||||
self.complement_production = config.complement_production
|
||||
self.step_scaling = config.step_scaling
|
||||
|
||||
logger.debug("config loaded")
|
||||
|
||||
|
|
@ -651,7 +675,8 @@ class World:
|
|||
benchmark_payload=sh.benchmark_payload,
|
||||
job_timeout=self.job_timeout,
|
||||
enabled=self.enabled,
|
||||
complement_production=self.complement_production
|
||||
complement_production=self.complement_production,
|
||||
step_scaling=self.step_scaling
|
||||
)
|
||||
|
||||
with open(self.config_path, 'w+') as config_file:
|
||||
|
|
@ -679,22 +704,24 @@ class World:
|
|||
if worker.queried and worker.state == State.IDLE: # TODO worker.queried
|
||||
continue
|
||||
|
||||
# for now skip/remove scripts that are not "always on" since there is currently no way to run
|
||||
# them at the same time as distributed
|
||||
supported_scripts = {
|
||||
'txt2img': [],
|
||||
'img2img': []
|
||||
}
|
||||
script_info = worker.session.get(url=worker.full_url('script-info')).json()
|
||||
|
||||
for key in script_info:
|
||||
name = key.get('name', None)
|
||||
response = worker.session.get(url=worker.full_url('script-info'))
|
||||
if response.status_code == 200:
|
||||
script_info = response.json()
|
||||
for key in script_info:
|
||||
name = key.get('name', None)
|
||||
|
||||
if name is not None:
|
||||
is_alwayson = key.get('is_alwayson', False)
|
||||
is_img2img = key.get('is_img2img', False)
|
||||
if is_alwayson:
|
||||
supported_scripts['img2img' if is_img2img else 'txt2img'].append(name)
|
||||
if name is not None:
|
||||
is_alwayson = key.get('is_alwayson', False)
|
||||
is_img2img = key.get('is_img2img', False)
|
||||
if is_alwayson:
|
||||
supported_scripts['img2img' if is_img2img else 'txt2img'].append(name)
|
||||
else:
|
||||
logger.error(f"failed to query script-info for worker '{worker.label}': {response}")
|
||||
worker.supported_scripts = supported_scripts
|
||||
|
||||
msg = f"worker '{worker.label}' is online"
|
||||
|
|
|
|||
Loading…
Reference in New Issue