diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 51f19e2..99da3b9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,6 +4,9 @@ on: - push - pull_request +env: + FORGE_CQ_TEST: "True" + jobs: build: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 820d680..b54b90d 100644 --- a/.gitignore +++ b/.gitignore @@ -175,6 +175,8 @@ web_tests/results/ web_tests/expectations/ tests/web_api/full_coverage/results/ tests/web_api/full_coverage/expectations/ +tests/web_api/results/ +tests/web_api/expectations/ *_diff.png diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c8792cd --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +import os + + +def pytest_configure(config): + # We don't want to fail on Py.test command line arguments being + # parsed by webui: + os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1") diff --git a/tests/web_api/detect_test.py b/tests/web_api/detect_test.py index f10c41e..e51c398 100644 --- a/tests/web_api/detect_test.py +++ b/tests/web_api/detect_test.py @@ -1,47 +1,92 @@ +import pytest import requests -import unittest -import importlib -utils = importlib.import_module( - 'extensions.sd-webui-controlnet.tests.utils', 'utils') +from typing import List + +from .template import ( + APITestTemplate, + realistic_girl_face_img, + save_base64, + get_dest_dir, + disable_in_cq, +) -class TestDetectEndpointWorking(unittest.TestCase): - def setUp(self): - self.base_detect_args = { - "controlnet_module": "canny", - "controlnet_input_images": [utils.readImage("test/test_files/img2img_basic.png")], - "controlnet_processor_res": 512, - "controlnet_threshold_a": 0, - "controlnet_threshold_b": 0, - } - - def test_detect_with_invalid_module_performed(self): - detect_args = self.base_detect_args.copy() - detect_args.update({ - "controlnet_module": "INVALID", - }) - self.assertEqual(utils.detect(detect_args).status_code, 422) - - def test_detect_with_no_input_images_performed(self): - detect_args = self.base_detect_args.copy() - detect_args.update({ - "controlnet_input_images": [], - }) - self.assertEqual(utils.detect(detect_args).status_code, 422) - - def test_detect_with_valid_args_performed(self): - detect_args = self.base_detect_args - response = utils.detect(detect_args) - - self.assertEqual(response.status_code, 200) - - def test_detect_invert(self): - detect_args = self.base_detect_args.copy() - detect_args["controlnet_module"] = "invert" - response = utils.detect(detect_args) - self.assertEqual(response.status_code, 200) - self.assertNotEqual(response.json()['images'], [""]) +def get_modules() -> List[str]: + return requests.get(APITestTemplate.BASE_URL + "controlnet/module_list").json()[ + "module_list" + ] -if __name__ == "__main__": - unittest.main() +def detect_template(payload, output_name: str, status: int = 200): + url = APITestTemplate.BASE_URL + "controlnet/detect" + resp = requests.post(url, json=payload) + assert resp.status_code == status + if status != 200: + return + + resp_json = resp.json() + assert "images" in resp_json + assert len(resp_json["images"]) == len(payload["controlnet_input_images"]) + if not APITestTemplate.is_cq_run: + for i, img in enumerate(resp_json["images"]): + if img == "Detect result is not image": + continue + dest = get_dest_dir() / f"{output_name}_{i}.png" + save_base64(img, dest) + return resp_json + + +# Need to allow detect of CLIP preprocessor result. +# https://github.com/Mikubill/sd-webui-controlnet/pull/2590 +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[clip_vision] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589ADD1210> +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[revision_clipvision] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AFB00E0> +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[revision_ignore_prompt] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF3C9A0> +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sd15] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF5B740> +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sdxl_plus_vith] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF3D0D0> +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sdxl] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x00000158FF7753F0> +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_face_id] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589B0414E0> +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_face_id_plus] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AEE3100> +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[instant_id_face_embedding] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AFF6CF0> + +# https://github.com/Mikubill/sd-webui-controlnet/issues/2693 +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[segmentation] - assert 500 == 200 + +# TODO: file issue on these failures. +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[depth_zoe] - assert 500 == 200 +# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[inpaint_only+lama] - assert 500 == 200 +@disable_in_cq +@pytest.mark.parametrize("module", get_modules()) +def test_detect_all_modules(module: str): + payload = dict( + controlnet_input_images=[realistic_girl_face_img], + controlnet_module=module, + ) + detect_template(payload, f"detect_{module}") + + +def test_detect_simple(): + detect_template( + dict( + controlnet_input_images=[realistic_girl_face_img], + controlnet_module="canny", # Canny does not require model download. + ), + "simple_detect", + ) + + +def test_detect_multiple_inputs(): + detect_template( + dict( + controlnet_input_images=[realistic_girl_face_img, realistic_girl_face_img], + controlnet_module="canny", # Canny does not require model download. + ), + "multiple_inputs_detect", + ) + + +def test_detect_with_invalid_module(): + detect_template({"controlnet_module": "INVALID"}, "invalid module", 422) + + +def test_detect_with_no_input_images(): + detect_template({"controlnet_input_images": []}, "invalid module", 422) diff --git a/tests/web_api/template.py b/tests/web_api/template.py new file mode 100644 index 0000000..bdd0235 --- /dev/null +++ b/tests/web_api/template.py @@ -0,0 +1,360 @@ +import io +import os +import cv2 +import base64 +import functools +from typing import Dict, Any, List, Union, Literal, Optional +from pathlib import Path +import datetime +from enum import Enum +import numpy as np +import pytest + +import requests +from PIL import Image + + +def disable_in_cq(func): + """Skips the decorated test func in CQ run.""" + + @functools.wraps(func) + def wrapped_func(*args, **kwargs): + if APITestTemplate.is_cq_run: + pytest.skip() + return func(*args, **kwargs) + + return wrapped_func + + +PayloadOverrideType = Dict[str, Any] + +timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") +test_result_dir = Path(__file__).parent / "results" / f"test_result_{timestamp}" +test_expectation_dir = Path(__file__).parent / "expectations" +os.makedirs(test_expectation_dir, exist_ok=True) +resource_dir = Path(__file__).parents[1] / "images" + + +def get_dest_dir(): + if APITestTemplate.is_set_expectation_run: + return test_expectation_dir + else: + return test_result_dir + + +def save_base64(base64img: str, dest: Path): + Image.open(io.BytesIO(base64.b64decode(base64img.split(",", 1)[0]))).save(dest) + + +def read_image(img_path: Path) -> str: + img = cv2.imread(str(img_path)) + _, bytes = cv2.imencode(".png", img) + encoded_image = base64.b64encode(bytes).decode("utf-8") + return encoded_image + + +def read_image_dir( + img_dir: Path, suffixes=(".png", ".jpg", ".jpeg", ".webp") +) -> List[str]: + """Try read all images in given img_dir.""" + img_dir = str(img_dir) + images = [] + for filename in os.listdir(img_dir): + if filename.endswith(suffixes): + img_path = os.path.join(img_dir, filename) + try: + images.append(read_image(img_path)) + except IOError: + print(f"Error opening {img_path}") + return images + + +girl_img = read_image(resource_dir / "1girl.png") +mask_img = read_image(resource_dir / "mask.png") +mask_small_img = read_image(resource_dir / "mask_small.png") +portrait_imgs = read_image_dir(resource_dir / "portrait") +realistic_girl_face_img = portrait_imgs[0] + + +general_negative_prompt = """ +(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, +((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, +backlight,(ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), +(tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, (bad anatomy:1.21), +(bad proportions:1.331), extra limbs, (missing arms:1.331), (extra legs:1.331), +(fused fingers:1.61051), (too many fingers:1.61051), (unclear eyes:1.331), bad hands, +missing fingers, extra digit, bad body, easynegative, nsfw""" + + +class StableDiffusionVersion(Enum): + """The version family of stable diffusion model.""" + + UNKNOWN = 0 + SD1x = 1 + SD2x = 2 + SDXL = 3 + + +sd_version = StableDiffusionVersion( + int(os.environ.get("CONTROLNET_TEST_SD_VERSION", StableDiffusionVersion.SD1x.value)) +) + +is_full_coverage = os.environ.get("CONTROLNET_TEST_FULL_COVERAGE", None) is not None + + +class APITestTemplate: + is_set_expectation_run = os.environ.get("CONTROLNET_SET_EXP", "True") == "True" + is_cq_run = os.environ.get("FORGE_CQ_TEST", "False") == "True" + BASE_URL = "http://localhost:7860/" + + def __init__( + self, + name: str, + gen_type: Union[Literal["img2img"], Literal["txt2img"]], + payload_overrides: PayloadOverrideType, + unit_overrides: Union[PayloadOverrideType, List[PayloadOverrideType]], + input_image: Optional[str] = None, + ): + self.name = name + self.url = APITestTemplate.BASE_URL + "sdapi/v1/" + gen_type + self.payload = { + **(txt2img_payload if gen_type == "txt2img" else img2img_payload), + **payload_overrides, + } + if gen_type == "img2img" and input_image is not None: + self.payload["init_images"] = [input_image] + + # CQ runs on CPU. Reduce steps to increase test speed. + if "steps" not in payload_overrides and APITestTemplate.is_cq_run: + self.payload["steps"] = 3 + + unit_overrides = ( + unit_overrides + if isinstance(unit_overrides, (list, tuple)) + else [unit_overrides] + ) + self.payload["alwayson_scripts"]["ControlNet"]["args"] = [ + { + **default_unit, + **unit_override, + **( + {"image": input_image} + if gen_type == "txt2img" and input_image is not None + else {} + ), + } + for unit_override in unit_overrides + ] + self.active_unit_count = len(unit_overrides) + + def exec(self, *args, **kwargs) -> bool: + if APITestTemplate.is_cq_run: + return self.exec_cq(*args, **kwargs) + else: + return self.exec_local(*args, **kwargs) + + def exec_cq( + self, expected_output_num: Optional[int] = None, *args, **kwargs + ) -> bool: + """Execute test in CQ environment.""" + res = requests.post(url=self.url, json=self.payload) + if res.status_code != 200: + print(f"Unexpected status code {res.status_code}") + return False + + response = res.json() + if "images" not in response: + print(response.keys()) + return False + + if expected_output_num is None: + expected_output_num = ( + self.payload["n_iter"] * self.payload["batch_size"] + + self.active_unit_count + ) + + if len(response["images"]) != expected_output_num: + print(f"{len(response['images'])} != {expected_output_num}") + return False + + return True + + def exec_local(self, result_only: bool = True, *args, **kwargs) -> bool: + """Execute test in local environment.""" + if not APITestTemplate.is_set_expectation_run: + os.makedirs(test_result_dir, exist_ok=True) + + failed = False + + response = requests.post(url=self.url, json=self.payload).json() + if "images" not in response: + print(response.keys()) + return False + + dest_dir = get_dest_dir() + results = response["images"][:1] if result_only else response["images"] + for i, base64image in enumerate(results): + img_file_name = f"{self.name}_{i}.png" + save_base64(base64image, dest_dir / img_file_name) + + if not APITestTemplate.is_set_expectation_run: + try: + img1 = cv2.imread(os.path.join(test_expectation_dir, img_file_name)) + img2 = cv2.imread(os.path.join(test_result_dir, img_file_name)) + except Exception as e: + print(f"Get exception reading imgs: {e}") + failed = True + continue + + if img1 is None: + print(f"Warn: No expectation file found {img_file_name}.") + continue + + if not expect_same_image( + img1, + img2, + diff_img_path=str( + test_result_dir / img_file_name.replace(".png", "_diff.png") + ), + ): + failed = True + return not failed + + +def expect_same_image(img1, img2, diff_img_path: str) -> bool: + # Calculate the difference between the two images + diff = cv2.absdiff(img1, img2) + + # Set a threshold to highlight the different pixels + threshold = 30 + diff_highlighted = np.where(diff > threshold, 255, 0).astype(np.uint8) + + # Assert that the two images are similar within a tolerance + similar = np.allclose(img1, img2, rtol=0.5, atol=1) + if not similar: + # Save the diff_highlighted image to inspect the differences + cv2.imwrite(diff_img_path, diff_highlighted) + + matching_pixels = np.isclose(img1, img2, rtol=0.5, atol=1) + similar_in_general = (matching_pixels.sum() / matching_pixels.size) >= 0.95 + return similar_in_general + + +def get_model(model_name: str) -> str: + """Find an available model with specified model name.""" + if model_name.lower() == "none": + return "None" + + r = requests.get(APITestTemplate.BASE_URL + "controlnet/model_list") + result = r.json() + if "model_list" not in result: + raise ValueError("No model available") + + candidates = [ + model for model in result["model_list"] if model_name.lower() in model.lower() + ] + + if not candidates: + raise ValueError("No suitable model available") + + return candidates[0] + + +default_unit = { + "control_mode": 0, + "enabled": True, + "guidance_end": 1, + "guidance_start": 0, + "pixel_perfect": True, + "processor_res": 512, + "resize_mode": 1, + "threshold_a": 64, + "threshold_b": 64, + "weight": 1, + "module": "canny", + "model": get_model("sd15_canny"), +} + +img2img_payload = { + "batch_size": 1, + "cfg_scale": 7, + "height": 768, + "width": 512, + "n_iter": 1, + "steps": 10, + "sampler_name": "Euler a", + "prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,", + "negative_prompt": "", + "seed": 42, + "seed_enable_extras": False, + "seed_resize_from_h": 0, + "seed_resize_from_w": 0, + "subseed": -1, + "subseed_strength": 0, + "override_settings": {}, + "override_settings_restore_afterwards": False, + "do_not_save_grid": False, + "do_not_save_samples": False, + "s_churn": 0, + "s_min_uncond": 0, + "s_noise": 1, + "s_tmax": None, + "s_tmin": 0, + "script_args": [], + "script_name": None, + "styles": [], + "alwayson_scripts": {"ControlNet": {"args": [default_unit]}}, + "denoising_strength": 0.75, + "initial_noise_multiplier": 1, + "inpaint_full_res": 0, + "inpaint_full_res_padding": 32, + "inpainting_fill": 1, + "inpainting_mask_invert": 0, + "mask_blur_x": 4, + "mask_blur_y": 4, + "mask_blur": 4, + "resize_mode": 0, +} + +txt2img_payload = { + "alwayson_scripts": {"ControlNet": {"args": [default_unit]}}, + "batch_size": 1, + "cfg_scale": 7, + "comments": {}, + "disable_extra_networks": False, + "do_not_save_grid": False, + "do_not_save_samples": False, + "enable_hr": False, + "height": 768, + "hr_negative_prompt": "", + "hr_prompt": "", + "hr_resize_x": 0, + "hr_resize_y": 0, + "hr_scale": 2, + "hr_second_pass_steps": 0, + "hr_upscaler": "Latent", + "n_iter": 1, + "negative_prompt": "", + "override_settings": {}, + "override_settings_restore_afterwards": True, + "prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,", + "restore_faces": False, + "s_churn": 0.0, + "s_min_uncond": 0, + "s_noise": 1.0, + "s_tmax": None, + "s_tmin": 0.0, + "sampler_name": "Euler a", + "script_args": [], + "script_name": None, + "seed": 42, + "seed_enable_extras": True, + "seed_resize_from_h": -1, + "seed_resize_from_w": -1, + "steps": 10, + "styles": [], + "subseed": -1, + "subseed_strength": 0, + "tiling": False, + "width": 512, +}