95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
import unittest
|
|
import cv2
|
|
import numpy as np
|
|
from typing import Dict
|
|
|
|
|
|
import importlib
|
|
utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils')
|
|
utils.setup_test_env()
|
|
|
|
from annotator.openpose import OpenposeDetector
|
|
|
|
class TestOpenposeDetector(unittest.TestCase):
|
|
image_path = './tests/images'
|
|
def setUp(self) -> None:
|
|
self.detector = OpenposeDetector()
|
|
self.detector.load_model()
|
|
|
|
def tearDown(self) -> None:
|
|
self.detector.unload_model()
|
|
|
|
def expect_same_image(self, img1, img2, diff_img_path: str):
|
|
# Calculate the difference between the two images
|
|
diff = cv2.absdiff(img1, img2)
|
|
|
|
# Set a threshold to highlight the different pixels
|
|
threshold = 30
|
|
diff_highlighted = np.where(diff > threshold, 255, 0).astype(np.uint8)
|
|
|
|
# Assert that the two images are similar within a tolerance
|
|
similar = np.allclose(img1, img2, rtol=1e-05, atol=1e-08)
|
|
if not similar:
|
|
# Save the diff_highlighted image to inspect the differences
|
|
cv2.imwrite(diff_img_path, diff_highlighted)
|
|
|
|
self.assertTrue(similar)
|
|
|
|
# Save expectation image as png so that no compression issue happens.
|
|
def template(self, test_image: str, expected_image: str, detector_config: Dict, overwrite_expectation: bool = False):
|
|
oriImg = cv2.imread(test_image)
|
|
canvas = self.detector(oriImg, **detector_config)
|
|
|
|
# Create expectation file
|
|
if overwrite_expectation:
|
|
cv2.imwrite(expected_image, canvas)
|
|
else:
|
|
expected_canvas = cv2.imread(expected_image)
|
|
self.expect_same_image(canvas, expected_canvas, diff_img_path=expected_image.replace('.png', '_diff.png'))
|
|
|
|
def test_body(self):
|
|
self.template(
|
|
test_image = f'{TestOpenposeDetector.image_path}/ski.jpg',
|
|
expected_image = f'{TestOpenposeDetector.image_path}/expected_ski_output.png',
|
|
detector_config=dict(),
|
|
overwrite_expectation=False
|
|
)
|
|
|
|
def test_hand(self):
|
|
self.template(
|
|
test_image = f'{TestOpenposeDetector.image_path}/woman.jpeg',
|
|
expected_image = f'{TestOpenposeDetector.image_path}/expected_woman_hand_output.png',
|
|
detector_config=dict(
|
|
include_body=False,
|
|
include_face=False,
|
|
include_hand=True,
|
|
),
|
|
overwrite_expectation=False
|
|
)
|
|
|
|
def test_face(self):
|
|
self.template(
|
|
test_image = f'{TestOpenposeDetector.image_path}/woman.jpeg',
|
|
expected_image = f'{TestOpenposeDetector.image_path}/expected_woman_face_output.png',
|
|
detector_config=dict(
|
|
include_body=False,
|
|
include_face=True,
|
|
include_hand=False,
|
|
),
|
|
overwrite_expectation=False
|
|
)
|
|
|
|
def test_all(self):
|
|
self.template(
|
|
test_image = f'{TestOpenposeDetector.image_path}/woman.jpeg',
|
|
expected_image = f'{TestOpenposeDetector.image_path}/expected_woman_all_output.png',
|
|
detector_config=dict(
|
|
include_body=True,
|
|
include_face=True,
|
|
include_hand=True,
|
|
),
|
|
overwrite_expectation=False
|
|
)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |