early support for controlnet

pull/7/head
unknown 2023-05-27 01:55:36 -05:00
parent d3b790d709
commit 04f8807df5
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
3 changed files with 103 additions and 13 deletions

View File

@ -24,6 +24,8 @@ from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
from scripts.spartan.Worker import Worker, State
from modules.shared import opts
from scripts.spartan.shared import logger
from scripts.spartan.control_net import pack_control_net
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
# TODO see if the current api has some sort of UUID generation functionality.
@ -174,8 +176,16 @@ class Script(scripts.Script):
for worker in Script.world.workers:
# if it fails here then that means that the response_cache global var is not being filled for some reason
expected_images = 1
for job in Script.world.jobs:
if job.worker == worker:
expected_images = job.batch_size
try:
images: json = worker.response["images"]
# if we for some reason get more than we asked for
if expected_images < len(images):
logger.debug(f"Requested {expected_images} images from '{worker.uuid}', got {len(images)}")
except Exception:
if worker.master is False:
logger.warn(f"Worker '{worker.uuid}' had nothing")
@ -190,12 +200,17 @@ class Script(scripts.Script):
image = Image.open(io.BytesIO(image_bytes))
processed.images.append(image)
# params
processed.all_prompts.append(image_params["prompt"])
# post-generation
processed.all_seeds.append(image_info_post["all_seeds"][i])
processed.all_subseeds.append(image_info_post["all_subseeds"][i])
processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][i])
try:
# post-generation
processed.all_seeds.append(image_info_post["all_seeds"][i])
processed.all_subseeds.append(image_info_post["all_subseeds"][i])
processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][i])
except Exception as e:
logger.debug(f"Image at index {i} for '{worker.uuid}' was missing some post-generation data")
processed.all_seeds.append(image_info_post["all_seeds"][0])
processed.all_subseeds.append(image_info_post["all_subseeds"][0])
processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][0])
# generate info-text string (mostly for user use)
this_info_text = processing.create_infotext(
@ -265,25 +280,54 @@ class Script(scripts.Script):
Script.initialize(initial_payload=p)
# strip scripts that aren't yet supported and warn user
for arg in range(0, len(p.script_args) - 1):
controlnet = None
arg = 0
while arg < len(p.script_args) - 1:
try:
json.dumps(p.script_args[arg])
except Exception:
sanitized_script_args = p.script_args[:arg] + p.script_args[arg + 1:]
p.script_args = sanitized_script_args
# find which script owns the offending arguments
# find which script owns the offending arguments and fix if supported
script_args_upper = None # the upper index of the offending scripts args so that we can skip the rest
for script in p.scripts.scripts:
title = script.title()
# check for supported scripts
if title == "ControlNet" and script.alwayson is True:
# grab all controlnet units
cn_units = []
for cn_arg in range(script.args_from, script.args_to + 1):
if isinstance(p.script_args[cn_arg], type(p.script_args[arg])):
cn_units.append(p.script_args[cn_arg])
logger.debug(f"Detected {len(cn_units)} controlnet unit(s)")
# get api formatted controlnet
controlnet: dict = pack_control_net(cn_units)
# ensure we don't do this more than once
script_args_upper = script.args_to
continue
# clean unsupported scripts
if script.args_from <= arg <= script.args_to:
logger.warn(f"Distributed does not yet support '{title}'")
sanitized_script_args = p.script_args[:arg] + p.script_args[arg + 1:]
p.script_args = sanitized_script_args
if script_args_upper is not None:
arg = script_args_upper
arg += 1
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
# see test/basic_features/txt2img_test.py for an example
payload = p.__dict__
payload = copy.copy(p.__dict__)
payload['batch_size'] = Script.world.get_default_worker_batch_size()
payload['scripts'] = None
del payload['script_args']
if controlnet is not None:
payload['alwayson_scripts'] = controlnet
# TODO api for some reason returns 200 even if something failed to be set.
# for now we may have to make redundant GET requests to check if actually successful...
@ -303,15 +347,14 @@ class Script(scripts.Script):
if job.batch_size < 1 or job.worker.master:
continue
new_payload = copy.copy(payload) # prevent race condition instead of sharing the payload object
new_payload['batch_size'] = job.batch_size
payload['batch_size'] = job.batch_size
if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
sync = True
job.worker.loaded_model = name
job.worker.loaded_vae = vae
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,))
t = Thread(target=job.worker.request, args=(payload, option_payload, sync,))
t.start()
Script.worker_threads.append(t)

View File

@ -334,6 +334,9 @@ class Worker:
verify=self.verify_remotes
)
self.response = response.json()
if response.status_code != 200:
logger.error(f"'{self.uuid}' response: Code <{response.status_code}> {str(response.content, 'utf-8')}")
raise InvalidWorkerResponse()
# update list of ETA accuracy
if self.benchmarked and not self.state == State.INTERRUPTED:

View File

@ -0,0 +1,44 @@
from modules.api.api import encode_pil_to_base64
from PIL import Image
import copy
from scripts.spartan.shared import logger
def pack_control_net(cn_units) -> dict:
"""
Given the control-net units, return the enveloping controlnet dict to be used with the api
"""
controlnet = {
'controlnet':
{
'args': []
}
}
cn_args = controlnet['controlnet']['args']
for i in range(0, len(cn_units)):
# copy control net unit to payload
cn_args.append(copy.copy(cn_units[i].__dict__))
unit = cn_args[i]
# if unit isn't enabled then don't bother including
if not unit['enabled']:
del unit['input_mode']
del unit['image']
logger.debug(f"Controlnet unit {i} is not enabled. Ignoring")
continue
# serialize image
if unit['image'] is not None:
image = unit['image']['image']
# mask = unit['image']['mask']
pil = Image.fromarray(image)
image_b64 = encode_pil_to_base64(pil)
image_b64 = str(image_b64, 'utf-8')
unit['input_image'] = image_b64
# remove anything unserializable
del unit['input_mode']
del unit['image']
return controlnet