early support for controlnet
parent
d3b790d709
commit
04f8807df5
|
|
@ -24,6 +24,8 @@ from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
|
|||
from scripts.spartan.Worker import Worker, State
|
||||
from modules.shared import opts
|
||||
from scripts.spartan.shared import logger
|
||||
from scripts.spartan.control_net import pack_control_net
|
||||
|
||||
|
||||
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
|
||||
# TODO see if the current api has some sort of UUID generation functionality.
|
||||
|
|
@ -174,8 +176,16 @@ class Script(scripts.Script):
|
|||
|
||||
for worker in Script.world.workers:
|
||||
# if it fails here then that means that the response_cache global var is not being filled for some reason
|
||||
expected_images = 1
|
||||
for job in Script.world.jobs:
|
||||
if job.worker == worker:
|
||||
expected_images = job.batch_size
|
||||
|
||||
try:
|
||||
images: json = worker.response["images"]
|
||||
# if we for some reason get more than we asked for
|
||||
if expected_images < len(images):
|
||||
logger.debug(f"Requested {expected_images} images from '{worker.uuid}', got {len(images)}")
|
||||
except Exception:
|
||||
if worker.master is False:
|
||||
logger.warn(f"Worker '{worker.uuid}' had nothing")
|
||||
|
|
@ -190,12 +200,17 @@ class Script(scripts.Script):
|
|||
image = Image.open(io.BytesIO(image_bytes))
|
||||
processed.images.append(image)
|
||||
|
||||
# params
|
||||
processed.all_prompts.append(image_params["prompt"])
|
||||
# post-generation
|
||||
processed.all_seeds.append(image_info_post["all_seeds"][i])
|
||||
processed.all_subseeds.append(image_info_post["all_subseeds"][i])
|
||||
processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][i])
|
||||
try:
|
||||
# post-generation
|
||||
processed.all_seeds.append(image_info_post["all_seeds"][i])
|
||||
processed.all_subseeds.append(image_info_post["all_subseeds"][i])
|
||||
processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][i])
|
||||
except Exception as e:
|
||||
logger.debug(f"Image at index {i} for '{worker.uuid}' was missing some post-generation data")
|
||||
processed.all_seeds.append(image_info_post["all_seeds"][0])
|
||||
processed.all_subseeds.append(image_info_post["all_subseeds"][0])
|
||||
processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][0])
|
||||
|
||||
# generate info-text string (mostly for user use)
|
||||
this_info_text = processing.create_infotext(
|
||||
|
|
@ -265,25 +280,54 @@ class Script(scripts.Script):
|
|||
Script.initialize(initial_payload=p)
|
||||
|
||||
# strip scripts that aren't yet supported and warn user
|
||||
for arg in range(0, len(p.script_args) - 1):
|
||||
controlnet = None
|
||||
arg = 0
|
||||
while arg < len(p.script_args) - 1:
|
||||
try:
|
||||
json.dumps(p.script_args[arg])
|
||||
except Exception:
|
||||
sanitized_script_args = p.script_args[:arg] + p.script_args[arg + 1:]
|
||||
p.script_args = sanitized_script_args
|
||||
|
||||
# find which script owns the offending arguments
|
||||
# find which script owns the offending arguments and fix if supported
|
||||
script_args_upper = None # the upper index of the offending scripts args so that we can skip the rest
|
||||
for script in p.scripts.scripts:
|
||||
title = script.title()
|
||||
|
||||
# check for supported scripts
|
||||
if title == "ControlNet" and script.alwayson is True:
|
||||
# grab all controlnet units
|
||||
cn_units = []
|
||||
for cn_arg in range(script.args_from, script.args_to + 1):
|
||||
if isinstance(p.script_args[cn_arg], type(p.script_args[arg])):
|
||||
cn_units.append(p.script_args[cn_arg])
|
||||
logger.debug(f"Detected {len(cn_units)} controlnet unit(s)")
|
||||
|
||||
# get api formatted controlnet
|
||||
controlnet: dict = pack_control_net(cn_units)
|
||||
|
||||
# ensure we don't do this more than once
|
||||
script_args_upper = script.args_to
|
||||
continue
|
||||
|
||||
# clean unsupported scripts
|
||||
if script.args_from <= arg <= script.args_to:
|
||||
logger.warn(f"Distributed does not yet support '{title}'")
|
||||
sanitized_script_args = p.script_args[:arg] + p.script_args[arg + 1:]
|
||||
p.script_args = sanitized_script_args
|
||||
|
||||
if script_args_upper is not None:
|
||||
arg = script_args_upper
|
||||
|
||||
arg += 1
|
||||
|
||||
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
|
||||
# see test/basic_features/txt2img_test.py for an example
|
||||
payload = p.__dict__
|
||||
payload = copy.copy(p.__dict__)
|
||||
payload['batch_size'] = Script.world.get_default_worker_batch_size()
|
||||
payload['scripts'] = None
|
||||
del payload['script_args']
|
||||
|
||||
if controlnet is not None:
|
||||
payload['alwayson_scripts'] = controlnet
|
||||
|
||||
# TODO api for some reason returns 200 even if something failed to be set.
|
||||
# for now we may have to make redundant GET requests to check if actually successful...
|
||||
|
|
@ -303,15 +347,14 @@ class Script(scripts.Script):
|
|||
if job.batch_size < 1 or job.worker.master:
|
||||
continue
|
||||
|
||||
new_payload = copy.copy(payload) # prevent race condition instead of sharing the payload object
|
||||
new_payload['batch_size'] = job.batch_size
|
||||
payload['batch_size'] = job.batch_size
|
||||
|
||||
if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
|
||||
sync = True
|
||||
job.worker.loaded_model = name
|
||||
job.worker.loaded_vae = vae
|
||||
|
||||
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,))
|
||||
t = Thread(target=job.worker.request, args=(payload, option_payload, sync,))
|
||||
|
||||
t.start()
|
||||
Script.worker_threads.append(t)
|
||||
|
|
|
|||
|
|
@ -334,6 +334,9 @@ class Worker:
|
|||
verify=self.verify_remotes
|
||||
)
|
||||
self.response = response.json()
|
||||
if response.status_code != 200:
|
||||
logger.error(f"'{self.uuid}' response: Code <{response.status_code}> {str(response.content, 'utf-8')}")
|
||||
raise InvalidWorkerResponse()
|
||||
|
||||
# update list of ETA accuracy
|
||||
if self.benchmarked and not self.state == State.INTERRUPTED:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
from modules.api.api import encode_pil_to_base64
|
||||
from PIL import Image
|
||||
import copy
|
||||
from scripts.spartan.shared import logger
|
||||
|
||||
|
||||
def pack_control_net(cn_units) -> dict:
|
||||
"""
|
||||
Given the control-net units, return the enveloping controlnet dict to be used with the api
|
||||
"""
|
||||
controlnet = {
|
||||
'controlnet':
|
||||
{
|
||||
'args': []
|
||||
}
|
||||
}
|
||||
cn_args = controlnet['controlnet']['args']
|
||||
|
||||
for i in range(0, len(cn_units)):
|
||||
# copy control net unit to payload
|
||||
cn_args.append(copy.copy(cn_units[i].__dict__))
|
||||
unit = cn_args[i]
|
||||
|
||||
# if unit isn't enabled then don't bother including
|
||||
if not unit['enabled']:
|
||||
del unit['input_mode']
|
||||
del unit['image']
|
||||
logger.debug(f"Controlnet unit {i} is not enabled. Ignoring")
|
||||
continue
|
||||
|
||||
# serialize image
|
||||
if unit['image'] is not None:
|
||||
image = unit['image']['image']
|
||||
# mask = unit['image']['mask']
|
||||
pil = Image.fromarray(image)
|
||||
image_b64 = encode_pil_to_base64(pil)
|
||||
image_b64 = str(image_b64, 'utf-8')
|
||||
unit['input_image'] = image_b64
|
||||
|
||||
# remove anything unserializable
|
||||
del unit['input_mode']
|
||||
del unit['image']
|
||||
|
||||
return controlnet
|
||||
Loading…
Reference in New Issue