223 lines
8.9 KiB
Python
223 lines
8.9 KiB
Python
"""
|
|
https://github.com/papuSpartan/stable-diffusion-webui-distributed
|
|
"""
|
|
|
|
import base64
|
|
import io
|
|
import json
|
|
import re
|
|
import gradio
|
|
from modules import scripts, script_callbacks
|
|
from modules import processing
|
|
from threading import Thread
|
|
from PIL import Image
|
|
from typing import List
|
|
import urllib3
|
|
import copy
|
|
from modules.images import save_image
|
|
from modules.shared import cmd_opts
|
|
import time
|
|
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
|
|
from scripts.spartan.Worker import Worker
|
|
from modules.shared import opts
|
|
|
|
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
|
|
# TODO see if the current api has some sort of UUID generation functionality.
|
|
|
|
# noinspection PyMissingOrEmptyDocstring
|
|
class Script(scripts.Script):
|
|
response_cache: json = None
|
|
worker_threads: List[Thread] = []
|
|
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
|
verify_remotes = False if cmd_opts.distributed_skip_verify_remotes else True
|
|
|
|
is_img2img = True
|
|
is_txt2img = True
|
|
alwayson = False
|
|
first_run = True
|
|
master_start = None
|
|
|
|
world = None
|
|
|
|
# p's type is
|
|
# "modules.processing.StableDiffusionProcessingTxt2Img"
|
|
# runs every time the generate button is hit
|
|
|
|
def title(self):
|
|
return "Distribute"
|
|
|
|
def show(self, is_img2img):
|
|
# return scripts.AlwaysVisible
|
|
return True
|
|
|
|
def ui(self, is_img2img):
|
|
|
|
with gradio.Box(): # adds padding so our components don't look out of place
|
|
interrupt_all_btn = gradio.Button(value="Interrupt all remote workers")
|
|
interrupt_all_btn.style(full_width=False)
|
|
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
|
|
|
|
return [interrupt_all_btn]
|
|
|
|
# World is not constructed until the first generation job, so I use an intermediary call
|
|
@staticmethod
|
|
def ui_connect_interrupt_btn():
|
|
try:
|
|
Script.world.interrupt_remotes()
|
|
except AttributeError:
|
|
print("Nothing to interrupt, Distributed system not initialized")
|
|
|
|
@staticmethod
|
|
def add_to_gallery(processed, p):
|
|
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
|
# get master ipm by estimating based on worker speed
|
|
global worker
|
|
master_elapsed = time.time() - Script.master_start
|
|
print(f"Took master {master_elapsed}s")
|
|
|
|
# wait for response from all workers
|
|
for thread in Script.worker_threads:
|
|
thread.join()
|
|
|
|
for worker in Script.world.workers:
|
|
# if it fails here then that means that the response_cache global var is not being filled for some reason
|
|
try:
|
|
images: json = worker.response["images"]
|
|
except TypeError:
|
|
if worker.master is False:
|
|
print(f"Worker '{worker.uuid}' had nothing")
|
|
continue
|
|
|
|
image_params: json = worker.response["parameters"]
|
|
image_info_post: json = json.loads(worker.response["info"]) # image info known after processing
|
|
|
|
# visibly add work from workers to the txt2img gallery
|
|
for i in range(0, len(images)):
|
|
image_bytes = base64.b64decode(images[i])
|
|
image = Image.open(io.BytesIO(image_bytes))
|
|
processed.images.append(image)
|
|
|
|
# params
|
|
processed.all_prompts.append(image_params["prompt"])
|
|
# for k in vars(processed):
|
|
# try:
|
|
# if image_params[k] is not None:
|
|
# print(f"processed: '{processed.k}'\nparams: '{image_params[k]}'\n")
|
|
# except Exception as e:
|
|
# print(e)
|
|
|
|
# post-generation
|
|
processed.all_seeds.append(image_info_post["all_seeds"][i])
|
|
processed.all_subseeds.append(image_info_post["all_subseeds"][i])
|
|
processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][i])
|
|
|
|
# generate info-text string (mostly for user use)
|
|
this_info_text = processing.create_infotext(
|
|
p,
|
|
processed.all_prompts,
|
|
processed.all_seeds,
|
|
processed.all_subseeds,
|
|
comments=[""],
|
|
position_in_batch=i + p.batch_size,
|
|
iteration=0 # not sure exactly what to do with this yet
|
|
)
|
|
processed.infotexts.append(this_info_text)
|
|
|
|
# save image to local disk if desired
|
|
# TODO add command line toggle for having worker results saved to disk
|
|
if cmd_opts.distributed_remotes_autosave:
|
|
save_image(
|
|
image,
|
|
p.outpath_samples,
|
|
"",
|
|
processed.all_seeds[i],
|
|
processed.all_prompts[i],
|
|
opts.samples_format,
|
|
info=this_info_text
|
|
)
|
|
|
|
p.batch_size = Script.world.get_current_output_size()
|
|
"""
|
|
This ensures that we don't get redundant outputs in a certain case:
|
|
We have 3 workers and we get 3 responses back.
|
|
The user requests another 3, but due to job optimization one of the workers does not produce anything new.
|
|
If we don't empty the response, the user will get back the two images they requested, but also one from before.
|
|
"""
|
|
worker.response = None
|
|
|
|
Script.unregister_callbacks()
|
|
return
|
|
|
|
def run(self, p, *args):
|
|
if cmd_opts.distributed_remotes is None:
|
|
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
|
|
|
|
Script.world = World(initial_payload=p, verify_remotes=Script.verify_remotes)
|
|
# add workers to the world
|
|
for worker in cmd_opts.distributed_remotes:
|
|
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
|
# register gallery callback
|
|
script_callbacks.on_after_image_processed(Script.add_to_gallery)
|
|
|
|
if self.verify_remotes is False:
|
|
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
|
|
try:
|
|
Script.world.initialize(p.batch_size)
|
|
print("World initialized!")
|
|
except WorldAlreadyInitialized:
|
|
Script.world.update_world(p.batch_size)
|
|
|
|
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
|
|
# see test/basic_features/txt2img_test.py for an example
|
|
payload = p.__dict__
|
|
payload['batch_size'] = Script.world.get_default_worker_batch_size()
|
|
payload['scripts'] = None
|
|
# print(payload)
|
|
# print(opts.dumpjson())
|
|
|
|
# TODO api for some reason returns 200 even if something failed to be set.
|
|
# for now we may have to make redundant GET requests to check if actually successful...
|
|
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146
|
|
name = re.sub(r'\s?\[[^\]]*\]$', '', opts.data["sd_model_checkpoint"])
|
|
vae = opts.data["sd_vae"]
|
|
option_payload = {
|
|
# "sd_model_checkpoint": opts.data["sd_model_checkpoint"],
|
|
"sd_model_checkpoint": name,
|
|
"sd_vae": vae
|
|
}
|
|
|
|
sync = False # should only really to sync once per job
|
|
Script.world.optimize_jobs(payload) # optimize work assignment before dispatching
|
|
for job in Script.world.jobs:
|
|
if job.batch_size < 1 or job.worker.master:
|
|
continue
|
|
|
|
new_payload = copy.copy(payload) # prevent race condition instead of sharing the payload object
|
|
new_payload['batch_size'] = job.batch_size
|
|
|
|
if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
|
|
sync = True
|
|
job.worker.loaded_model = name
|
|
job.worker.loaded_vae = vae
|
|
|
|
# print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n")
|
|
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,))
|
|
|
|
t.start()
|
|
Script.worker_threads.append(t)
|
|
|
|
# if master batch size was changed again due to optimization change it to the updated value
|
|
p.batch_size = Script.world.get_master_batch_size()
|
|
Script.master_start = time.time()
|
|
# return processing.process_images(p, *args)
|
|
|
|
@staticmethod
|
|
def unregister_callbacks():
|
|
script_callbacks.remove_current_script_callbacks()
|
|
|
|
|
|
# not actually called when user selects a different script in the ui dropdown
|
|
script_callbacks.on_script_unloaded(Script.unregister_callbacks)
|