Add PuLID VRAM leak reproduction tests (#2891)

* Add PuLID detect test

* Add PuLID generation test
pull/2896/head
Chenlei Hu 2024-05-14 22:15:27 -04:00 committed by GitHub
parent 5d7529915b
commit 58f620c921
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 0 deletions

View File

@ -5,6 +5,7 @@ from typing import List
from .template import (
APITestTemplate,
realistic_girl_face_img,
portrait_imgs,
girl_img,
mask_img,
save_base64,
@ -91,6 +92,17 @@ def test_inpaint_mask(module: str):
detect_template(payload, f"detect_inpaint_mask_{module}")
@disable_in_cq
@pytest.mark.parametrize("img_index", [i for i, _ in enumerate(portrait_imgs)])
def test_pulid(img_index: int):
"""PuLID preprocessor should not memory leak."""
payload = dict(
controlnet_input_images=[portrait_imgs[img_index]],
controlnet_module="ip-adapter_pulid",
)
detect_template(payload, f"detect_pulid_{img_index}")
@pytest.mark.parametrize("module", [m for m in UNSUPPORTED_PREPROCESSORS])
def test_unsupported_modules(module: str):
payload = dict(

View File

@ -2,6 +2,7 @@ import pytest
from .template import (
APITestTemplate,
portrait_imgs,
girl_img,
mask_img,
disable_in_cq,
@ -305,3 +306,22 @@ def test_ip_adapter_auto():
).exec()
assert log_context.is_in_console_logs(["ip-adapter-auto => ip-adapter_clip_h"])
@disable_in_cq
@pytest.mark.parametrize("img_index", [i for i, _ in enumerate(portrait_imgs)])
def test_pulid(img_index: int):
"""PuLID should not memory leak."""
assert APITestTemplate(
f"txt2img_pulid_{img_index}",
"txt2img",
payload_overrides={
"width": 768,
"height": 768,
},
unit_overrides={
"image": portrait_imgs[img_index],
"model": get_model("ip-adapter_pulid_sdxl_fp16"),
"module": "ip-adapter_pulid",
},
).exec()