132 lines
4.2 KiB
Python
132 lines
4.2 KiB
Python
from PIL.Image import Image
|
|
from typing_extensions import override, Tuple
|
|
from scripts.spartan.shared import logger
|
|
from scripts.spartan.control_net import pack_control_net
|
|
|
|
class Adapter(object):
|
|
def __init__(self):
|
|
self.script = None
|
|
|
|
def early(self, p, world, script, *args) -> bool:
|
|
"""make changes before any worker request objects are created. return True to cede control back to webui"""
|
|
|
|
self.script = script
|
|
return False
|
|
|
|
def late(self, p, world, payload, *args):
|
|
"""make changes after the worker request object has been created and workloads have been manipulated"""
|
|
# payload['alwayson_scripts'] guaranteed to exist, but may not be populated
|
|
pass
|
|
|
|
def script_args(self, p) -> Tuple:
|
|
return p.script_args[self.script.args_from:self.script.args_to]
|
|
|
|
class DynamicPromptsAdapter(Adapter):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.title = "Dynamic Prompts"
|
|
|
|
def early(self, p, world, script, *args):
|
|
super().early(p, world, script)
|
|
|
|
# dynamic will run twice, but because of load order our call overwrites the original
|
|
# logger.debug("finding callback")
|
|
# script_process_cbs = p.scripts.callback_map['script_process'][1]
|
|
|
|
# for i, callback in enumerate(script_process_cbs):
|
|
# if callback.callback.name == self.title.lower():
|
|
# logger.debug(f"found callback")
|
|
|
|
# # prevent double exec
|
|
# script_process_cbs.remove(callback)
|
|
# else:
|
|
# logger.debug(f"already hooked dynamic prompts")
|
|
|
|
# right before payload is cloned into a seperate instance for each job
|
|
def late(self, p, world, payload, *args):
|
|
|
|
# dynamic clobbers the actual p even if we pass a different object
|
|
if len(p.prompt) > 1:
|
|
p.all_prompts.clear()
|
|
for i in range(world.num_requested()):
|
|
p.all_prompts.append(p.prompt)
|
|
elif len(p.negative_prompt) > 1:
|
|
p.all_negative_prompts.clear()
|
|
for i in range(world.num_requested()):
|
|
p.all_negative_prompts.append(p.negative_prompt)
|
|
else:
|
|
return
|
|
|
|
# dynamic prompts overrides p.all_prompts if it doesn't match the batch size
|
|
temp = p.batch_size
|
|
p.batch_size = world.num_requested()
|
|
self.script.process(p, *self.script_args(p))
|
|
p.batch_size = temp
|
|
payload['all_prompts'] = p.all_prompts
|
|
payload['all_negative_prompts'] = p.all_negative_prompts
|
|
|
|
class ControlNetAdapter(Adapter):
|
|
def __init__(self, *args):
|
|
super().__init__()
|
|
self.title = "ControlNet"
|
|
|
|
def late(self, p, world, payload, *args):
|
|
# grab all controlnet units
|
|
cn_units = []
|
|
for cn_arg in self.script_args(p):
|
|
if "ControlNetUnit" in type(cn_arg).__name__:
|
|
cn_units.append(cn_arg)
|
|
logger.debug(f"Detected {len(cn_units)} controlnet unit(s)")
|
|
|
|
# get api formatted controlnet
|
|
payload['alwayson_scripts'].update(pack_control_net(cn_units))
|
|
|
|
class ADetailerAdapter(Adapter):
|
|
def __init__(self, *args):
|
|
super().__init__()
|
|
self.title = "ADetailer"
|
|
|
|
def early(self, p, world, script, *args):
|
|
super().early(p, world, script)
|
|
adetailer_args = self.script_args(p)
|
|
|
|
# InputAccordion main toggle, skip img2img toggle
|
|
if adetailer_args[0] and adetailer_args[1]:
|
|
return True
|
|
|
|
def late(self, p, world, payload, *args):
|
|
payload['_ad_orig'] = None # unserializable
|
|
|
|
|
|
class GenericAdapter(Adapter):
|
|
def __init__(self, *args):
|
|
super().__init__()
|
|
self.packed_script_args = [] # list of api formatted per-script argument objects
|
|
# { "script_name": { "args": ["value1", "value2", ...] }
|
|
|
|
def early(self, p, world, script, *args):
|
|
title = script.title()
|
|
# https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514
|
|
args_script_pack = {title: {"args": []}}
|
|
for arg in p.script_args[script.args_from:script.args_to]:
|
|
args_script_pack[title]["args"].append(arg)
|
|
self.packed_script_args.append(args_script_pack)
|
|
|
|
def late(self, p, world, payload, *args):
|
|
for packed in self.packed_script_args:
|
|
payload['alwayson_scripts'].update(packed)
|
|
|
|
|
|
if payload.get('init_images_original_md') is not None: # multidiffusion
|
|
payload['init_images_original_md'] = None
|
|
|
|
# for key in payload:
|
|
# contains_dict = any(isinstance(payload[key], Image) for item in payload)
|
|
# if isinstance(payload[key], Image):
|
|
# logger.warning(f"will not serialize PIL image in key '{key}'")
|
|
# del payload[key]
|
|
|
|
|
|
|
|
adapters = [ControlNetAdapter(), ADetailerAdapter(), DynamicPromptsAdapter()]
|