fix controlnet inplace modification
parent
aa8bc18eb3
commit
f5d396b79b
|
|
@ -160,24 +160,34 @@ def deserialize_img2img_image_args(args: dict):
|
|||
|
||||
def serialize_controlnet_args(cnet_unit):
|
||||
args: dict = cnet_unit.__dict__
|
||||
args["is_cnet"] = True
|
||||
new_args = {}
|
||||
new_args["is_cnet"] = True
|
||||
for k, v in args.items():
|
||||
if k == "image" and hasattr(v, "image") and v.image is not None:
|
||||
print("serialize image for controlnet")
|
||||
args[k] = {
|
||||
"image": serialize_image(v["image"]),
|
||||
"mask": serialize_image(v["mask"])
|
||||
if v.get("mask", None) is not None
|
||||
else None,
|
||||
}
|
||||
if k == 'image':
|
||||
if hasattr(v, "image") and v.image is not None:
|
||||
print("serialize image for controlnet")
|
||||
new_args[k] = {
|
||||
"image": serialize_image(v["image"]),
|
||||
"mask": serialize_image(v["mask"])
|
||||
if v.get("mask", None) is not None
|
||||
else None,
|
||||
}
|
||||
elif type(v) is dict and v.get('image', None) is not None:
|
||||
new_args[k] = {
|
||||
"image": serialize_image(v),
|
||||
"mask": None if v.get("mask", None) is None else serialize_image(v["mask"]),
|
||||
}
|
||||
else:
|
||||
print("Fallbacked for argument " + str(k) + " with value " + str(v) + " to serialize_image")
|
||||
new_args[k] = serialize_image(v)
|
||||
elif k == 'image' and isinstance(v, (np.ndarray, torch.Tensor, Image.Image)):
|
||||
args[k] = serialize_image(v)
|
||||
new_args[k] = serialize_image(v)
|
||||
elif isinstance(v, Enum):
|
||||
args[k] = serialize_image(v.value)
|
||||
new_args[k] = serialize_image(v.value)
|
||||
else:
|
||||
args[k] = serialize_image(v)
|
||||
new_args[k] = serialize_image(v)
|
||||
|
||||
return args
|
||||
return new_args
|
||||
|
||||
|
||||
def deserialize_controlnet_args(args: dict):
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ class TaskRunner:
|
|||
return serialize_image(obj)
|
||||
# controlnet
|
||||
elif self.UiControlNetUnit and isinstance(obj, self.UiControlNetUnit):
|
||||
return serialize_controlnet_args(obj)
|
||||
return self.recursively_serialize(serialize_controlnet_args(obj))
|
||||
else:
|
||||
# check json.dumps
|
||||
return obj
|
||||
|
|
@ -163,12 +163,7 @@ class TaskRunner:
|
|||
# loop through script_args and serialize images
|
||||
serialized_args:list = [None] * len(script_args)
|
||||
for i, a in enumerate(script_args):
|
||||
if isinstance(a, (Image.Image, ndarray, Tensor)):
|
||||
serialized_args[i] = serialize_image(a)
|
||||
elif self.UiControlNetUnit and isinstance(a, self.UiControlNetUnit):
|
||||
serialized_args[i] = serialize_controlnet_args(a)
|
||||
else:
|
||||
serialized_args[i] = self.recursively_serialize(a)
|
||||
serialized_args[i] = self.recursively_serialize(a)
|
||||
# assert each arguments is serializable
|
||||
check_args = [named_args, serialized_args, checkpoint]
|
||||
args_name = ["named_args", "script_args", "checkpoint"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue