Fix controlnet usage
parent
66d74824ae
commit
8a0bd7b359
|
|
@ -162,28 +162,38 @@ def serialize_controlnet_args(cnet_unit):
|
|||
args: dict = cnet_unit.__dict__
|
||||
args["is_cnet"] = True
|
||||
for k, v in args.items():
|
||||
if k == "image" and v is not None:
|
||||
if k == "image" and hasattr(v, "image") and v.image is not None:
|
||||
print("serialize image for controlnet")
|
||||
args[k] = {
|
||||
"image": serialize_image(v["image"]),
|
||||
"mask": serialize_image(v["mask"])
|
||||
if v.get("mask", None) is not None
|
||||
else None,
|
||||
}
|
||||
if isinstance(v, Enum):
|
||||
args[k] = v.value
|
||||
elif k == 'image' and isinstance(v, (np.ndarray, torch.Tensor, Image.Image)):
|
||||
args[k] = serialize_image(v)
|
||||
elif isinstance(v, Enum):
|
||||
args[k] = serialize_image(v.value)
|
||||
else:
|
||||
args[k] = serialize_image(v)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def deserialize_controlnet_args(args: dict):
|
||||
for k, v in args.items():
|
||||
if k == "image" and v is not None:
|
||||
if k == "image" and hasattr(v, "image") and v.image is not None:
|
||||
print("deserialize image for controlnet")
|
||||
args[k] = {
|
||||
"image": deserialize_image(v["image"]),
|
||||
"mask": deserialize_image(v["mask"])
|
||||
if v.get("mask", None) is not None
|
||||
else None,
|
||||
}
|
||||
elif isinstance(v, dict) and v.get("cls", None) in ["Image", "ndarray", "Tensor"]:
|
||||
args[k] = deserialize_image(v)
|
||||
else:
|
||||
args[k] = v
|
||||
|
||||
return args
|
||||
|
||||
|
|
|
|||
|
|
@ -65,7 +65,6 @@ class ParsedTaskArgs(BaseModel):
|
|||
script_args: list[Any]
|
||||
checkpoint: Optional[str] = None
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
instance = None
|
||||
|
||||
|
|
@ -105,6 +104,52 @@ class TaskRunner:
|
|||
@property
|
||||
def paused(self) -> bool:
|
||||
return getattr(shared.opts, "queue_paused", False)
|
||||
|
||||
def recursively_serialize(self, obj):
|
||||
"""
|
||||
Recursively serialize an object to JSON
|
||||
"""
|
||||
# dict
|
||||
if isinstance(obj, dict):
|
||||
new_obj = {}
|
||||
for k, v in obj.items():
|
||||
assert k not in new_obj, "Cannot serialize recursive dict"
|
||||
new_obj[k] = self.recursively_serialize(v)
|
||||
return new_obj
|
||||
elif isinstance(obj, list):
|
||||
new_obj = []
|
||||
for v in obj:
|
||||
assert v is not obj, "Cannot serialize recursive list"
|
||||
new_obj.append(self.recursively_serialize(v))
|
||||
return new_obj
|
||||
# image or tensor or ndarray
|
||||
elif isinstance(obj, (Image.Image, Tensor, ndarray)):
|
||||
return serialize_image(obj)
|
||||
# controlnet
|
||||
elif self.UiControlNetUnit and isinstance(obj, self.UiControlNetUnit):
|
||||
return serialize_controlnet_args(obj)
|
||||
else:
|
||||
# check json.dumps
|
||||
return obj
|
||||
|
||||
def recursively_deserialize(self, obj):
|
||||
"""
|
||||
Recursively deserialize an object from JSON
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
new_obj = {}
|
||||
for k, v in obj.items():
|
||||
new_obj[k] = self.recursively_deserialize(v)
|
||||
return new_obj
|
||||
elif isinstance(obj, list):
|
||||
new_obj = []
|
||||
for v in obj:
|
||||
new_obj.append(self.recursively_deserialize(v))
|
||||
return new_obj
|
||||
elif isinstance(obj, dict) and obj.get("is_cnet", False):
|
||||
return deserialize_controlnet_args(obj)
|
||||
else:
|
||||
return deserialize_image(obj)
|
||||
|
||||
def __serialize_ui_task_args(self, is_img2img: bool, *args, checkpoint: str = None):
|
||||
named_args, script_args = map_ui_task_args_list_to_named_args(
|
||||
|
|
@ -116,16 +161,27 @@ class TaskRunner:
|
|||
serialize_img2img_image_args(named_args)
|
||||
|
||||
# loop through script_args and serialize images
|
||||
serialized_args:list = [None] * len(script_args)
|
||||
for i, a in enumerate(script_args):
|
||||
if isinstance(a, (Image.Image, ndarray, Tensor)):
|
||||
script_args[i] = serialize_image(a)
|
||||
serialized_args[i] = serialize_image(a)
|
||||
elif self.UiControlNetUnit and isinstance(a, self.UiControlNetUnit):
|
||||
script_args[i] = serialize_controlnet_args(a)
|
||||
|
||||
serialized_args[i] = serialize_controlnet_args(a)
|
||||
else:
|
||||
serialized_args[i] = self.recursively_serialize(a)
|
||||
# assert each arguments is serializable
|
||||
check_args = [named_args, serialized_args, checkpoint]
|
||||
args_name = ["named_args", "script_args", "checkpoint"]
|
||||
for args, name in zip(check_args, args_name):
|
||||
try:
|
||||
json.dumps(args)
|
||||
except Exception as e:
|
||||
print(f"Cannot serialize args: {args} with name: {name}")
|
||||
raise e
|
||||
return json.dumps(
|
||||
{
|
||||
"args": named_args,
|
||||
"script_args": script_args,
|
||||
"script_args": serialized_args,
|
||||
"checkpoint": checkpoint,
|
||||
"is_ui": True,
|
||||
"is_img2img": is_img2img,
|
||||
|
|
@ -159,16 +215,26 @@ class TaskRunner:
|
|||
def __deserialize_ui_task_args(
|
||||
self, is_img2img: bool, named_args: dict, script_args: list
|
||||
):
|
||||
"""
|
||||
Deserialize UI task arguments
|
||||
In-place update named_args and script_args
|
||||
"""
|
||||
# loop through image_args and deserialize images
|
||||
if is_img2img:
|
||||
deserialize_img2img_image_args(named_args)
|
||||
|
||||
|
||||
deserialized_args:list = [None] * len(script_args)
|
||||
# loop through script_args and deserialize images
|
||||
for i, arg in enumerate(script_args):
|
||||
if isinstance(arg, dict) and arg.get("is_cnet", False):
|
||||
script_args[i] = deserialize_controlnet_args(arg)
|
||||
deserialized_args[i] = deserialize_controlnet_args(arg)
|
||||
elif isinstance(arg, dict) and arg.get("cls", "") in {"Image", "ndarray", "Tensor"}:
|
||||
script_args[i] = deserialize_image(arg)
|
||||
deserialized_args[i] = deserialize_image(arg)
|
||||
else:
|
||||
deserialized_args[i] = self.recursively_deserialize(arg)
|
||||
for i, arg in enumerate(deserialized_args):
|
||||
script_args[i] = arg
|
||||
|
||||
|
||||
def __deserialize_api_task_args(self, is_img2img: bool, named_args: dict):
|
||||
# load images from disk
|
||||
|
|
|
|||
Loading…
Reference in New Issue