Merge pull request #597 from akx/better-testing

Better tests
no-strict-zip^2
Adi Eyal 2023-08-25 22:29:38 +02:00 committed by GitHub
commit d4cc1194b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 272 additions and 77 deletions

View File

@ -12,7 +12,6 @@ import torch
from dynamicprompts.generators.promptgenerator import GeneratorException
from dynamicprompts.parser.parse import ParserConfig
from dynamicprompts.wildcards import WildcardManager
from modules import devices
from modules.processing import fix_seed
from modules.shared import opts
@ -23,6 +22,7 @@ from sd_dynamic_prompts.helpers import (
generate_prompts,
get_seeds,
load_magicprompt_models,
repeat_iterable_to_length,
should_freeze_prompt,
)
from sd_dynamic_prompts.paths import (
@ -48,11 +48,6 @@ def _get_effective_prompt(prompts: list[str], prompt: str) -> str:
return prompts[0] if prompts else prompt
device = devices.device
# There might be a bug in auto1111 where the correct device is not inferred in some scenarios
if device.type == "cuda" and not device.index:
device = torch.device("cuda:0")
loaded_count = 0
@ -69,6 +64,26 @@ def _get_install_error_message() -> str | None:
return None
def _get_hr_fix_prompts(
prompts: list[str],
original_hr_prompt: str,
original_prompt: str,
) -> list[str]:
if original_prompt == original_hr_prompt:
return list(prompts)
return repeat_iterable_to_length([original_hr_prompt], len(prompts))
def get_magic_prompt_device() -> torch.device:
from modules import devices
device = devices.device
# There might be a bug in auto1111 where the correct device is not inferred in some scenarios
if device.type == "cuda" and not device.index:
device = torch.device("cuda:0")
return device
class Script(scripts.Script):
def __init__(self):
global loaded_count
@ -334,23 +349,23 @@ class Script(scripts.Script):
def process(
self,
p,
is_enabled,
is_combinatorial,
combinatorial_batches,
is_magic_prompt,
is_feeling_lucky,
is_attention_grabber,
min_attention,
max_attention,
magic_prompt_length,
magic_temp_value,
use_fixed_seed,
unlink_seed_from_prompt,
disable_negative_prompt,
enable_jinja_templates,
no_image_generation,
max_generations,
magic_model,
is_enabled: bool,
is_combinatorial: bool,
combinatorial_batches: int,
is_magic_prompt: bool,
is_feeling_lucky: bool,
is_attention_grabber: bool,
min_attention: float,
max_attention: float,
magic_prompt_length: int,
magic_temp_value: float,
use_fixed_seed: bool,
unlink_seed_from_prompt: bool,
disable_negative_prompt: bool,
enable_jinja_templates: bool,
no_image_generation: bool,
max_generations: int,
magic_model: str | None,
magic_blocklist_regex: str | None,
):
if not is_enabled:
@ -439,7 +454,7 @@ class Script(scripts.Script):
magic_temp_value=magic_temp_value,
magic_blocklist_regex=magic_blocklist_regex,
batch_size=magicprompt_batch_size,
device=device,
device=get_magic_prompt_device(),
)
.set_is_dummy(False)
.set_unlink_seed_from_prompt(unlink_seed_from_prompt)
@ -468,12 +483,12 @@ class Script(scripts.Script):
all_seeds = p.all_seeds
all_prompts, all_negative_prompts = generate_prompts(
generator,
negative_generator,
original_prompt,
original_negative_prompt,
num_images,
all_seeds,
prompt_generator=generator,
negative_prompt_generator=negative_generator,
prompt=original_prompt,
negative_prompt=original_negative_prompt,
num_prompts=num_images,
seeds=all_seeds,
)
except GeneratorException as e:
@ -517,13 +532,13 @@ class Script(scripts.Script):
p.prompt = original_prompt
if hr_fix_enabled:
p.all_hr_prompts = (
all_prompts
if original_prompt == original_hr_prompt
else len(all_prompts) * [original_hr_prompt]
p.all_hr_prompts = _get_hr_fix_prompts(
all_prompts,
original_hr_prompt,
original_prompt,
)
p.all_hr_negative_prompts = (
all_negative_prompts
if original_negative_prompt == original_negative_hr_prompt
else len(all_negative_prompts) * [original_negative_hr_prompt]
p.all_hr_negative_prompts = _get_hr_fix_prompts(
all_negative_prompts,
original_negative_hr_prompt,
original_negative_prompt,
)

View File

@ -198,20 +198,17 @@ class GeneratorBuilder:
parser_config=self._parser_config,
ignore_whitespace=self._ignore_whitespace,
)
prompt_generator = BatchedCombinatorialPromptGenerator(
return BatchedCombinatorialPromptGenerator(
prompt_generator,
batches=self._combinatorial_batches,
)
else:
prompt_generator = RandomPromptGenerator(
self._wildcard_manager,
seed=self._seed,
parser_config=self._parser_config,
unlink_seed_from_prompt=self._unlink_seed_from_prompt,
ignore_whitespace=self._ignore_whitespace,
)
return prompt_generator
return RandomPromptGenerator(
self._wildcard_manager,
seed=self._seed,
parser_config=self._parser_config,
unlink_seed_from_prompt=self._unlink_seed_from_prompt,
ignore_whitespace=self._ignore_whitespace,
)
def create_jinja_generator(self, p) -> PromptGenerator:
original_prompt = p.all_prompts[0] if len(p.all_prompts) > 0 else p.prompt

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from itertools import product
from itertools import cycle, islice, product
from pathlib import Path
from dynamicprompts.generators.promptgenerator import PromptGenerator
@ -17,7 +17,7 @@ def get_seeds(
use_fixed_seed,
is_combinatorial=False,
combinatorial_batches=1,
):
) -> tuple[list[int], list[int]]:
if p.subseed_strength != 0:
seed = int(p.all_seeds[0])
subseed = int(p.all_subseeds[0])
@ -74,7 +74,7 @@ def generate_prompts(
prompt: str,
negative_prompt: str | None,
num_prompts: int,
seeds: list[int],
seeds: list[int] | None,
) -> tuple[list[str], list[str]]:
"""
Generate positive and negative prompts.
@ -102,14 +102,14 @@ def generate_prompts(
if num_prompts is None:
return generate_prompt_cross_product(all_prompts, all_negative_prompts)
else:
return all_prompts, (all_negative_prompts * num_prompts)[0:num_prompts]
return all_prompts, repeat_iterable_to_length(all_negative_prompts, num_prompts)
def generate_prompt_cross_product(
prompts: list[str],
negative_prompts: list[str],
) -> tuple(list[str], list[str]):
) -> tuple[list[str], list[str]]:
"""
Create a cross product of all the items in `prompts` and `negative_prompts`.
Return the positive prompts and negative prompts in two separate lists
@ -121,10 +121,29 @@ def generate_prompt_cross_product(
Returns:
- Tuple containing list of positive and negative prompts
"""
if len(prompts) == 0 or len(negative_prompts) == 0:
if not (prompts and negative_prompts):
return [], []
positive_prompts, negative_prompts = list(
zip(*product(prompts, negative_prompts), strict=True),
new_positive_prompts, new_negative_prompts = zip(
*product(prompts, negative_prompts),
strict=True,
)
return list(positive_prompts), list(negative_prompts)
return list(new_positive_prompts), list(new_negative_prompts)
def repeat_iterable_to_length(iterable, length: int) -> list:
"""Repeat an iterable to a given length.
If the iterable is shorter than the desired length, it will be repeated
until it is long enough. If it is longer than the desired length, it will
be truncated.
Args:
iterable (Iterable): The iterable to repeat.
length (int): The desired length of the iterable.
Returns:
list: The repeated iterable.
"""
return list(islice(cycle(iterable), length))

View File

@ -42,7 +42,7 @@ class PngInfoSaver:
return parameters
def strip_template_info(self, parameters: dict[str, Any]) -> str:
def strip_template_info(self, parameters: dict[str, Any]) -> dict[str, Any]:
if "Prompt" in parameters and f"{TEMPLATE_LABEL}:" in parameters["Prompt"]:
parameters["Prompt"] = (
parameters["Prompt"].split(f"{TEMPLATE_LABEL}:")[0].strip()

View File

@ -32,7 +32,9 @@ try:
from packaging.requirements import Requirement
except ImportError:
# pip has had this since 2018
from pip._vendor.packaging.requirements import Requirement
from pip._vendor.packaging.requirements import ( # type: ignore[assignment]
Requirement,
)
logger = logging.getLogger(__name__)
@ -40,7 +42,7 @@ logger = logging.getLogger(__name__)
@dataclasses.dataclass
class InstallResult:
requirement: Requirement
installed: str
installed: str | None
@property
def message(self) -> str | None:
@ -59,7 +61,9 @@ class InstallResult:
@property
def correct(self) -> bool:
return self.installed and self.requirement.specifier.contains(self.installed)
return bool(
self.installed and self.requirement.specifier.contains(self.installed),
)
@property
def pip_install_command(self) -> str:
@ -72,9 +76,10 @@ class InstallResult:
@lru_cache(maxsize=1)
def get_requirements() -> tuple[str]:
def get_requirements() -> tuple[str, ...]:
toml_text = (Path(__file__).parent.parent / "pyproject.toml").read_text()
return tuple(tomli.loads(toml_text)["project"]["dependencies"])
deps = tomli.loads(toml_text)["project"]["dependencies"]
return tuple(str(dep) for dep in deps)
def get_install_result(req_str: str) -> InstallResult:

101
tests/conftest.py Normal file
View File

@ -0,0 +1,101 @@
from __future__ import annotations
import dataclasses
import sys
import types
from unittest.mock import MagicMock, Mock
import pytest
@dataclasses.dataclass
class MockProcessing:
seed: int
subseed: int
all_seeds: list[int]
all_subseeds: list[int]
subseed_strength: float
prompt: str = ""
negative_prompt: str = ""
hr_prompt: str = ""
hr_negative_prompt: str = ""
n_iter: int = 1
batch_size: int = 1
enable_hr: bool = False
all_prompts: list[str] = dataclasses.field(default_factory=list)
all_hr_prompts: list[str] = dataclasses.field(default_factory=list)
all_negative_prompts: list[str] = dataclasses.field(default_factory=list)
all_hr_negative_prompts: list[str] = dataclasses.field(default_factory=list)
def set_prompt_for_test(self, prompt):
self.prompt = prompt
self.hr_prompt = prompt
self.all_prompts = [prompt] * self.n_iter * self.batch_size
self.all_hr_prompts = self.all_prompts.copy()
def set_negative_prompt_for_test(self, negative_prompt):
self.negative_prompt = negative_prompt
self.hr_negative_prompt = negative_prompt
self.all_negative_prompts = [negative_prompt] * self.n_iter * self.batch_size
self.all_hr_negative_prompts = self.all_negative_prompts.copy()
@pytest.fixture
def processing() -> MockProcessing:
p = MockProcessing(
seed=1000,
subseed=2000,
all_seeds=list(range(3000, 3000 + 10)),
all_subseeds=list(range(4000, 4000 + 10)),
subseed_strength=0,
)
p.set_prompt_for_test("beautiful sheep")
p.set_negative_prompt_for_test("ugly")
return p
@pytest.fixture
def monkeypatch_webui(monkeypatch, tmp_path):
"""
Patch sys.modules to look like we have a (partial) WebUI installation.
"""
import torch
fake_webui = {
"gradio": {"__getattr__": MagicMock()},
"modules": {},
"modules.scripts": {"Script": object, "basedir": lambda: str(tmp_path)},
"modules.devices": {"device": torch.device("cpu")},
"modules.processing": {"fix_seed": Mock()},
"modules.shared": {
"opts": types.SimpleNamespace(
dp_auto_purge_cache=True,
dp_ignore_whitespace=True,
dp_limit_jinja_prompts=False,
dp_magicprompt_batch_size=1,
dp_parser_variant_end="}",
dp_parser_variant_start="{",
dp_parser_wildcard_wrap="__",
dp_wildcard_manager_no_dedupe=False,
dp_wildcard_manager_no_sort=False,
dp_wildcard_manager_shuffle=False,
dp_write_prompts_to_file=False,
dp_write_raw_template=False,
),
},
"modules.script_callbacks": {
"ImageSaveParams": object,
"__getattr__": MagicMock(),
},
"modules.generation_parameters_copypaste": {
"parse_generation_parameters": Mock(),
},
}
for module_name, contents in fake_webui.items():
if module_name in sys.modules:
continue
mod = types.ModuleType(module_name)
for name, obj in contents.items():
setattr(mod, name, obj)
monkeypatch.setitem(sys.modules, module_name, mod)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from unittest import mock
import pytest
@ -10,18 +12,6 @@ from sd_dynamic_prompts.helpers import (
)
@pytest.fixture
def processing():
m = mock.Mock()
m.seed = 1000
m.subseed = 2000
m.all_seeds = list(range(3000, 3000 + 10))
m.all_subseeds = list(range(4000, 4000 + 10))
m.subseed_strength = 0
return m
def test_get_seeds_with_fixed_seed(processing):
num_seeds = 10
@ -142,6 +132,8 @@ def test_generate_with_num_prompts(num_prompts: int | None):
num_prompts,
seeds,
)
assert isinstance(positive_prompts, list)
assert isinstance(negative_prompts, list)
if num_prompts:
assert positive_prompts == [

11
tests/test_paths.py Normal file
View File

@ -0,0 +1,11 @@
from sd_dynamic_prompts.paths import (
get_extension_base_path,
get_magicprompt_models_txt_path,
get_wildcard_dir,
)
def test_get_paths():
assert get_extension_base_path().is_dir()
assert get_magicprompt_models_txt_path().is_file()
assert get_wildcard_dir().is_dir()

55
tests/test_script.py Normal file
View File

@ -0,0 +1,55 @@
import pytest
@pytest.mark.parametrize("enable_hr", [True, False], ids=["yes_hr", "no_hr"])
@pytest.mark.parametrize("is_combinatorial", [True, False], ids=["yes_comb", "no_comb"])
def test_script(
monkeypatch,
monkeypatch_webui,
processing,
enable_hr,
is_combinatorial,
):
from scripts.dynamic_prompting import Script
s = Script()
if not is_combinatorial:
processing.batch_size = 3
processing.set_prompt_for_test("{red|green|blue} ball")
processing.set_negative_prompt_for_test("ugly")
processing.enable_hr = enable_hr
s.process(
p=processing,
is_enabled=True,
is_combinatorial=is_combinatorial,
combinatorial_batches=1,
is_magic_prompt=False,
is_feeling_lucky=False,
is_attention_grabber=False,
min_attention=0,
max_attention=1,
magic_prompt_length=0,
magic_temp_value=1,
use_fixed_seed=False,
unlink_seed_from_prompt=False,
disable_negative_prompt=False,
enable_jinja_templates=False,
no_image_generation=False,
max_generations=0,
magic_model="magic",
magic_blocklist_regex=None,
)
assert isinstance(processing.all_prompts, list)
assert isinstance(processing.all_negative_prompts, list)
assert isinstance(processing.all_hr_prompts, list)
assert isinstance(processing.all_hr_negative_prompts, list)
if is_combinatorial:
assert processing.all_prompts == ["red ball", "green ball", "blue ball"]
assert processing.all_negative_prompts == ["ugly"] * 3
if enable_hr:
assert processing.all_hr_prompts == processing.all_prompts
assert processing.all_hr_negative_prompts == processing.all_negative_prompts
else:
assert len(processing.all_prompts) == 3 # can't assert on the contents