Merge pull request #352 from adieyal/bug/variation-seed

Freezing the prompt if variation strength is > 0
pull/355/head
Adi Eyal 2023-03-29 08:51:48 +03:00 committed by GitHub
commit b293722b3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 66 additions and 2 deletions

View File

@ -1,3 +1,4 @@
- 2.8.12 Prompts are frozen if the variation strength is greater than 0. See [#310](https://github.com/adieyal/sd-dynamic-prompts/issues/310)
- 2.8.11 Fixed the broken wildcards manager, see [#338](https://github.com/adieyal/sd-dynamic-prompts/issues/338)
- 2.8.10 Magic Prompt now works on M1/M2 Mac - see [#329](https://github.com/adieyal/sd-dynamic-prompts/issues/329)
- 2.8.9 Updated dynamicprompts to 0.10.5 which fixes #307

View File

@ -19,11 +19,11 @@ from modules.shared import opts
from sd_dynamic_prompts import callbacks
from sd_dynamic_prompts.consts import MAGIC_PROMPT_MODELS
from sd_dynamic_prompts.generator_builder import GeneratorBuilder
from sd_dynamic_prompts.helpers import get_seeds
from sd_dynamic_prompts.helpers import get_seeds, should_freeze_prompt
from sd_dynamic_prompts.ui.pnginfo_saver import PngInfoSaver
from sd_dynamic_prompts.ui.prompt_writer import PromptWriter
VERSION = "2.8.11"
VERSION = "2.8.12"
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@ -413,6 +413,7 @@ class Script(scripts.Script):
.set_unlink_seed_from_prompt(unlink_seed_from_prompt)
.set_seed(original_seed)
.set_context(p)
.set_freeze_prompt(should_freeze_prompt(p))
)
generator = generator_builder.create_generator()

View File

@ -0,0 +1,20 @@
from __future__ import annotations
from dynamicprompts.generators.promptgenerator import PromptGenerator
class FrozenPromptGenerator(PromptGenerator):
"""
Generates a prompt once and repeats that prompt as num_images times
"""
def __init__(self, prompt_generator: PromptGenerator):
self._generator = prompt_generator
def generate(
self,
template: str,
num_images: int = 1,
) -> list[str]:
prompts = self._generator.generate(template, 1)
return num_images * prompts

View File

@ -14,6 +14,7 @@ from dynamicprompts.generators import (
from dynamicprompts.parser.parse import default_parser_config
from sd_dynamic_prompts.consts import DEFAULT_MAGIC_MODEL
from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator
logger = logging.getLogger(__name__)
@ -28,6 +29,7 @@ class GeneratorBuilder:
self._wildcard_manager = wildcard_manager
self._is_dummy = False
self._should_freeze_prompt = False
self._is_feeling_lucky = False
self._is_jinja_template = False
self._is_combinatorial = False
@ -125,6 +127,10 @@ class GeneratorBuilder:
self._seed = seed
return self
def set_freeze_prompt(self, should_freeze: bool):
self._should_freeze_prompt = should_freeze
return self
def set_context(self, context):
self._context = context
return self
@ -175,6 +181,9 @@ class GeneratorBuilder:
)
except ImportError as ie:
logger.error(f"Not using AttentionGenerator: {ie}")
if self._should_freeze_prompt:
generator = FrozenPromptGenerator(generator)
return generator
def create_basic_generator(

View File

@ -18,3 +18,8 @@ def get_seeds(p, num_seeds, use_fixed_seed):
all_subseeds = [subseed + i for i in range(num_seeds)]
return all_seeds, all_subseeds
def should_freeze_prompt(p):
# When using a variation seed, the prompt shouldn't change between generations
return p.subseed_strength > 0

View File

@ -0,0 +1,19 @@
from dynamicprompts.generators import RandomPromptGenerator
from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator
def test_repeats_correctly():
generator = FrozenPromptGenerator(RandomPromptGenerator())
template = "{A|B|C|D|E|F|G|H|I|J|K}"
prompts = generator.generate(template, 10)
assert len(prompts) == 10
assert len(set(prompts)) == 1
prompts2 = generator.generate(template, 10)
assert len(prompts2) == 10
assert len(set(prompts2)) == 1
assert prompts[0] != prompts2[0]

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
from dynamicprompts.wildcards import WildcardManager
from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator
from sd_dynamic_prompts.generator_builder import GeneratorBuilder
@ -17,3 +18,11 @@ def test_magic_blocklist_regexp(tmp_path):
gen = gb.create_generator()
assert isinstance(gen, MagicPromptGenerator)
assert gen._blocklist_regex.pattern == popular_artist
def test_frozen_generator(tmp_path):
gb = GeneratorBuilder(wildcard_manager=WildcardManager(tmp_path)).set_freeze_prompt(
True,
)
gen = gb.create_generator()
assert type(gen) == FrozenPromptGenerator