173 lines
6.0 KiB
Python
173 lines
6.0 KiB
Python
from typing import Any, Dict, List
|
|
import unittest
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
import importlib
|
|
|
|
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
|
|
utils.setup_test_env()
|
|
|
|
from scripts import external_code, processor
|
|
from scripts.controlnet import prepare_mask, Script, set_numpy_seed
|
|
from modules import processing
|
|
|
|
|
|
class TestPrepareMask(unittest.TestCase):
|
|
def test_prepare_mask(self):
|
|
p = processing.StableDiffusionProcessing()
|
|
p.inpainting_mask_invert = True
|
|
p.mask_blur = 5
|
|
|
|
mask = Image.new("RGB", (10, 10), color="white")
|
|
|
|
processed_mask = prepare_mask(mask, p)
|
|
|
|
# Check that mask is correctly converted to grayscale
|
|
self.assertTrue(processed_mask.mode, "L")
|
|
|
|
# Check that mask colors are correctly inverted
|
|
self.assertEqual(
|
|
processed_mask.getpixel((0, 0)), 0
|
|
) # inverted white should be black
|
|
|
|
p.inpainting_mask_invert = False
|
|
processed_mask = prepare_mask(mask, p)
|
|
|
|
# Check that mask colors are not inverted when 'inpainting_mask_invert' is False
|
|
self.assertEqual(
|
|
processed_mask.getpixel((0, 0)), 255
|
|
) # white should remain white
|
|
|
|
p.mask_blur = 0
|
|
mask = Image.new("RGB", (10, 10), color="black")
|
|
processed_mask = prepare_mask(mask, p)
|
|
|
|
# Check that mask is not blurred when 'mask_blur' is 0
|
|
self.assertEqual(
|
|
processed_mask.getpixel((0, 0)), 0
|
|
) # black should remain black
|
|
|
|
|
|
class TestSetNumpySeed(unittest.TestCase):
|
|
def test_seed_subseed_minus_one(self):
|
|
p = processing.StableDiffusionProcessing()
|
|
p.seed = -1
|
|
p.subseed = -1
|
|
p.all_seeds = [123, 456]
|
|
expected_seed = (123 + 123) & 0xFFFFFFFF
|
|
self.assertEqual(set_numpy_seed(p), expected_seed)
|
|
|
|
def test_valid_seed_subseed(self):
|
|
p = processing.StableDiffusionProcessing()
|
|
p.seed = 50
|
|
p.subseed = 100
|
|
p.all_seeds = [123, 456]
|
|
expected_seed = (50 + 100) & 0xFFFFFFFF
|
|
self.assertEqual(set_numpy_seed(p), expected_seed)
|
|
|
|
def test_invalid_seed_subseed(self):
|
|
p = processing.StableDiffusionProcessing()
|
|
p.seed = "invalid"
|
|
p.subseed = 2.5
|
|
p.all_seeds = [123, 456]
|
|
self.assertEqual(set_numpy_seed(p), None)
|
|
|
|
def test_empty_all_seeds(self):
|
|
p = processing.StableDiffusionProcessing()
|
|
p.seed = -1
|
|
p.subseed = 2
|
|
p.all_seeds = []
|
|
self.assertEqual(set_numpy_seed(p), None)
|
|
|
|
def test_random_state_change(self):
|
|
p = processing.StableDiffusionProcessing()
|
|
p.seed = 50
|
|
p.subseed = 100
|
|
p.all_seeds = [123, 456]
|
|
expected_seed = (50 + 100) & 0xFFFFFFFF
|
|
|
|
np.random.seed(0) # set a known seed
|
|
before_random = np.random.randint(0, 1000) # get a random integer
|
|
|
|
seed = set_numpy_seed(p)
|
|
self.assertEqual(seed, expected_seed)
|
|
|
|
after_random = np.random.randint(0, 1000) # get another random integer
|
|
|
|
self.assertNotEqual(before_random, after_random)
|
|
|
|
|
|
class MockImg2ImgProcessing(processing.StableDiffusionProcessing):
|
|
"""Mock the Img2Img processing as the WebUI version have dependency on
|
|
`sd_model`."""
|
|
|
|
def __init__(self, init_images, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.init_images = init_images
|
|
|
|
|
|
class TestScript(unittest.TestCase):
|
|
sample_base64_image = (
|
|
"data:image/png;base64,"
|
|
"iVBORw0KGgoAAAANSUhEUgAAARMAAAC3CAIAAAC+MS2jAAAAqUlEQVR4nO3BAQ"
|
|
"0AAADCoPdPbQ8HFAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
|
|
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
|
|
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
|
|
"AAAAAAAAAAAAAAAAAAAAAAAA/wZOlAAB5tU+nAAAAABJRU5ErkJggg=="
|
|
)
|
|
|
|
sample_np_image = np.array(
|
|
[[100, 200, 50], [150, 75, 225], [30, 120, 180]], dtype=np.uint8
|
|
)
|
|
|
|
def test_bound_check_params(self):
|
|
def param_required(module: str, param: str) -> bool:
|
|
configs = processor.preprocessor_sliders_config[module]
|
|
config_index = ("processor_res", "threshold_a", "threshold_b").index(param)
|
|
return config_index < len(configs) and configs[config_index] is not None
|
|
|
|
for module in processor.preprocessor_sliders_config.keys():
|
|
for param in ("processor_res", "threshold_a", "threshold_b"):
|
|
with self.subTest(param=param, module=module):
|
|
unit = external_code.ControlNetUnit(
|
|
module=module,
|
|
**{param: -100},
|
|
)
|
|
Script.bound_check_params(unit)
|
|
if param_required(module, param):
|
|
self.assertGreaterEqual(getattr(unit, param), 0)
|
|
else:
|
|
self.assertEqual(getattr(unit, param), -100)
|
|
|
|
def test_choose_input_image(self):
|
|
with self.subTest(name="no image"):
|
|
with self.assertRaises(ValueError):
|
|
Script.choose_input_image(
|
|
p=processing.StableDiffusionProcessing(),
|
|
unit=external_code.ControlNetUnit(),
|
|
idx=0,
|
|
)
|
|
|
|
with self.subTest(name="control net input"):
|
|
_, from_a1111 = Script.choose_input_image(
|
|
p=MockImg2ImgProcessing(init_images=[TestScript.sample_np_image]),
|
|
unit=external_code.ControlNetUnit(
|
|
image=TestScript.sample_base64_image, module="none"
|
|
),
|
|
idx=0,
|
|
)
|
|
self.assertFalse(from_a1111)
|
|
|
|
with self.subTest(name="A1111 input"):
|
|
_, from_a1111 = Script.choose_input_image(
|
|
p=MockImg2ImgProcessing(init_images=[TestScript.sample_np_image]),
|
|
unit=external_code.ControlNetUnit(module="none"),
|
|
idx=0,
|
|
)
|
|
self.assertTrue(from_a1111)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|