Merge branch 'main' into no-strict-zip
commit
320beda0f0
|
|
@ -12,7 +12,6 @@ import torch
|
|||
from dynamicprompts.generators.promptgenerator import GeneratorException
|
||||
from dynamicprompts.parser.parse import ParserConfig
|
||||
from dynamicprompts.wildcards import WildcardManager
|
||||
from modules import devices
|
||||
from modules.processing import fix_seed
|
||||
from modules.shared import opts
|
||||
|
||||
|
|
@ -23,6 +22,7 @@ from sd_dynamic_prompts.helpers import (
|
|||
generate_prompts,
|
||||
get_seeds,
|
||||
load_magicprompt_models,
|
||||
repeat_iterable_to_length,
|
||||
should_freeze_prompt,
|
||||
)
|
||||
from sd_dynamic_prompts.paths import (
|
||||
|
|
@ -48,11 +48,6 @@ def _get_effective_prompt(prompts: list[str], prompt: str) -> str:
|
|||
return prompts[0] if prompts else prompt
|
||||
|
||||
|
||||
device = devices.device
|
||||
# There might be a bug in auto1111 where the correct device is not inferred in some scenarios
|
||||
if device.type == "cuda" and not device.index:
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
loaded_count = 0
|
||||
|
||||
|
||||
|
|
@ -69,6 +64,26 @@ def _get_install_error_message() -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def _get_hr_fix_prompts(
|
||||
prompts: list[str],
|
||||
original_hr_prompt: str,
|
||||
original_prompt: str,
|
||||
) -> list[str]:
|
||||
if original_prompt == original_hr_prompt:
|
||||
return list(prompts)
|
||||
return repeat_iterable_to_length([original_hr_prompt], len(prompts))
|
||||
|
||||
|
||||
def get_magic_prompt_device() -> torch.device:
|
||||
from modules import devices
|
||||
|
||||
device = devices.device
|
||||
# There might be a bug in auto1111 where the correct device is not inferred in some scenarios
|
||||
if device.type == "cuda" and not device.index:
|
||||
device = torch.device("cuda:0")
|
||||
return device
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
def __init__(self):
|
||||
global loaded_count
|
||||
|
|
@ -334,23 +349,23 @@ class Script(scripts.Script):
|
|||
def process(
|
||||
self,
|
||||
p,
|
||||
is_enabled,
|
||||
is_combinatorial,
|
||||
combinatorial_batches,
|
||||
is_magic_prompt,
|
||||
is_feeling_lucky,
|
||||
is_attention_grabber,
|
||||
min_attention,
|
||||
max_attention,
|
||||
magic_prompt_length,
|
||||
magic_temp_value,
|
||||
use_fixed_seed,
|
||||
unlink_seed_from_prompt,
|
||||
disable_negative_prompt,
|
||||
enable_jinja_templates,
|
||||
no_image_generation,
|
||||
max_generations,
|
||||
magic_model,
|
||||
is_enabled: bool,
|
||||
is_combinatorial: bool,
|
||||
combinatorial_batches: int,
|
||||
is_magic_prompt: bool,
|
||||
is_feeling_lucky: bool,
|
||||
is_attention_grabber: bool,
|
||||
min_attention: float,
|
||||
max_attention: float,
|
||||
magic_prompt_length: int,
|
||||
magic_temp_value: float,
|
||||
use_fixed_seed: bool,
|
||||
unlink_seed_from_prompt: bool,
|
||||
disable_negative_prompt: bool,
|
||||
enable_jinja_templates: bool,
|
||||
no_image_generation: bool,
|
||||
max_generations: int,
|
||||
magic_model: str | None,
|
||||
magic_blocklist_regex: str | None,
|
||||
):
|
||||
if not is_enabled:
|
||||
|
|
@ -439,7 +454,7 @@ class Script(scripts.Script):
|
|||
magic_temp_value=magic_temp_value,
|
||||
magic_blocklist_regex=magic_blocklist_regex,
|
||||
batch_size=magicprompt_batch_size,
|
||||
device=device,
|
||||
device=get_magic_prompt_device(),
|
||||
)
|
||||
.set_is_dummy(False)
|
||||
.set_unlink_seed_from_prompt(unlink_seed_from_prompt)
|
||||
|
|
@ -468,12 +483,12 @@ class Script(scripts.Script):
|
|||
all_seeds = p.all_seeds
|
||||
|
||||
all_prompts, all_negative_prompts = generate_prompts(
|
||||
generator,
|
||||
negative_generator,
|
||||
original_prompt,
|
||||
original_negative_prompt,
|
||||
num_images,
|
||||
all_seeds,
|
||||
prompt_generator=generator,
|
||||
negative_prompt_generator=negative_generator,
|
||||
prompt=original_prompt,
|
||||
negative_prompt=original_negative_prompt,
|
||||
num_prompts=num_images,
|
||||
seeds=all_seeds,
|
||||
)
|
||||
|
||||
except GeneratorException as e:
|
||||
|
|
@ -517,13 +532,13 @@ class Script(scripts.Script):
|
|||
p.prompt = original_prompt
|
||||
|
||||
if hr_fix_enabled:
|
||||
p.all_hr_prompts = (
|
||||
all_prompts
|
||||
if original_prompt == original_hr_prompt
|
||||
else len(all_prompts) * [original_hr_prompt]
|
||||
p.all_hr_prompts = _get_hr_fix_prompts(
|
||||
all_prompts,
|
||||
original_hr_prompt,
|
||||
original_prompt,
|
||||
)
|
||||
p.all_hr_negative_prompts = (
|
||||
all_negative_prompts
|
||||
if original_negative_prompt == original_negative_hr_prompt
|
||||
else len(all_negative_prompts) * [original_negative_hr_prompt]
|
||||
p.all_hr_negative_prompts = _get_hr_fix_prompts(
|
||||
all_negative_prompts,
|
||||
original_negative_hr_prompt,
|
||||
original_negative_prompt,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -198,20 +198,17 @@ class GeneratorBuilder:
|
|||
parser_config=self._parser_config,
|
||||
ignore_whitespace=self._ignore_whitespace,
|
||||
)
|
||||
prompt_generator = BatchedCombinatorialPromptGenerator(
|
||||
return BatchedCombinatorialPromptGenerator(
|
||||
prompt_generator,
|
||||
batches=self._combinatorial_batches,
|
||||
)
|
||||
else:
|
||||
prompt_generator = RandomPromptGenerator(
|
||||
self._wildcard_manager,
|
||||
seed=self._seed,
|
||||
parser_config=self._parser_config,
|
||||
unlink_seed_from_prompt=self._unlink_seed_from_prompt,
|
||||
ignore_whitespace=self._ignore_whitespace,
|
||||
)
|
||||
|
||||
return prompt_generator
|
||||
return RandomPromptGenerator(
|
||||
self._wildcard_manager,
|
||||
seed=self._seed,
|
||||
parser_config=self._parser_config,
|
||||
unlink_seed_from_prompt=self._unlink_seed_from_prompt,
|
||||
ignore_whitespace=self._ignore_whitespace,
|
||||
)
|
||||
|
||||
def create_jinja_generator(self, p) -> PromptGenerator:
|
||||
original_prompt = p.all_prompts[0] if len(p.all_prompts) > 0 else p.prompt
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from itertools import product
|
||||
from itertools import cycle, islice, product
|
||||
from pathlib import Path
|
||||
|
||||
from dynamicprompts.generators.promptgenerator import PromptGenerator
|
||||
|
|
@ -17,7 +17,7 @@ def get_seeds(
|
|||
use_fixed_seed,
|
||||
is_combinatorial=False,
|
||||
combinatorial_batches=1,
|
||||
):
|
||||
) -> tuple[list[int], list[int]]:
|
||||
if p.subseed_strength != 0:
|
||||
seed = int(p.all_seeds[0])
|
||||
subseed = int(p.all_subseeds[0])
|
||||
|
|
@ -74,7 +74,7 @@ def generate_prompts(
|
|||
prompt: str,
|
||||
negative_prompt: str | None,
|
||||
num_prompts: int,
|
||||
seeds: list[int],
|
||||
seeds: list[int] | None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Generate positive and negative prompts.
|
||||
|
|
@ -102,14 +102,14 @@ def generate_prompts(
|
|||
|
||||
if num_prompts is None:
|
||||
return generate_prompt_cross_product(all_prompts, all_negative_prompts)
|
||||
else:
|
||||
return all_prompts, (all_negative_prompts * num_prompts)[0:num_prompts]
|
||||
|
||||
return all_prompts, repeat_iterable_to_length(all_negative_prompts, num_prompts)
|
||||
|
||||
|
||||
def generate_prompt_cross_product(
|
||||
prompts: list[str],
|
||||
negative_prompts: list[str],
|
||||
) -> tuple(list[str], list[str]):
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Create a cross product of all the items in `prompts` and `negative_prompts`.
|
||||
Return the positive prompts and negative prompts in two separate lists
|
||||
|
|
@ -121,11 +121,30 @@ def generate_prompt_cross_product(
|
|||
Returns:
|
||||
- Tuple containing list of positive and negative prompts
|
||||
"""
|
||||
if len(prompts) == 0 or len(negative_prompts) == 0:
|
||||
if not (prompts and negative_prompts):
|
||||
return [], []
|
||||
|
||||
positive_prompts, negative_prompts = list(
|
||||
# noqa to remain compatible with python 3.9, see issue #601
|
||||
zip(*product(prompts, negative_prompts)), # noqa: B905
|
||||
# noqa to remain compatible with python 3.9, see issue #601
|
||||
new_positive_prompts, new_negative_prompts = zip(
|
||||
*product(prompts, negative_prompts),
|
||||
|
||||
)
|
||||
return list(positive_prompts), list(negative_prompts)
|
||||
return list(new_positive_prompts), list(new_negative_prompts)
|
||||
|
||||
|
||||
def repeat_iterable_to_length(iterable, length: int) -> list:
|
||||
"""Repeat an iterable to a given length.
|
||||
|
||||
If the iterable is shorter than the desired length, it will be repeated
|
||||
until it is long enough. If it is longer than the desired length, it will
|
||||
be truncated.
|
||||
|
||||
Args:
|
||||
iterable (Iterable): The iterable to repeat.
|
||||
length (int): The desired length of the iterable.
|
||||
|
||||
Returns:
|
||||
list: The repeated iterable.
|
||||
|
||||
"""
|
||||
return list(islice(cycle(iterable), length))
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class PngInfoSaver:
|
|||
|
||||
return parameters
|
||||
|
||||
def strip_template_info(self, parameters: dict[str, Any]) -> str:
|
||||
def strip_template_info(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
if "Prompt" in parameters and f"{TEMPLATE_LABEL}:" in parameters["Prompt"]:
|
||||
parameters["Prompt"] = (
|
||||
parameters["Prompt"].split(f"{TEMPLATE_LABEL}:")[0].strip()
|
||||
|
|
|
|||
|
|
@ -32,7 +32,9 @@ try:
|
|||
from packaging.requirements import Requirement
|
||||
except ImportError:
|
||||
# pip has had this since 2018
|
||||
from pip._vendor.packaging.requirements import Requirement
|
||||
from pip._vendor.packaging.requirements import ( # type: ignore[assignment]
|
||||
Requirement,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -40,7 +42,7 @@ logger = logging.getLogger(__name__)
|
|||
@dataclasses.dataclass
|
||||
class InstallResult:
|
||||
requirement: Requirement
|
||||
installed: str
|
||||
installed: str | None
|
||||
|
||||
@property
|
||||
def message(self) -> str | None:
|
||||
|
|
@ -59,7 +61,9 @@ class InstallResult:
|
|||
|
||||
@property
|
||||
def correct(self) -> bool:
|
||||
return self.installed and self.requirement.specifier.contains(self.installed)
|
||||
return bool(
|
||||
self.installed and self.requirement.specifier.contains(self.installed),
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_install_command(self) -> str:
|
||||
|
|
@ -72,9 +76,10 @@ class InstallResult:
|
|||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_requirements() -> tuple[str]:
|
||||
def get_requirements() -> tuple[str, ...]:
|
||||
toml_text = (Path(__file__).parent.parent / "pyproject.toml").read_text()
|
||||
return tuple(tomli.loads(toml_text)["project"]["dependencies"])
|
||||
deps = tomli.loads(toml_text)["project"]["dependencies"]
|
||||
return tuple(str(dep) for dep in deps)
|
||||
|
||||
|
||||
def get_install_result(req_str: str) -> InstallResult:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,101 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MockProcessing:
|
||||
seed: int
|
||||
subseed: int
|
||||
all_seeds: list[int]
|
||||
all_subseeds: list[int]
|
||||
subseed_strength: float
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
hr_prompt: str = ""
|
||||
hr_negative_prompt: str = ""
|
||||
n_iter: int = 1
|
||||
batch_size: int = 1
|
||||
enable_hr: bool = False
|
||||
all_prompts: list[str] = dataclasses.field(default_factory=list)
|
||||
all_hr_prompts: list[str] = dataclasses.field(default_factory=list)
|
||||
all_negative_prompts: list[str] = dataclasses.field(default_factory=list)
|
||||
all_hr_negative_prompts: list[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
def set_prompt_for_test(self, prompt):
|
||||
self.prompt = prompt
|
||||
self.hr_prompt = prompt
|
||||
self.all_prompts = [prompt] * self.n_iter * self.batch_size
|
||||
self.all_hr_prompts = self.all_prompts.copy()
|
||||
|
||||
def set_negative_prompt_for_test(self, negative_prompt):
|
||||
self.negative_prompt = negative_prompt
|
||||
self.hr_negative_prompt = negative_prompt
|
||||
self.all_negative_prompts = [negative_prompt] * self.n_iter * self.batch_size
|
||||
self.all_hr_negative_prompts = self.all_negative_prompts.copy()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def processing() -> MockProcessing:
|
||||
p = MockProcessing(
|
||||
seed=1000,
|
||||
subseed=2000,
|
||||
all_seeds=list(range(3000, 3000 + 10)),
|
||||
all_subseeds=list(range(4000, 4000 + 10)),
|
||||
subseed_strength=0,
|
||||
)
|
||||
p.set_prompt_for_test("beautiful sheep")
|
||||
p.set_negative_prompt_for_test("ugly")
|
||||
return p
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def monkeypatch_webui(monkeypatch, tmp_path):
|
||||
"""
|
||||
Patch sys.modules to look like we have a (partial) WebUI installation.
|
||||
"""
|
||||
import torch
|
||||
|
||||
fake_webui = {
|
||||
"gradio": {"__getattr__": MagicMock()},
|
||||
"modules": {},
|
||||
"modules.scripts": {"Script": object, "basedir": lambda: str(tmp_path)},
|
||||
"modules.devices": {"device": torch.device("cpu")},
|
||||
"modules.processing": {"fix_seed": Mock()},
|
||||
"modules.shared": {
|
||||
"opts": types.SimpleNamespace(
|
||||
dp_auto_purge_cache=True,
|
||||
dp_ignore_whitespace=True,
|
||||
dp_limit_jinja_prompts=False,
|
||||
dp_magicprompt_batch_size=1,
|
||||
dp_parser_variant_end="}",
|
||||
dp_parser_variant_start="{",
|
||||
dp_parser_wildcard_wrap="__",
|
||||
dp_wildcard_manager_no_dedupe=False,
|
||||
dp_wildcard_manager_no_sort=False,
|
||||
dp_wildcard_manager_shuffle=False,
|
||||
dp_write_prompts_to_file=False,
|
||||
dp_write_raw_template=False,
|
||||
),
|
||||
},
|
||||
"modules.script_callbacks": {
|
||||
"ImageSaveParams": object,
|
||||
"__getattr__": MagicMock(),
|
||||
},
|
||||
"modules.generation_parameters_copypaste": {
|
||||
"parse_generation_parameters": Mock(),
|
||||
},
|
||||
}
|
||||
|
||||
for module_name, contents in fake_webui.items():
|
||||
if module_name in sys.modules:
|
||||
continue
|
||||
mod = types.ModuleType(module_name)
|
||||
for name, obj in contents.items():
|
||||
setattr(mod, name, obj)
|
||||
monkeypatch.setitem(sys.modules, module_name, mod)
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
|
@ -10,18 +12,6 @@ from sd_dynamic_prompts.helpers import (
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def processing():
|
||||
m = mock.Mock()
|
||||
m.seed = 1000
|
||||
m.subseed = 2000
|
||||
m.all_seeds = list(range(3000, 3000 + 10))
|
||||
m.all_subseeds = list(range(4000, 4000 + 10))
|
||||
m.subseed_strength = 0
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def test_get_seeds_with_fixed_seed(processing):
|
||||
num_seeds = 10
|
||||
|
||||
|
|
@ -142,6 +132,8 @@ def test_generate_with_num_prompts(num_prompts: int | None):
|
|||
num_prompts,
|
||||
seeds,
|
||||
)
|
||||
assert isinstance(positive_prompts, list)
|
||||
assert isinstance(negative_prompts, list)
|
||||
|
||||
if num_prompts:
|
||||
assert positive_prompts == [
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
from sd_dynamic_prompts.paths import (
|
||||
get_extension_base_path,
|
||||
get_magicprompt_models_txt_path,
|
||||
get_wildcard_dir,
|
||||
)
|
||||
|
||||
|
||||
def test_get_paths():
|
||||
assert get_extension_base_path().is_dir()
|
||||
assert get_magicprompt_models_txt_path().is_file()
|
||||
assert get_wildcard_dir().is_dir()
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_hr", [True, False], ids=["yes_hr", "no_hr"])
|
||||
@pytest.mark.parametrize("is_combinatorial", [True, False], ids=["yes_comb", "no_comb"])
|
||||
def test_script(
|
||||
monkeypatch,
|
||||
monkeypatch_webui,
|
||||
processing,
|
||||
enable_hr,
|
||||
is_combinatorial,
|
||||
):
|
||||
from scripts.dynamic_prompting import Script
|
||||
|
||||
s = Script()
|
||||
if not is_combinatorial:
|
||||
processing.batch_size = 3
|
||||
processing.set_prompt_for_test("{red|green|blue} ball")
|
||||
processing.set_negative_prompt_for_test("ugly")
|
||||
processing.enable_hr = enable_hr
|
||||
s.process(
|
||||
p=processing,
|
||||
is_enabled=True,
|
||||
is_combinatorial=is_combinatorial,
|
||||
combinatorial_batches=1,
|
||||
is_magic_prompt=False,
|
||||
is_feeling_lucky=False,
|
||||
is_attention_grabber=False,
|
||||
min_attention=0,
|
||||
max_attention=1,
|
||||
magic_prompt_length=0,
|
||||
magic_temp_value=1,
|
||||
use_fixed_seed=False,
|
||||
unlink_seed_from_prompt=False,
|
||||
disable_negative_prompt=False,
|
||||
enable_jinja_templates=False,
|
||||
no_image_generation=False,
|
||||
max_generations=0,
|
||||
magic_model="magic",
|
||||
magic_blocklist_regex=None,
|
||||
)
|
||||
assert isinstance(processing.all_prompts, list)
|
||||
assert isinstance(processing.all_negative_prompts, list)
|
||||
assert isinstance(processing.all_hr_prompts, list)
|
||||
assert isinstance(processing.all_hr_negative_prompts, list)
|
||||
|
||||
if is_combinatorial:
|
||||
assert processing.all_prompts == ["red ball", "green ball", "blue ball"]
|
||||
assert processing.all_negative_prompts == ["ugly"] * 3
|
||||
|
||||
if enable_hr:
|
||||
assert processing.all_hr_prompts == processing.all_prompts
|
||||
assert processing.all_hr_negative_prompts == processing.all_negative_prompts
|
||||
else:
|
||||
assert len(processing.all_prompts) == 3 # can't assert on the contents
|
||||
Loading…
Reference in New Issue