fix controlnet inplace modification

pull/68/head
aria1th 2023-07-09 13:05:34 +09:00
parent aa8bc18eb3
commit f5d396b79b
2 changed files with 25 additions and 20 deletions

View File

@ -160,24 +160,34 @@ def deserialize_img2img_image_args(args: dict):
def serialize_controlnet_args(cnet_unit):
args: dict = cnet_unit.__dict__
args["is_cnet"] = True
new_args = {}
new_args["is_cnet"] = True
for k, v in args.items():
if k == "image" and hasattr(v, "image") and v.image is not None:
print("serialize image for controlnet")
args[k] = {
"image": serialize_image(v["image"]),
"mask": serialize_image(v["mask"])
if v.get("mask", None) is not None
else None,
}
if k == 'image':
if hasattr(v, "image") and v.image is not None:
print("serialize image for controlnet")
new_args[k] = {
"image": serialize_image(v["image"]),
"mask": serialize_image(v["mask"])
if v.get("mask", None) is not None
else None,
}
elif type(v) is dict and v.get('image', None) is not None:
new_args[k] = {
"image": serialize_image(v),
"mask": None if v.get("mask", None) is None else serialize_image(v["mask"]),
}
else:
print("Fallbacked for argument " + str(k) + " with value " + str(v) + " to serialize_image")
new_args[k] = serialize_image(v)
elif k == 'image' and isinstance(v, (np.ndarray, torch.Tensor, Image.Image)):
args[k] = serialize_image(v)
new_args[k] = serialize_image(v)
elif isinstance(v, Enum):
args[k] = serialize_image(v.value)
new_args[k] = serialize_image(v.value)
else:
args[k] = serialize_image(v)
new_args[k] = serialize_image(v)
return args
return new_args
def deserialize_controlnet_args(args: dict):

View File

@ -127,7 +127,7 @@ class TaskRunner:
return serialize_image(obj)
# controlnet
elif self.UiControlNetUnit and isinstance(obj, self.UiControlNetUnit):
return serialize_controlnet_args(obj)
return self.recursively_serialize(serialize_controlnet_args(obj))
else:
# check json.dumps
return obj
@ -163,12 +163,7 @@ class TaskRunner:
# loop through script_args and serialize images
serialized_args:list = [None] * len(script_args)
for i, a in enumerate(script_args):
if isinstance(a, (Image.Image, ndarray, Tensor)):
serialized_args[i] = serialize_image(a)
elif self.UiControlNetUnit and isinstance(a, self.UiControlNetUnit):
serialized_args[i] = serialize_controlnet_args(a)
else:
serialized_args[i] = self.recursively_serialize(a)
serialized_args[i] = self.recursively_serialize(a)
# assert each arguments is serializable
check_args = [named_args, serialized_args, checkpoint]
args_name = ["named_args", "script_args", "checkpoint"]