diff --git a/agent_scheduler/task_helpers.py b/agent_scheduler/task_helpers.py index a991fe1..9737997 100644 --- a/agent_scheduler/task_helpers.py +++ b/agent_scheduler/task_helpers.py @@ -162,28 +162,38 @@ def serialize_controlnet_args(cnet_unit): args: dict = cnet_unit.__dict__ args["is_cnet"] = True for k, v in args.items(): - if k == "image" and v is not None: + if k == "image" and hasattr(v, "image") and v.image is not None: + print("serialize image for controlnet") args[k] = { "image": serialize_image(v["image"]), "mask": serialize_image(v["mask"]) if v.get("mask", None) is not None else None, } - if isinstance(v, Enum): - args[k] = v.value + elif k == 'image' and isinstance(v, (np.ndarray, torch.Tensor, Image.Image)): + args[k] = serialize_image(v) + elif isinstance(v, Enum): + args[k] = serialize_image(v.value) + else: + args[k] = serialize_image(v) return args def deserialize_controlnet_args(args: dict): for k, v in args.items(): - if k == "image" and v is not None: + if k == "image" and hasattr(v, "image") and v.image is not None: + print("deserialize image for controlnet") args[k] = { "image": deserialize_image(v["image"]), "mask": deserialize_image(v["mask"]) if v.get("mask", None) is not None else None, } + elif isinstance(v, dict) and v.get("cls", None) in ["Image", "ndarray", "Tensor"]: + args[k] = deserialize_image(v) + else: + args[k] = v return args diff --git a/agent_scheduler/task_runner.py b/agent_scheduler/task_runner.py index 82c7bd8..2a89879 100644 --- a/agent_scheduler/task_runner.py +++ b/agent_scheduler/task_runner.py @@ -65,7 +65,6 @@ class ParsedTaskArgs(BaseModel): script_args: list[Any] checkpoint: Optional[str] = None - class TaskRunner: instance = None @@ -105,6 +104,52 @@ class TaskRunner: @property def paused(self) -> bool: return getattr(shared.opts, "queue_paused", False) + + def recursively_serialize(self, obj): + """ + Recursively serialize an object to JSON + """ + # dict + if isinstance(obj, dict): + new_obj = {} + for k, v in obj.items(): + assert k not in new_obj, "Cannot serialize recursive dict" + new_obj[k] = self.recursively_serialize(v) + return new_obj + elif isinstance(obj, list): + new_obj = [] + for v in obj: + assert v is not obj, "Cannot serialize recursive list" + new_obj.append(self.recursively_serialize(v)) + return new_obj + # image or tensor or ndarray + elif isinstance(obj, (Image.Image, Tensor, ndarray)): + return serialize_image(obj) + # controlnet + elif self.UiControlNetUnit and isinstance(obj, self.UiControlNetUnit): + return serialize_controlnet_args(obj) + else: + # check json.dumps + return obj + + def recursively_deserialize(self, obj): + """ + Recursively deserialize an object from JSON + """ + if isinstance(obj, dict): + new_obj = {} + for k, v in obj.items(): + new_obj[k] = self.recursively_deserialize(v) + return new_obj + elif isinstance(obj, list): + new_obj = [] + for v in obj: + new_obj.append(self.recursively_deserialize(v)) + return new_obj + elif isinstance(obj, dict) and obj.get("is_cnet", False): + return deserialize_controlnet_args(obj) + else: + return deserialize_image(obj) def __serialize_ui_task_args(self, is_img2img: bool, *args, checkpoint: str = None): named_args, script_args = map_ui_task_args_list_to_named_args( @@ -116,16 +161,27 @@ class TaskRunner: serialize_img2img_image_args(named_args) # loop through script_args and serialize images + serialized_args:list = [None] * len(script_args) for i, a in enumerate(script_args): if isinstance(a, (Image.Image, ndarray, Tensor)): - script_args[i] = serialize_image(a) + serialized_args[i] = serialize_image(a) elif self.UiControlNetUnit and isinstance(a, self.UiControlNetUnit): - script_args[i] = serialize_controlnet_args(a) - + serialized_args[i] = serialize_controlnet_args(a) + else: + serialized_args[i] = self.recursively_serialize(a) + # assert each arguments is serializable + check_args = [named_args, serialized_args, checkpoint] + args_name = ["named_args", "script_args", "checkpoint"] + for args, name in zip(check_args, args_name): + try: + json.dumps(args) + except Exception as e: + print(f"Cannot serialize args: {args} with name: {name}") + raise e return json.dumps( { "args": named_args, - "script_args": script_args, + "script_args": serialized_args, "checkpoint": checkpoint, "is_ui": True, "is_img2img": is_img2img, @@ -159,16 +215,26 @@ class TaskRunner: def __deserialize_ui_task_args( self, is_img2img: bool, named_args: dict, script_args: list ): + """ + Deserialize UI task arguments + In-place update named_args and script_args + """ # loop through image_args and deserialize images if is_img2img: deserialize_img2img_image_args(named_args) - + + deserialized_args:list = [None] * len(script_args) # loop through script_args and deserialize images for i, arg in enumerate(script_args): if isinstance(arg, dict) and arg.get("is_cnet", False): - script_args[i] = deserialize_controlnet_args(arg) + deserialized_args[i] = deserialize_controlnet_args(arg) elif isinstance(arg, dict) and arg.get("cls", "") in {"Image", "ndarray", "Tensor"}: - script_args[i] = deserialize_image(arg) + deserialized_args[i] = deserialize_image(arg) + else: + deserialized_args[i] = self.recursively_deserialize(arg) + for i, arg in enumerate(deserialized_args): + script_args[i] = arg + def __deserialize_api_task_args(self, is_img2img: bool, named_args: dict): # load images from disk