stable-diffusion-webui-dist.../scripts/distributed.py

377 lines
15 KiB
Python

"""
https://github.com/papuSpartan/stable-diffusion-webui-distributed
"""
import base64
import copy
import io
import json
import re
import signal
import sys
import time
from threading import Thread
from typing import List
import gradio
from torchvision.transforms import ToTensor
import urllib3
from PIL import Image
from modules import processing
from modules import scripts
from modules.processing import fix_seed
from modules.shared import opts, cmd_opts
from modules.shared import state as webui_state
from scripts.spartan.shared import logger
from scripts.spartan.ui import UI
from scripts.spartan.world import World, State, Job
from scripts.spartan.adapters import adapters, GenericAdapter
old_sigint_handler = signal.getsignal(signal.SIGINT)
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
# noinspection PyMissingOrEmptyDocstring
class DistributedScript(scripts.Script):
# global old_sigterm_handler, old_sigterm_handler
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
master_start = None
runs_since_init = 0
name = "distributed"
is_dropdown_handler_injected = False
if verify_remotes is False:
logger.warning(f"You have chosen to forego the verification of worker TLS certificates")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# build world
world = World(verify_remotes=verify_remotes)
world.load_config()
logger.info("doing initial ping sweep to see which workers are reachable")
world.ping_remotes(indiscriminate=True)
# constructed for both txt2img and img2img
def __init__(self):
super().__init__()
def title(self):
return "Distribute"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
extension_ui = UI(world=self.world, is_img2img=is_img2img)
# root, api_exposed = extension_ui.create_ui()
components = extension_ui.create_ui()
# The first injection of handler for the models dropdown(sd_model_checkpoint) which is often present
# in the quick-settings bar of a user. Helps ensure model swaps propagate to all nodes ASAP.
self.world.inject_model_dropdown_handler()
# return some components that should be exposed to the api
return components
def enabled(self, p):
is_img2img = getattr(p, 'init_images', False)
if is_img2img and self.world.enabled_i2i is False:
return False
elif not is_img2img and self.world.enabled is False:
return False
return True
def api_to_internal(self, job) -> ([], [], [], [], []):
# takes worker response received from api and returns parsed objects in internal sdwui format. E.g. all_seeds
image_params: json = job.worker.response['parameters']
image_info_post: json = json.loads(job.worker.response["info"]) # image info known after processing
all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = [], [], [], [], []
for i in range(len(job.worker.response["images"])):
try:
if image_params["batch_size"] * image_params["n_iter"] > 1:
all_seeds.append(image_info_post['all_seeds'][i])
all_subseeds.append(image_info_post['all_subseeds'][i])
all_negative_prompts.append(image_info_post['all_negative_prompts'][i])
all_prompts.append(image_info_post['all_prompts'][i])
else: # only a single image received
all_seeds.append(image_info_post['seed'])
all_subseeds.append(image_info_post['subseed'])
all_negative_prompts.append(image_info_post['negative_prompt'])
all_prompts.append(image_info_post['prompt'])
except IndexError:
# # like with controlnet masks, there isn't always full post-gen info, so we use the first images'
# logger.debug(f"Image at index {info_index} for '{job.worker.label}' was missing some post-generation data")
# self.processed_inject_image(image=image, info_index=0, job=job, p=p)
# return
logger.critical(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
continue
# parse image
image_bytes = base64.b64decode(job.worker.response["images"][i])
image = Image.open(io.BytesIO(image_bytes))
transform = ToTensor()
images.append(transform(image))
return all_seeds, all_subseeds, all_negative_prompts, all_prompts, images
def inject_job(self, job: Job, p, pp):
"""Adds the work completed by one Job via its worker response to the processing and postprocessing objects"""
all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = self.api_to_internal(job)
p.seeds.extend(all_seeds)
p.subseeds.extend(all_subseeds)
p.negative_prompts.extend(all_negative_prompts)
p.prompts.extend(all_prompts)
num_local = self.world.p.n_iter * self.world.p.batch_size + (opts.return_grid - self.world.thin_client_mode)
num_injected = len(pp.images) - self.world.p.batch_size
for i, image in enumerate(images):
# modules.ui_common -> update_generation_info renders to html below gallery
gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
job.gallery_map.append(gallery_index) # so we know where to edit infotext
pp.images.append(image)
logger.debug(f"image {gallery_index + 1 + self.world.thin_client_mode}/{self.world.num_gallery()}")
def update_gallery(self, pp, p):
"""adds all remotely generated images to the image gallery after waiting for all workers to finish"""
# get master ipm by estimating based on worker speed
master_elapsed = time.time() - self.master_start
logger.debug(f"Took master {master_elapsed:.2f}s")
# wait for response from all workers
webui_state.textinfo = "Distributed - receiving results"
for job in self.world.jobs:
if job.thread is None:
continue
logger.debug(f"waiting for worker thread '{job.thread.name}'")
job.thread.join()
logger.debug("all worker request threads returned")
webui_state.textinfo = "Distributed - injecting images"
received_images = False
for job in self.world.jobs:
if not isinstance(job.worker.response, dict) or job.batch_size < 1 or job.worker.master:
continue
try:
images: json = job.worker.response["images"]
# if we for some reason get more than we asked for
if (job.batch_size * p.n_iter) < len(images):
logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}")
received_images = True
except KeyError:
if job.batch_size > 0:
logger.warning(f"Worker '{job.worker.label}' had no images")
continue
except TypeError as e:
if job.worker.response is None:
msg = f"worker '{job.worker.label}' had no response"
logger.error(msg)
gradio.Warning("Distributed: "+msg)
else:
logger.exception(e)
continue
# adding the images in
self.inject_job(job, p, pp)
# TODO fix controlnet masks returned via api having no generation info
if received_images is False:
logger.critical("couldn't collect any responses, the extension will have no effect")
return
p.batch_size = len(pp.images)
webui_state.textinfo = ""
return
# p's type is
# "modules.processing.StableDiffusionProcessing*"
def before_process(self, p, *args):
# decide how to distribute work, apply adaptations for extensions, dispatch requests
if not self.enabled(p):
return
self.active_adapters = []
if p.all_prompts is None:
p.all_prompts = []
if p.all_negative_prompts is None:
p.all_negative_prompts = []
is_img2img = getattr(p, 'init_images', False)
if is_img2img and self.world.enabled_i2i is False:
logger.debug("extension is disabled for i2i")
return
elif not is_img2img and self.world.enabled is False:
logger.debug("extension is disabled for t2i")
return
self.world.update(p)
# save original process_images_inner function for later if we monkeypatch it
self.original_process_images_inner = processing.process_images_inner
generic_adapter = GenericAdapter()
for script in p.scripts.scripts:
if script.alwayson is not True:
continue
title = script.title()
# logger.debug(f"processing script '{title}'")
found_adapter = False
for adapter in adapters:
if adapter.title.lower() in title.lower():
self.active_adapters.append(adapter)
cede = adapter.early(p, self.world, script)
if cede:
logger.debug(f"adapter for '{adapter.title}' cedes control back to wui")
return
found_adapter = True
break
if not found_adapter: # shoehorn scripts which we don't explicitly support
generic_adapter.early(p, self.world, script)
logger.debug(f"activated {len(self.active_adapters)} adapters: {[a.title for a in self.active_adapters]}")
# generate seed early for master so that we can calculate the successive seeds for each slave
fix_seed(p)
payload = copy.copy(p.__dict__)
payload['alwayson_scripts'] = {}
payload['batch_size'] = self.world.default_batch_size()
payload['scripts'] = None
payload['scripts_value'] = None
try:
del payload['script_args']
except KeyError:
del payload['script_args_value']
name = re.sub(r'\s?\[[^]]*]$', '', opts.data["sd_model_checkpoint"])
vae = opts.data.get('sd_vae')
option_payload = {
"sd_model_checkpoint": name,
"sd_vae": vae
}
self.world.optimize_jobs(payload)
for adapter in self.active_adapters:
adapter.late(p, self.world, payload, option_payload)
generic_adapter.late(p, self.world, payload, option_payload)
# start generating images assigned to remote machines
sync = False # should only really need to sync once per job
started_jobs = []
# check if anything even needs to be done
if len(self.world.jobs) == 1 and self.world.jobs[0].worker.master:
if payload['batch_size'] >= 2:
msg = f"all remote workers are offline or unreachable"
gradio.Info(f"Distributed: "+msg)
logger.critical(msg)
logger.debug(f"distributed has nothing to do, returning control to webui")
return
for job in self.world.jobs:
if job.worker.state in (State.UNAVAILABLE, State.DISABLED):
continue
payload_worker = copy.deepcopy(payload)
if job.worker.master:
started_jobs.append(job)
if job.batch_size < 1 or job.worker.master:
continue
prior_images = 0
for j in started_jobs:
prior_images += j.batch_size * p.n_iter
payload_worker['batch_size'] = job.batch_size
if len(payload_worker['all_prompts']) == self.world.num_requested():
payload_worker['prompt'] = payload_worker['all_prompts'][prior_images]
if len(payload_worker['all_negative_prompts']) == self.world.num_requested():
payload_worker['negative_prompt'] = payload_worker['all_negative_prompts'][prior_images]
if job.step_override is not None:
payload_worker['steps'] = job.step_override
payload_worker['subseed'] += prior_images
if not self.world.comparison_mode:
payload_worker['seed'] += prior_images if payload_worker['subseed_strength'] == 0 else 0
logger.debug(
f"'{job.worker.label}' job's given starting seed is "
f"{payload_worker['seed']} with {prior_images} coming before it"
)
if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
sync = True
job.worker.loaded_model = name
job.worker.loaded_vae = vae
job.thread = Thread(target=job.worker.request, args=(payload_worker, option_payload, sync,),
name=f"{job.worker.label}_request")
job.thread.start()
started_jobs.append(job)
# if master batch size was changed again due to optimization change it to the updated value
if not self.world.thin_client_mode:
p.batch_size = self.world.master_job().batch_size
self.master_start = time.time()
self.runs_since_init += 1
return
def postprocess_batch_list(self, p, pp, *args, **kwargs):
# inject images
if not self.world.thin_client_mode and p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
return
if not self.enabled(p):
return
if self.master_start is not None:
self.update_gallery(p=p, pp=pp)
def postprocess(self, p, processed, *args):
# overwrite with proper infotext from remote results and cleanup
if not self.enabled(p):
return
for job in self.world.jobs:
if job.worker.master:
continue
if job.worker.response is not None:
for i, v in enumerate(job.gallery_map):
infotext = json.loads(job.worker.response['info'])['infotexts'][i]
infotext += f", Worker Label: {job.worker.label}"
processed.infotexts[v] = infotext
# cleanup
for worker in self.world.get_workers():
worker.response = None
# restore process_images_inner if it was monkey-patched
processing.process_images_inner = self.original_process_images_inner
for adapter in self.active_adapters:
adapter.cleanup()
# save any dangling state to prevent load_config in next iteration overwriting it
self.world.save_config()
@staticmethod
def signal_handler(sig, frame):
logger.debug("handling interrupt signal")
# do cleanup
DistributedScript.world.save_config()
if sig == signal.SIGINT:
if callable(old_sigint_handler):
old_sigint_handler(sig, frame)
else:
if callable(old_sigterm_handler):
old_sigterm_handler(sig, frame)
else:
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)