Merge pull request #595 from adieyal/bug/cross-product

Bug/cross product
pull/596/head
Adi Eyal 2023-08-13 15:42:02 +02:00 committed by GitHub
commit f23624e1be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 126 additions and 9 deletions

View File

@ -1,3 +1,5 @@
- 2.16.1 Only using cross product when num_prompts is not provided
- 2.16.0 Added cross product of positive and negative prompts
- 2.15.0 Added Wildcard Manager Search
- 2.14.0 Added configuration option to shuffle wildcards for increased randomness in combinatorial mode.
- 2.13.0 Added configuration options to prevent the wildcard manager from sorting and deduplicating wildcard files

View File

@ -1 +1 @@
__version__ = "2.15.0"
__version__ = "2.16.1"

View File

@ -101,4 +101,31 @@ def generate_prompts(
seeds=negative_seeds,
) or [""]
return list(zip(*product(all_prompts, all_negative_prompts), strict=True))
if num_prompts is None:
return generate_prompt_cross_product(all_prompts, all_negative_prompts)
else:
return all_prompts, (all_negative_prompts * num_prompts)[0:num_prompts]
def generate_prompt_cross_product(
prompts: list[str],
negative_prompts: list[str],
) -> tuple(list[str], list[str]):
"""
Create a cross product of all the items in `prompts` and `negative_prompts`.
Return the positive prompts and negative prompts in two separate lists
Parameters:
- prompts: List of prompts
- negative_prompts: List of negative prompts
Returns:
- Tuple containing list of positive and negative prompts
"""
if len(prompts) == 0 or len(negative_prompts) == 0:
return [], []
positive_prompts, negative_prompts = list(
zip(*product(prompts, negative_prompts), strict=True),
)
return list(positive_prompts), list(negative_prompts)

View File

@ -4,16 +4,17 @@ from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator
def test_repeats_correctly():
generator = FrozenPromptGenerator(RandomPromptGenerator())
generator = FrozenPromptGenerator(
RandomPromptGenerator(unlink_seed_from_prompt=True),
)
template = "{A|B|C|D|E|F|G|H|I|J|K}"
prompts = generator.generate(template, 10)
prompts = generator.generate(template, 40)
assert len(prompts) == 10
assert len(prompts) == 40
assert len(set(prompts)) == 1
prompts2 = generator.generate(template, 10)
prompts2 = generator.generate(template, 40)
assert len(prompts2) == 10
assert len(prompts2) == 40
assert len(set(prompts2)) == 1
assert prompts[0] != prompts2[0]

View File

@ -4,7 +4,12 @@ from unittest import mock
import pytest
from sd_dynamic_prompts.helpers import get_seeds, load_magicprompt_models
from sd_dynamic_prompts.helpers import (
generate_prompt_cross_product,
generate_prompts,
get_seeds,
load_magicprompt_models,
)
@pytest.fixture
@ -104,3 +109,85 @@ model 2
load_magicprompt_models(tmp_filename)
finally:
os.unlink(tmp_filename)
def test_cross_product():
prompts = []
negative_prompts = []
expected_output = [], []
assert generate_prompt_cross_product(prompts, negative_prompts) == expected_output
prompts = ["A", "B", "C"]
negative_prompts = ["X", "Y"]
expected_output = (["A", "A", "B", "B", "C", "C"], ["X", "Y", "X", "Y", "X", "Y"])
assert generate_prompt_cross_product(prompts, negative_prompts) == expected_output
@pytest.mark.parametrize("num_prompts", [5, None])
def test_generate_with_num_prompts(num_prompts: int | None):
prompt_generator = mock.Mock()
negative_prompt_generator = mock.Mock()
prompt_generator.generate.return_value = [
"Positive Prompt 1",
"Positive Prompt 2",
"Positive Prompt 3",
"Positive Prompt 4",
"Positive Prompt 5",
]
negative_prompt_generator.generate.return_value = [
"Negative Prompt 1",
"Negative Prompt 2",
]
prompt = "This is a positive prompt."
negative_prompt = "This is a negative prompt."
seeds = [1, 2, 3, 4, 5]
positive_prompts, negative_prompts = generate_prompts(
prompt_generator,
negative_prompt_generator,
prompt,
negative_prompt,
num_prompts,
seeds,
)
if num_prompts:
assert positive_prompts == [
"Positive Prompt 1",
"Positive Prompt 2",
"Positive Prompt 3",
"Positive Prompt 4",
"Positive Prompt 5",
]
assert negative_prompts == [
"Negative Prompt 1",
"Negative Prompt 2",
"Negative Prompt 1",
"Negative Prompt 2",
"Negative Prompt 1",
]
else:
assert positive_prompts == [
"Positive Prompt 1",
"Positive Prompt 1",
"Positive Prompt 2",
"Positive Prompt 2",
"Positive Prompt 3",
"Positive Prompt 3",
"Positive Prompt 4",
"Positive Prompt 4",
"Positive Prompt 5",
"Positive Prompt 5",
]
assert negative_prompts == [
"Negative Prompt 1",
"Negative Prompt 2",
"Negative Prompt 1",
"Negative Prompt 2",
"Negative Prompt 1",
"Negative Prompt 2",
"Negative Prompt 1",
"Negative Prompt 2",
"Negative Prompt 1",
"Negative Prompt 2",
]