stable-diffusion-webui-dist.../scripts/extension.py

421 lines
17 KiB
Python

"""
https://github.com/papuSpartan/stable-diffusion-webui-distributed
"""
import base64
import io
import json
import re
import threading
import gradio
from modules import scripts
from modules import processing
from threading import Thread, current_thread
from PIL import Image
from typing import List
import urllib3
import copy
from modules.images import save_image
from modules.shared import cmd_opts
import time
from pathlib import Path
import os
import subprocess
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
from scripts.spartan.Worker import Worker, State
from modules.shared import opts
from scripts.spartan.shared import logger
from scripts.spartan.control_net import pack_control_net
from modules.processing import fix_seed
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
# TODO see if the current api has some sort of UUID generation functionality.
# noinspection PyMissingOrEmptyDocstring
class Script(scripts.Script):
worker_threads: List[Thread] = []
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
verify_remotes = False if cmd_opts.distributed_skip_verify_remotes else True
is_img2img = True
is_txt2img = True
alwayson = False
first_run = True
master_start = None
world = None
# p's type is
# "modules.processing.StableDiffusionProcessingTxt2Img"
# runs every time the generate button is hit
def title(self):
return "Distribute"
def show(self, is_img2img):
# return scripts.AlwaysVisible
return True
def ui(self, is_img2img):
with gradio.Box(): # adds padding so our components don't look out of place
with gradio.Accordion(label='Distributed', open=False) as main_accordian:
with gradio.Tab('Status') as status_tab:
status = gradio.Textbox(elem_id='status', show_label=False)
status.placeholder = 'Refresh!'
jobs = gradio.Textbox(elem_id='jobs', label='Jobs', show_label=True)
jobs.placeholder = 'Refresh!'
refresh_status_btn = gradio.Button(value='Refresh')
refresh_status_btn.style(size='sm')
refresh_status_btn.click(Script.ui_connect_status, inputs=[], outputs=[jobs, status])
status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[jobs, status])
with gradio.Tab('Utils'):
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
refresh_checkpoints_btn.style(full_width=False)
refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[])
sync_models_btn = gradio.Button(value='Synchronize models')
sync_models_btn.style(full_width=False)
sync_models_btn.click(Script.user_sync_script, inputs=[], outputs=[])
interrupt_all_btn = gradio.Button(value='Interrupt all', variant='stop')
interrupt_all_btn.style(full_width=False)
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
# redo benchmarks button
redo_benchmarks_btn = gradio.Button(value='Redo benchmarks', variant='stop')
redo_benchmarks_btn.style(full_width=False)
redo_benchmarks_btn.click(Script.ui_connect_benchmark_button, inputs=[], outputs=[])
return
@staticmethod
def ui_connect_benchmark_button():
logger.info("Redoing benchmarks...")
Script.world.benchmark(rebenchmark=True)
@staticmethod
def user_sync_script():
user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user')
# user_script = user_scripts.joinpath('example.sh')
for file in user_scripts.iterdir():
if file.is_file() and file.name.startswith('sync'):
user_script = file
suffix = user_script.suffix[1:]
if suffix == 'ps1':
subprocess.call(['powershell', user_script])
return True
else:
f = open(user_script, 'r')
first_line = f.readline().strip()
if first_line.startswith('#!'):
shebang = first_line[2:]
subprocess.call([shebang, user_script])
return True
return False
# World is not constructed until the first generation job, so I use an intermediary call
@staticmethod
def ui_connect_interrupt_btn():
try:
Script.world.interrupt_remotes()
except AttributeError:
logger.debug("Nothing to interrupt, Distributed system not initialized")
@staticmethod
def ui_connect_refresh_ckpts_btn():
try:
Script.world.refresh_checkpoints()
except AttributeError:
logger.debug("Distributed system not initialized")
@staticmethod
def ui_connect_status():
try:
worker_status = ''
for worker in Script.world.workers:
if worker.master:
continue
worker_status += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
# TODO replace this with a single check to a state flag that we should make in the world class
for worker in Script.world.workers:
if worker.state == State.WORKING:
return Script.world.__str__(), worker_status
return 'No active jobs!', worker_status
# init system if it isn't already
except AttributeError as e:
# batch size will be clobbered later once an actual request is made anyway
Script.initialize(initial_payload=None)
return Script.ui_connect_status()
@staticmethod
def add_to_gallery(processed, p):
"""adds generated images to the image gallery after waiting for all workers to finish"""
def processed_inject_image(image, info_index, iteration: int, save_path_override=None, grid=False, response=None):
image_params: json = response["parameters"]
image_info_post: json = json.loads(response["info"]) # image info known after processing
try:
# some metadata
processed.all_seeds.append(image_info_post["all_seeds"][info_index])
processed.all_subseeds.append(image_info_post["all_subseeds"][info_index])
processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][info_index])
except Exception:
# like with controlnet masks, there isn't always full post-gen info, so we use the first images'
logger.debug(f"Image at index {i} for '{worker.uuid}' was missing some post-generation data")
processed_inject_image(image=image, info_index=0, iteration=iteration)
return
processed.all_prompts.append(image_params["prompt"])
processed.images.append(image) # actual received image
# generate info-text string
images_per_batch = p.n_iter * p.batch_size
# zero-indexed position of image in total batch (so including master results)
true_image_pos = len(processed.images) - 1
num_remote_images = images_per_batch * p.batch_size
if p.n_iter > 1: # if splitting by batch count
num_remote_images *= p.n_iter - 1
info_text_used_seed_index = info_index + p.n_iter * p.batch_size if not grid else 0
if iteration != 0:
logger.debug(f"iteration {iteration}/{p.n_iter}, image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, info-index: {info_index}, used seed index {info_text_used_seed_index}")
info_text = processing.create_infotext(
p=p,
all_prompts=processed.all_prompts,
all_seeds=processed.all_seeds,
all_subseeds=processed.all_subseeds,
# comments=[""], # unimplemented upstream :(
position_in_batch=true_image_pos if not grid else 0,
iteration=0
)
processed.infotexts.append(info_text)
# automatically save received image to local disk if desired
if cmd_opts.distributed_remotes_autosave:
save_image(
image=image,
path=p.outpath_samples if save_path_override is None else save_path_override,
basename="",
seed=processed.all_seeds[-1],
prompt=processed.all_prompts[-1],
info=info_text,
extension=opts.samples_format
)
# get master ipm by estimating based on worker speed
master_elapsed = time.time() - Script.master_start
logger.debug(f"Took master {master_elapsed:.2f}s")
# wait for response from all workers
for thread in Script.worker_threads:
logger.debug(f"waiting for worker thread '{thread.name}'")
thread.join()
Script.worker_threads.clear()
logger.debug("all worker request threads returned")
# some worker which we know has a good response that we can use for generating the grid
donor_worker = None
spoofed_iteration = p.n_iter
for worker in Script.world.workers:
expected_images = 1
for job in Script.world.jobs:
if job.worker == worker:
expected_images = job.batch_size * p.n_iter
try:
images: json = worker.response["images"]
# if we for some reason get more than we asked for
if expected_images < len(images):
logger.debug(f"Requested {expected_images} images from '{worker.uuid}', got {len(images)}")
if donor_worker is None:
donor_worker = worker
except Exception:
if worker.master is False:
logger.warning(f"Worker '{worker.uuid}' had nothing")
continue
injected_to_iteration = 0
images_per_iteration = Script.world.get_current_output_size()
# visibly add work from workers to the image gallery
for i in range(0, len(images)):
image_bytes = base64.b64decode(images[i])
image = Image.open(io.BytesIO(image_bytes))
# inject image
processed_inject_image(image=image, info_index=i, iteration=spoofed_iteration, response=worker.response)
if injected_to_iteration >= images_per_iteration - 1:
spoofed_iteration += 1
injected_to_iteration = 0
else:
injected_to_iteration += 1
# generate and inject grid
if opts.return_grid:
grid = processing.images.image_grid(processed.images, len(processed.images))
processed_inject_image(
image=grid,
info_index=0,
save_path_override=p.outpath_grids,
iteration=spoofed_iteration,
grid=True,
response=donor_worker.response
)
# cleanup after we're doing using all the responses
for worker in Script.world.workers:
worker.response = None
p.batch_size = len(processed.images)
return
@staticmethod
def initialize(initial_payload):
# get default batch size
try:
batch_size = initial_payload.batch_size
except AttributeError:
batch_size = 1
if Script.world is None:
if Script.verify_remotes is False:
logger.warning(f"You have chosen to forego the verification of worker TLS certificates")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# construct World
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
# add workers to the world
for worker in cmd_opts.distributed_remotes:
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
try:
Script.world.initialize(batch_size)
logger.debug(f"World initialized!")
except WorldAlreadyInitialized:
Script.world.update_world(total_batch_size=batch_size)
def run(self, p, *args):
current_thread().name = "distributed_main"
if cmd_opts.distributed_remotes is None:
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
Script.initialize(initial_payload=p)
# strip scripts that aren't yet supported and warn user
packed_script_args: List[dict] = [] # list of api formatted per-script argument objects
for script in p.scripts.scripts:
if script.alwayson is not True:
continue
title = script.title()
# check for supported scripts
if title == "ControlNet":
# grab all controlnet units
cn_units = []
cn_args = p.script_args[script.args_from:script.args_to]
for cn_arg in cn_args:
if type(cn_arg).__name__ == "UiControlNetUnit":
cn_units.append(cn_arg)
logger.debug(f"Detected {len(cn_units)} controlnet unit(s)")
# get api formatted controlnet
packed_script_args.append(pack_control_net(cn_units))
continue
else:
# https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514
logger.warning(f"Distributed doesn't yet support '{title}'")
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
# see test/basic_features/txt2img_test.py for an example
payload = copy.copy(p.__dict__)
payload['batch_size'] = Script.world.get_default_worker_batch_size()
payload['scripts'] = None
del payload['script_args']
payload['alwayson_scripts'] = {}
for packed in packed_script_args:
payload['alwayson_scripts'].update(packed)
# generate seed early for master so that we can calculate the successive seeds for each slave
fix_seed(p)
payload['seed'] = p.seed
payload['subseed'] = p.subseed
# TODO api for some reason returns 200 even if something failed to be set.
# for now we may have to make redundant GET requests to check if actually successful...
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146
name = re.sub(r'\s?\[[^\]]*\]$', '', opts.data["sd_model_checkpoint"])
vae = opts.data["sd_vae"]
option_payload = {
# "sd_model_checkpoint": opts.data["sd_model_checkpoint"],
"sd_model_checkpoint": name,
"sd_vae": vae
}
# start generating images assigned to remote machines
sync = False # should only really to sync once per job
Script.world.optimize_jobs(payload) # optimize work assignment before dispatching
started_jobs = []
for job in Script.world.jobs:
payload_temp = copy.deepcopy(payload)
if job.worker.master:
started_jobs.append(job)
if job.batch_size < 1 or job.worker.master:
continue
prior_images = 0
for j in started_jobs:
prior_images += j.batch_size * p.n_iter
payload_temp['batch_size'] = job.batch_size
payload_temp['subseed'] += prior_images
payload_temp['seed'] += prior_images if payload_temp['subseed_strength'] == 0 else 0
logger.debug(f"'{job.worker.uuid}' job's given starting seed is {payload_temp['seed']} with {prior_images} coming before it")
if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
sync = True
job.worker.loaded_model = name
job.worker.loaded_vae = vae
t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync, ), name=f"{job.worker.uuid}_request")
t.start()
Script.worker_threads.append(t)
started_jobs.append(job)
# if master batch size was changed again due to optimization change it to the updated value
p.batch_size = Script.world.get_master_batch_size()
Script.master_start = time.time()
# generate images assigned to local machine
p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
processed = processing.process_images(p, *args)
Script.add_to_gallery(processed, p)
return processed