Fix controlnet usage

pull/68/head
aria1th 2023-07-09 01:49:32 +09:00
parent 66d74824ae
commit 8a0bd7b359
2 changed files with 88 additions and 12 deletions

View File

@ -162,28 +162,38 @@ def serialize_controlnet_args(cnet_unit):
args: dict = cnet_unit.__dict__
args["is_cnet"] = True
for k, v in args.items():
if k == "image" and v is not None:
if k == "image" and hasattr(v, "image") and v.image is not None:
print("serialize image for controlnet")
args[k] = {
"image": serialize_image(v["image"]),
"mask": serialize_image(v["mask"])
if v.get("mask", None) is not None
else None,
}
if isinstance(v, Enum):
args[k] = v.value
elif k == 'image' and isinstance(v, (np.ndarray, torch.Tensor, Image.Image)):
args[k] = serialize_image(v)
elif isinstance(v, Enum):
args[k] = serialize_image(v.value)
else:
args[k] = serialize_image(v)
return args
def deserialize_controlnet_args(args: dict):
for k, v in args.items():
if k == "image" and v is not None:
if k == "image" and hasattr(v, "image") and v.image is not None:
print("deserialize image for controlnet")
args[k] = {
"image": deserialize_image(v["image"]),
"mask": deserialize_image(v["mask"])
if v.get("mask", None) is not None
else None,
}
elif isinstance(v, dict) and v.get("cls", None) in ["Image", "ndarray", "Tensor"]:
args[k] = deserialize_image(v)
else:
args[k] = v
return args

View File

@ -65,7 +65,6 @@ class ParsedTaskArgs(BaseModel):
script_args: list[Any]
checkpoint: Optional[str] = None
class TaskRunner:
instance = None
@ -105,6 +104,52 @@ class TaskRunner:
@property
def paused(self) -> bool:
return getattr(shared.opts, "queue_paused", False)
def recursively_serialize(self, obj):
"""
Recursively serialize an object to JSON
"""
# dict
if isinstance(obj, dict):
new_obj = {}
for k, v in obj.items():
assert k not in new_obj, "Cannot serialize recursive dict"
new_obj[k] = self.recursively_serialize(v)
return new_obj
elif isinstance(obj, list):
new_obj = []
for v in obj:
assert v is not obj, "Cannot serialize recursive list"
new_obj.append(self.recursively_serialize(v))
return new_obj
# image or tensor or ndarray
elif isinstance(obj, (Image.Image, Tensor, ndarray)):
return serialize_image(obj)
# controlnet
elif self.UiControlNetUnit and isinstance(obj, self.UiControlNetUnit):
return serialize_controlnet_args(obj)
else:
# check json.dumps
return obj
def recursively_deserialize(self, obj):
"""
Recursively deserialize an object from JSON
"""
if isinstance(obj, dict):
new_obj = {}
for k, v in obj.items():
new_obj[k] = self.recursively_deserialize(v)
return new_obj
elif isinstance(obj, list):
new_obj = []
for v in obj:
new_obj.append(self.recursively_deserialize(v))
return new_obj
elif isinstance(obj, dict) and obj.get("is_cnet", False):
return deserialize_controlnet_args(obj)
else:
return deserialize_image(obj)
def __serialize_ui_task_args(self, is_img2img: bool, *args, checkpoint: str = None):
named_args, script_args = map_ui_task_args_list_to_named_args(
@ -116,16 +161,27 @@ class TaskRunner:
serialize_img2img_image_args(named_args)
# loop through script_args and serialize images
serialized_args:list = [None] * len(script_args)
for i, a in enumerate(script_args):
if isinstance(a, (Image.Image, ndarray, Tensor)):
script_args[i] = serialize_image(a)
serialized_args[i] = serialize_image(a)
elif self.UiControlNetUnit and isinstance(a, self.UiControlNetUnit):
script_args[i] = serialize_controlnet_args(a)
serialized_args[i] = serialize_controlnet_args(a)
else:
serialized_args[i] = self.recursively_serialize(a)
# assert each arguments is serializable
check_args = [named_args, serialized_args, checkpoint]
args_name = ["named_args", "script_args", "checkpoint"]
for args, name in zip(check_args, args_name):
try:
json.dumps(args)
except Exception as e:
print(f"Cannot serialize args: {args} with name: {name}")
raise e
return json.dumps(
{
"args": named_args,
"script_args": script_args,
"script_args": serialized_args,
"checkpoint": checkpoint,
"is_ui": True,
"is_img2img": is_img2img,
@ -159,16 +215,26 @@ class TaskRunner:
def __deserialize_ui_task_args(
self, is_img2img: bool, named_args: dict, script_args: list
):
"""
Deserialize UI task arguments
In-place update named_args and script_args
"""
# loop through image_args and deserialize images
if is_img2img:
deserialize_img2img_image_args(named_args)
deserialized_args:list = [None] * len(script_args)
# loop through script_args and deserialize images
for i, arg in enumerate(script_args):
if isinstance(arg, dict) and arg.get("is_cnet", False):
script_args[i] = deserialize_controlnet_args(arg)
deserialized_args[i] = deserialize_controlnet_args(arg)
elif isinstance(arg, dict) and arg.get("cls", "") in {"Image", "ndarray", "Tensor"}:
script_args[i] = deserialize_image(arg)
deserialized_args[i] = deserialize_image(arg)
else:
deserialized_args[i] = self.recursively_deserialize(arg)
for i, arg in enumerate(deserialized_args):
script_args[i] = arg
def __deserialize_api_task_args(self, is_img2img: bool, named_args: dict):
# load images from disk