diff --git a/scripts/extension.py b/scripts/extension.py index 9693adc..ad3ff80 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -24,6 +24,8 @@ from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized from scripts.spartan.Worker import Worker, State from modules.shared import opts from scripts.spartan.shared import logger +from scripts.spartan.control_net import pack_control_net + # TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers? # TODO see if the current api has some sort of UUID generation functionality. @@ -174,8 +176,16 @@ class Script(scripts.Script): for worker in Script.world.workers: # if it fails here then that means that the response_cache global var is not being filled for some reason + expected_images = 1 + for job in Script.world.jobs: + if job.worker == worker: + expected_images = job.batch_size + try: images: json = worker.response["images"] + # if we for some reason get more than we asked for + if expected_images < len(images): + logger.debug(f"Requested {expected_images} images from '{worker.uuid}', got {len(images)}") except Exception: if worker.master is False: logger.warn(f"Worker '{worker.uuid}' had nothing") @@ -190,12 +200,17 @@ class Script(scripts.Script): image = Image.open(io.BytesIO(image_bytes)) processed.images.append(image) - # params processed.all_prompts.append(image_params["prompt"]) - # post-generation - processed.all_seeds.append(image_info_post["all_seeds"][i]) - processed.all_subseeds.append(image_info_post["all_subseeds"][i]) - processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][i]) + try: + # post-generation + processed.all_seeds.append(image_info_post["all_seeds"][i]) + processed.all_subseeds.append(image_info_post["all_subseeds"][i]) + processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][i]) + except Exception as e: + logger.debug(f"Image at index {i} for '{worker.uuid}' was missing some post-generation data") + processed.all_seeds.append(image_info_post["all_seeds"][0]) + processed.all_subseeds.append(image_info_post["all_subseeds"][0]) + processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][0]) # generate info-text string (mostly for user use) this_info_text = processing.create_infotext( @@ -265,25 +280,54 @@ class Script(scripts.Script): Script.initialize(initial_payload=p) # strip scripts that aren't yet supported and warn user - for arg in range(0, len(p.script_args) - 1): + controlnet = None + arg = 0 + while arg < len(p.script_args) - 1: try: json.dumps(p.script_args[arg]) except Exception: - sanitized_script_args = p.script_args[:arg] + p.script_args[arg + 1:] - p.script_args = sanitized_script_args - # find which script owns the offending arguments + # find which script owns the offending arguments and fix if supported + script_args_upper = None # the upper index of the offending scripts args so that we can skip the rest for script in p.scripts.scripts: title = script.title() + # check for supported scripts + if title == "ControlNet" and script.alwayson is True: + # grab all controlnet units + cn_units = [] + for cn_arg in range(script.args_from, script.args_to + 1): + if isinstance(p.script_args[cn_arg], type(p.script_args[arg])): + cn_units.append(p.script_args[cn_arg]) + logger.debug(f"Detected {len(cn_units)} controlnet unit(s)") + + # get api formatted controlnet + controlnet: dict = pack_control_net(cn_units) + + # ensure we don't do this more than once + script_args_upper = script.args_to + continue + + # clean unsupported scripts if script.args_from <= arg <= script.args_to: logger.warn(f"Distributed does not yet support '{title}'") + sanitized_script_args = p.script_args[:arg] + p.script_args[arg + 1:] + p.script_args = sanitized_script_args + + if script_args_upper is not None: + arg = script_args_upper + + arg += 1 # encapsulating the request object within a txt2imgreq object is deprecated and no longer works # see test/basic_features/txt2img_test.py for an example - payload = p.__dict__ + payload = copy.copy(p.__dict__) payload['batch_size'] = Script.world.get_default_worker_batch_size() payload['scripts'] = None + del payload['script_args'] + + if controlnet is not None: + payload['alwayson_scripts'] = controlnet # TODO api for some reason returns 200 even if something failed to be set. # for now we may have to make redundant GET requests to check if actually successful... @@ -303,15 +347,14 @@ class Script(scripts.Script): if job.batch_size < 1 or job.worker.master: continue - new_payload = copy.copy(payload) # prevent race condition instead of sharing the payload object - new_payload['batch_size'] = job.batch_size + payload['batch_size'] = job.batch_size if job.worker.loaded_model != name or job.worker.loaded_vae != vae: sync = True job.worker.loaded_model = name job.worker.loaded_vae = vae - t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,)) + t = Thread(target=job.worker.request, args=(payload, option_payload, sync,)) t.start() Script.worker_threads.append(t) diff --git a/scripts/spartan/Worker.py b/scripts/spartan/Worker.py index d7826e0..5819fbe 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/Worker.py @@ -334,6 +334,9 @@ class Worker: verify=self.verify_remotes ) self.response = response.json() + if response.status_code != 200: + logger.error(f"'{self.uuid}' response: Code <{response.status_code}> {str(response.content, 'utf-8')}") + raise InvalidWorkerResponse() # update list of ETA accuracy if self.benchmarked and not self.state == State.INTERRUPTED: diff --git a/scripts/spartan/control_net.py b/scripts/spartan/control_net.py new file mode 100644 index 0000000..fe2d99a --- /dev/null +++ b/scripts/spartan/control_net.py @@ -0,0 +1,44 @@ +from modules.api.api import encode_pil_to_base64 +from PIL import Image +import copy +from scripts.spartan.shared import logger + + +def pack_control_net(cn_units) -> dict: + """ + Given the control-net units, return the enveloping controlnet dict to be used with the api + """ + controlnet = { + 'controlnet': + { + 'args': [] + } + } + cn_args = controlnet['controlnet']['args'] + + for i in range(0, len(cn_units)): + # copy control net unit to payload + cn_args.append(copy.copy(cn_units[i].__dict__)) + unit = cn_args[i] + + # if unit isn't enabled then don't bother including + if not unit['enabled']: + del unit['input_mode'] + del unit['image'] + logger.debug(f"Controlnet unit {i} is not enabled. Ignoring") + continue + + # serialize image + if unit['image'] is not None: + image = unit['image']['image'] + # mask = unit['image']['mask'] + pil = Image.fromarray(image) + image_b64 = encode_pil_to_base64(pil) + image_b64 = str(image_b64, 'utf-8') + unit['input_image'] = image_b64 + + # remove anything unserializable + del unit['input_mode'] + del unit['image'] + + return controlnet