diff --git a/sd_dynamic_prompts/dynamic_prompting.py b/sd_dynamic_prompts/dynamic_prompting.py index 81574ef..ae4c012 100644 --- a/sd_dynamic_prompts/dynamic_prompting.py +++ b/sd_dynamic_prompts/dynamic_prompting.py @@ -12,7 +12,6 @@ import torch from dynamicprompts.generators.promptgenerator import GeneratorException from dynamicprompts.parser.parse import ParserConfig from dynamicprompts.wildcards import WildcardManager -from modules import devices from modules.processing import fix_seed from modules.shared import opts @@ -23,6 +22,7 @@ from sd_dynamic_prompts.helpers import ( generate_prompts, get_seeds, load_magicprompt_models, + repeat_iterable_to_length, should_freeze_prompt, ) from sd_dynamic_prompts.paths import ( @@ -48,11 +48,6 @@ def _get_effective_prompt(prompts: list[str], prompt: str) -> str: return prompts[0] if prompts else prompt -device = devices.device -# There might be a bug in auto1111 where the correct device is not inferred in some scenarios -if device.type == "cuda" and not device.index: - device = torch.device("cuda:0") - loaded_count = 0 @@ -69,6 +64,26 @@ def _get_install_error_message() -> str | None: return None +def _get_hr_fix_prompts( + prompts: list[str], + original_hr_prompt: str, + original_prompt: str, +) -> list[str]: + if original_prompt == original_hr_prompt: + return list(prompts) + return repeat_iterable_to_length([original_hr_prompt], len(prompts)) + + +def get_magic_prompt_device() -> torch.device: + from modules import devices + + device = devices.device + # There might be a bug in auto1111 where the correct device is not inferred in some scenarios + if device.type == "cuda" and not device.index: + device = torch.device("cuda:0") + return device + + class Script(scripts.Script): def __init__(self): global loaded_count @@ -334,23 +349,23 @@ class Script(scripts.Script): def process( self, p, - is_enabled, - is_combinatorial, - combinatorial_batches, - is_magic_prompt, - is_feeling_lucky, - is_attention_grabber, - min_attention, - max_attention, - magic_prompt_length, - magic_temp_value, - use_fixed_seed, - unlink_seed_from_prompt, - disable_negative_prompt, - enable_jinja_templates, - no_image_generation, - max_generations, - magic_model, + is_enabled: bool, + is_combinatorial: bool, + combinatorial_batches: int, + is_magic_prompt: bool, + is_feeling_lucky: bool, + is_attention_grabber: bool, + min_attention: float, + max_attention: float, + magic_prompt_length: int, + magic_temp_value: float, + use_fixed_seed: bool, + unlink_seed_from_prompt: bool, + disable_negative_prompt: bool, + enable_jinja_templates: bool, + no_image_generation: bool, + max_generations: int, + magic_model: str | None, magic_blocklist_regex: str | None, ): if not is_enabled: @@ -439,7 +454,7 @@ class Script(scripts.Script): magic_temp_value=magic_temp_value, magic_blocklist_regex=magic_blocklist_regex, batch_size=magicprompt_batch_size, - device=device, + device=get_magic_prompt_device(), ) .set_is_dummy(False) .set_unlink_seed_from_prompt(unlink_seed_from_prompt) @@ -468,12 +483,12 @@ class Script(scripts.Script): all_seeds = p.all_seeds all_prompts, all_negative_prompts = generate_prompts( - generator, - negative_generator, - original_prompt, - original_negative_prompt, - num_images, - all_seeds, + prompt_generator=generator, + negative_prompt_generator=negative_generator, + prompt=original_prompt, + negative_prompt=original_negative_prompt, + num_prompts=num_images, + seeds=all_seeds, ) except GeneratorException as e: @@ -517,13 +532,13 @@ class Script(scripts.Script): p.prompt = original_prompt if hr_fix_enabled: - p.all_hr_prompts = ( - all_prompts - if original_prompt == original_hr_prompt - else len(all_prompts) * [original_hr_prompt] + p.all_hr_prompts = _get_hr_fix_prompts( + all_prompts, + original_hr_prompt, + original_prompt, ) - p.all_hr_negative_prompts = ( - all_negative_prompts - if original_negative_prompt == original_negative_hr_prompt - else len(all_negative_prompts) * [original_negative_hr_prompt] + p.all_hr_negative_prompts = _get_hr_fix_prompts( + all_negative_prompts, + original_negative_hr_prompt, + original_negative_prompt, ) diff --git a/sd_dynamic_prompts/generator_builder.py b/sd_dynamic_prompts/generator_builder.py index 62b81be..ebf3468 100644 --- a/sd_dynamic_prompts/generator_builder.py +++ b/sd_dynamic_prompts/generator_builder.py @@ -198,20 +198,17 @@ class GeneratorBuilder: parser_config=self._parser_config, ignore_whitespace=self._ignore_whitespace, ) - prompt_generator = BatchedCombinatorialPromptGenerator( + return BatchedCombinatorialPromptGenerator( prompt_generator, batches=self._combinatorial_batches, ) - else: - prompt_generator = RandomPromptGenerator( - self._wildcard_manager, - seed=self._seed, - parser_config=self._parser_config, - unlink_seed_from_prompt=self._unlink_seed_from_prompt, - ignore_whitespace=self._ignore_whitespace, - ) - - return prompt_generator + return RandomPromptGenerator( + self._wildcard_manager, + seed=self._seed, + parser_config=self._parser_config, + unlink_seed_from_prompt=self._unlink_seed_from_prompt, + ignore_whitespace=self._ignore_whitespace, + ) def create_jinja_generator(self, p) -> PromptGenerator: original_prompt = p.all_prompts[0] if len(p.all_prompts) > 0 else p.prompt diff --git a/sd_dynamic_prompts/helpers.py b/sd_dynamic_prompts/helpers.py index 15ee93e..4aef51b 100644 --- a/sd_dynamic_prompts/helpers.py +++ b/sd_dynamic_prompts/helpers.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from itertools import product +from itertools import cycle, islice, product from pathlib import Path from dynamicprompts.generators.promptgenerator import PromptGenerator @@ -17,7 +17,7 @@ def get_seeds( use_fixed_seed, is_combinatorial=False, combinatorial_batches=1, -): +) -> tuple[list[int], list[int]]: if p.subseed_strength != 0: seed = int(p.all_seeds[0]) subseed = int(p.all_subseeds[0]) @@ -74,7 +74,7 @@ def generate_prompts( prompt: str, negative_prompt: str | None, num_prompts: int, - seeds: list[int], + seeds: list[int] | None, ) -> tuple[list[str], list[str]]: """ Generate positive and negative prompts. @@ -102,14 +102,14 @@ def generate_prompts( if num_prompts is None: return generate_prompt_cross_product(all_prompts, all_negative_prompts) - else: - return all_prompts, (all_negative_prompts * num_prompts)[0:num_prompts] + + return all_prompts, repeat_iterable_to_length(all_negative_prompts, num_prompts) def generate_prompt_cross_product( prompts: list[str], negative_prompts: list[str], -) -> tuple(list[str], list[str]): +) -> tuple[list[str], list[str]]: """ Create a cross product of all the items in `prompts` and `negative_prompts`. Return the positive prompts and negative prompts in two separate lists @@ -121,10 +121,29 @@ def generate_prompt_cross_product( Returns: - Tuple containing list of positive and negative prompts """ - if len(prompts) == 0 or len(negative_prompts) == 0: + if not (prompts and negative_prompts): return [], [] - positive_prompts, negative_prompts = list( - zip(*product(prompts, negative_prompts), strict=True), + new_positive_prompts, new_negative_prompts = zip( + *product(prompts, negative_prompts), + strict=True, ) - return list(positive_prompts), list(negative_prompts) + return list(new_positive_prompts), list(new_negative_prompts) + + +def repeat_iterable_to_length(iterable, length: int) -> list: + """Repeat an iterable to a given length. + + If the iterable is shorter than the desired length, it will be repeated + until it is long enough. If it is longer than the desired length, it will + be truncated. + + Args: + iterable (Iterable): The iterable to repeat. + length (int): The desired length of the iterable. + + Returns: + list: The repeated iterable. + + """ + return list(islice(cycle(iterable), length)) diff --git a/sd_dynamic_prompts/pnginfo_saver.py b/sd_dynamic_prompts/pnginfo_saver.py index a4b03c6..16a03c1 100644 --- a/sd_dynamic_prompts/pnginfo_saver.py +++ b/sd_dynamic_prompts/pnginfo_saver.py @@ -42,7 +42,7 @@ class PngInfoSaver: return parameters - def strip_template_info(self, parameters: dict[str, Any]) -> str: + def strip_template_info(self, parameters: dict[str, Any]) -> dict[str, Any]: if "Prompt" in parameters and f"{TEMPLATE_LABEL}:" in parameters["Prompt"]: parameters["Prompt"] = ( parameters["Prompt"].split(f"{TEMPLATE_LABEL}:")[0].strip() diff --git a/sd_dynamic_prompts/version_tools.py b/sd_dynamic_prompts/version_tools.py index 11bae81..3dd3a50 100644 --- a/sd_dynamic_prompts/version_tools.py +++ b/sd_dynamic_prompts/version_tools.py @@ -32,7 +32,9 @@ try: from packaging.requirements import Requirement except ImportError: # pip has had this since 2018 - from pip._vendor.packaging.requirements import Requirement + from pip._vendor.packaging.requirements import ( # type: ignore[assignment] + Requirement, + ) logger = logging.getLogger(__name__) @@ -40,7 +42,7 @@ logger = logging.getLogger(__name__) @dataclasses.dataclass class InstallResult: requirement: Requirement - installed: str + installed: str | None @property def message(self) -> str | None: @@ -59,7 +61,9 @@ class InstallResult: @property def correct(self) -> bool: - return self.installed and self.requirement.specifier.contains(self.installed) + return bool( + self.installed and self.requirement.specifier.contains(self.installed), + ) @property def pip_install_command(self) -> str: @@ -72,9 +76,10 @@ class InstallResult: @lru_cache(maxsize=1) -def get_requirements() -> tuple[str]: +def get_requirements() -> tuple[str, ...]: toml_text = (Path(__file__).parent.parent / "pyproject.toml").read_text() - return tuple(tomli.loads(toml_text)["project"]["dependencies"]) + deps = tomli.loads(toml_text)["project"]["dependencies"] + return tuple(str(dep) for dep in deps) def get_install_result(req_str: str) -> InstallResult: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5cce4bf --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import dataclasses +import sys +import types +from unittest.mock import MagicMock, Mock + +import pytest + + +@dataclasses.dataclass +class MockProcessing: + seed: int + subseed: int + all_seeds: list[int] + all_subseeds: list[int] + subseed_strength: float + prompt: str = "" + negative_prompt: str = "" + hr_prompt: str = "" + hr_negative_prompt: str = "" + n_iter: int = 1 + batch_size: int = 1 + enable_hr: bool = False + all_prompts: list[str] = dataclasses.field(default_factory=list) + all_hr_prompts: list[str] = dataclasses.field(default_factory=list) + all_negative_prompts: list[str] = dataclasses.field(default_factory=list) + all_hr_negative_prompts: list[str] = dataclasses.field(default_factory=list) + + def set_prompt_for_test(self, prompt): + self.prompt = prompt + self.hr_prompt = prompt + self.all_prompts = [prompt] * self.n_iter * self.batch_size + self.all_hr_prompts = self.all_prompts.copy() + + def set_negative_prompt_for_test(self, negative_prompt): + self.negative_prompt = negative_prompt + self.hr_negative_prompt = negative_prompt + self.all_negative_prompts = [negative_prompt] * self.n_iter * self.batch_size + self.all_hr_negative_prompts = self.all_negative_prompts.copy() + + +@pytest.fixture +def processing() -> MockProcessing: + p = MockProcessing( + seed=1000, + subseed=2000, + all_seeds=list(range(3000, 3000 + 10)), + all_subseeds=list(range(4000, 4000 + 10)), + subseed_strength=0, + ) + p.set_prompt_for_test("beautiful sheep") + p.set_negative_prompt_for_test("ugly") + return p + + +@pytest.fixture +def monkeypatch_webui(monkeypatch, tmp_path): + """ + Patch sys.modules to look like we have a (partial) WebUI installation. + """ + import torch + + fake_webui = { + "gradio": {"__getattr__": MagicMock()}, + "modules": {}, + "modules.scripts": {"Script": object, "basedir": lambda: str(tmp_path)}, + "modules.devices": {"device": torch.device("cpu")}, + "modules.processing": {"fix_seed": Mock()}, + "modules.shared": { + "opts": types.SimpleNamespace( + dp_auto_purge_cache=True, + dp_ignore_whitespace=True, + dp_limit_jinja_prompts=False, + dp_magicprompt_batch_size=1, + dp_parser_variant_end="}", + dp_parser_variant_start="{", + dp_parser_wildcard_wrap="__", + dp_wildcard_manager_no_dedupe=False, + dp_wildcard_manager_no_sort=False, + dp_wildcard_manager_shuffle=False, + dp_write_prompts_to_file=False, + dp_write_raw_template=False, + ), + }, + "modules.script_callbacks": { + "ImageSaveParams": object, + "__getattr__": MagicMock(), + }, + "modules.generation_parameters_copypaste": { + "parse_generation_parameters": Mock(), + }, + } + + for module_name, contents in fake_webui.items(): + if module_name in sys.modules: + continue + mod = types.ModuleType(module_name) + for name, obj in contents.items(): + setattr(mod, name, obj) + monkeypatch.setitem(sys.modules, module_name, mod) diff --git a/tests/prompts/test_helpers.py b/tests/prompts/test_helpers.py index 094d7d3..e6bcf49 100644 --- a/tests/prompts/test_helpers.py +++ b/tests/prompts/test_helpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest import mock import pytest @@ -10,18 +12,6 @@ from sd_dynamic_prompts.helpers import ( ) -@pytest.fixture -def processing(): - m = mock.Mock() - m.seed = 1000 - m.subseed = 2000 - m.all_seeds = list(range(3000, 3000 + 10)) - m.all_subseeds = list(range(4000, 4000 + 10)) - m.subseed_strength = 0 - - return m - - def test_get_seeds_with_fixed_seed(processing): num_seeds = 10 @@ -142,6 +132,8 @@ def test_generate_with_num_prompts(num_prompts: int | None): num_prompts, seeds, ) + assert isinstance(positive_prompts, list) + assert isinstance(negative_prompts, list) if num_prompts: assert positive_prompts == [ diff --git a/tests/test_paths.py b/tests/test_paths.py new file mode 100644 index 0000000..9a07dce --- /dev/null +++ b/tests/test_paths.py @@ -0,0 +1,11 @@ +from sd_dynamic_prompts.paths import ( + get_extension_base_path, + get_magicprompt_models_txt_path, + get_wildcard_dir, +) + + +def test_get_paths(): + assert get_extension_base_path().is_dir() + assert get_magicprompt_models_txt_path().is_file() + assert get_wildcard_dir().is_dir() diff --git a/tests/test_script.py b/tests/test_script.py new file mode 100644 index 0000000..eecb2ab --- /dev/null +++ b/tests/test_script.py @@ -0,0 +1,55 @@ +import pytest + + +@pytest.mark.parametrize("enable_hr", [True, False], ids=["yes_hr", "no_hr"]) +@pytest.mark.parametrize("is_combinatorial", [True, False], ids=["yes_comb", "no_comb"]) +def test_script( + monkeypatch, + monkeypatch_webui, + processing, + enable_hr, + is_combinatorial, +): + from scripts.dynamic_prompting import Script + + s = Script() + if not is_combinatorial: + processing.batch_size = 3 + processing.set_prompt_for_test("{red|green|blue} ball") + processing.set_negative_prompt_for_test("ugly") + processing.enable_hr = enable_hr + s.process( + p=processing, + is_enabled=True, + is_combinatorial=is_combinatorial, + combinatorial_batches=1, + is_magic_prompt=False, + is_feeling_lucky=False, + is_attention_grabber=False, + min_attention=0, + max_attention=1, + magic_prompt_length=0, + magic_temp_value=1, + use_fixed_seed=False, + unlink_seed_from_prompt=False, + disable_negative_prompt=False, + enable_jinja_templates=False, + no_image_generation=False, + max_generations=0, + magic_model="magic", + magic_blocklist_regex=None, + ) + assert isinstance(processing.all_prompts, list) + assert isinstance(processing.all_negative_prompts, list) + assert isinstance(processing.all_hr_prompts, list) + assert isinstance(processing.all_hr_negative_prompts, list) + + if is_combinatorial: + assert processing.all_prompts == ["red ball", "green ball", "blue ball"] + assert processing.all_negative_prompts == ["ugly"] * 3 + + if enable_hr: + assert processing.all_hr_prompts == processing.all_prompts + assert processing.all_hr_negative_prompts == processing.all_negative_prompts + else: + assert len(processing.all_prompts) == 3 # can't assert on the contents