diff --git a/agent_scheduler/task_helpers.py b/agent_scheduler/task_helpers.py index 9737997..056d237 100644 --- a/agent_scheduler/task_helpers.py +++ b/agent_scheduler/task_helpers.py @@ -160,24 +160,34 @@ def deserialize_img2img_image_args(args: dict): def serialize_controlnet_args(cnet_unit): args: dict = cnet_unit.__dict__ - args["is_cnet"] = True + new_args = {} + new_args["is_cnet"] = True for k, v in args.items(): - if k == "image" and hasattr(v, "image") and v.image is not None: - print("serialize image for controlnet") - args[k] = { - "image": serialize_image(v["image"]), - "mask": serialize_image(v["mask"]) - if v.get("mask", None) is not None - else None, - } + if k == 'image': + if hasattr(v, "image") and v.image is not None: + print("serialize image for controlnet") + new_args[k] = { + "image": serialize_image(v["image"]), + "mask": serialize_image(v["mask"]) + if v.get("mask", None) is not None + else None, + } + elif type(v) is dict and v.get('image', None) is not None: + new_args[k] = { + "image": serialize_image(v), + "mask": None if v.get("mask", None) is None else serialize_image(v["mask"]), + } + else: + print("Fallbacked for argument " + str(k) + " with value " + str(v) + " to serialize_image") + new_args[k] = serialize_image(v) elif k == 'image' and isinstance(v, (np.ndarray, torch.Tensor, Image.Image)): - args[k] = serialize_image(v) + new_args[k] = serialize_image(v) elif isinstance(v, Enum): - args[k] = serialize_image(v.value) + new_args[k] = serialize_image(v.value) else: - args[k] = serialize_image(v) + new_args[k] = serialize_image(v) - return args + return new_args def deserialize_controlnet_args(args: dict): diff --git a/agent_scheduler/task_runner.py b/agent_scheduler/task_runner.py index 2a89879..594263d 100644 --- a/agent_scheduler/task_runner.py +++ b/agent_scheduler/task_runner.py @@ -127,7 +127,7 @@ class TaskRunner: return serialize_image(obj) # controlnet elif self.UiControlNetUnit and isinstance(obj, self.UiControlNetUnit): - return serialize_controlnet_args(obj) + return self.recursively_serialize(serialize_controlnet_args(obj)) else: # check json.dumps return obj @@ -163,12 +163,7 @@ class TaskRunner: # loop through script_args and serialize images serialized_args:list = [None] * len(script_args) for i, a in enumerate(script_args): - if isinstance(a, (Image.Image, ndarray, Tensor)): - serialized_args[i] = serialize_image(a) - elif self.UiControlNetUnit and isinstance(a, self.UiControlNetUnit): - serialized_args[i] = serialize_controlnet_args(a) - else: - serialized_args[i] = self.recursively_serialize(a) + serialized_args[i] = self.recursively_serialize(a) # assert each arguments is serializable check_args = [named_args, serialized_args, checkpoint] args_name = ["named_args", "script_args", "checkpoint"]