Merge pull request #352 from adieyal/bug/variation-seed
Freezing the prompt if variation strength is > 0pull/355/head
commit
b293722b3d
|
|
@ -1,3 +1,4 @@
|
|||
- 2.8.12 Prompts are frozen if the variation strength is greater than 0. See [#310](https://github.com/adieyal/sd-dynamic-prompts/issues/310)
|
||||
- 2.8.11 Fixed the broken wildcards manager, see [#338](https://github.com/adieyal/sd-dynamic-prompts/issues/338)
|
||||
- 2.8.10 Magic Prompt now works on M1/M2 Mac - see [#329](https://github.com/adieyal/sd-dynamic-prompts/issues/329)
|
||||
- 2.8.9 Updated dynamicprompts to 0.10.5 which fixes #307
|
||||
|
|
|
|||
|
|
@ -19,11 +19,11 @@ from modules.shared import opts
|
|||
from sd_dynamic_prompts import callbacks
|
||||
from sd_dynamic_prompts.consts import MAGIC_PROMPT_MODELS
|
||||
from sd_dynamic_prompts.generator_builder import GeneratorBuilder
|
||||
from sd_dynamic_prompts.helpers import get_seeds
|
||||
from sd_dynamic_prompts.helpers import get_seeds, should_freeze_prompt
|
||||
from sd_dynamic_prompts.ui.pnginfo_saver import PngInfoSaver
|
||||
from sd_dynamic_prompts.ui.prompt_writer import PromptWriter
|
||||
|
||||
VERSION = "2.8.11"
|
||||
VERSION = "2.8.12"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -413,6 +413,7 @@ class Script(scripts.Script):
|
|||
.set_unlink_seed_from_prompt(unlink_seed_from_prompt)
|
||||
.set_seed(original_seed)
|
||||
.set_context(p)
|
||||
.set_freeze_prompt(should_freeze_prompt(p))
|
||||
)
|
||||
|
||||
generator = generator_builder.create_generator()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dynamicprompts.generators.promptgenerator import PromptGenerator
|
||||
|
||||
|
||||
class FrozenPromptGenerator(PromptGenerator):
|
||||
"""
|
||||
Generates a prompt once and repeats that prompt as num_images times
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_generator: PromptGenerator):
|
||||
self._generator = prompt_generator
|
||||
|
||||
def generate(
|
||||
self,
|
||||
template: str,
|
||||
num_images: int = 1,
|
||||
) -> list[str]:
|
||||
prompts = self._generator.generate(template, 1)
|
||||
return num_images * prompts
|
||||
|
|
@ -14,6 +14,7 @@ from dynamicprompts.generators import (
|
|||
from dynamicprompts.parser.parse import default_parser_config
|
||||
|
||||
from sd_dynamic_prompts.consts import DEFAULT_MAGIC_MODEL
|
||||
from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -28,6 +29,7 @@ class GeneratorBuilder:
|
|||
self._wildcard_manager = wildcard_manager
|
||||
|
||||
self._is_dummy = False
|
||||
self._should_freeze_prompt = False
|
||||
self._is_feeling_lucky = False
|
||||
self._is_jinja_template = False
|
||||
self._is_combinatorial = False
|
||||
|
|
@ -125,6 +127,10 @@ class GeneratorBuilder:
|
|||
self._seed = seed
|
||||
return self
|
||||
|
||||
def set_freeze_prompt(self, should_freeze: bool):
|
||||
self._should_freeze_prompt = should_freeze
|
||||
return self
|
||||
|
||||
def set_context(self, context):
|
||||
self._context = context
|
||||
return self
|
||||
|
|
@ -175,6 +181,9 @@ class GeneratorBuilder:
|
|||
)
|
||||
except ImportError as ie:
|
||||
logger.error(f"Not using AttentionGenerator: {ie}")
|
||||
|
||||
if self._should_freeze_prompt:
|
||||
generator = FrozenPromptGenerator(generator)
|
||||
return generator
|
||||
|
||||
def create_basic_generator(
|
||||
|
|
|
|||
|
|
@ -18,3 +18,8 @@ def get_seeds(p, num_seeds, use_fixed_seed):
|
|||
all_subseeds = [subseed + i for i in range(num_seeds)]
|
||||
|
||||
return all_seeds, all_subseeds
|
||||
|
||||
|
||||
def should_freeze_prompt(p):
|
||||
# When using a variation seed, the prompt shouldn't change between generations
|
||||
return p.subseed_strength > 0
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
from dynamicprompts.generators import RandomPromptGenerator
|
||||
|
||||
from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator
|
||||
|
||||
|
||||
def test_repeats_correctly():
|
||||
generator = FrozenPromptGenerator(RandomPromptGenerator())
|
||||
template = "{A|B|C|D|E|F|G|H|I|J|K}"
|
||||
prompts = generator.generate(template, 10)
|
||||
|
||||
assert len(prompts) == 10
|
||||
assert len(set(prompts)) == 1
|
||||
|
||||
prompts2 = generator.generate(template, 10)
|
||||
|
||||
assert len(prompts2) == 10
|
||||
assert len(set(prompts2)) == 1
|
||||
|
||||
assert prompts[0] != prompts2[0]
|
||||
|
|
@ -3,6 +3,7 @@ from unittest.mock import patch
|
|||
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
|
||||
from dynamicprompts.wildcards import WildcardManager
|
||||
|
||||
from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator
|
||||
from sd_dynamic_prompts.generator_builder import GeneratorBuilder
|
||||
|
||||
|
||||
|
|
@ -17,3 +18,11 @@ def test_magic_blocklist_regexp(tmp_path):
|
|||
gen = gb.create_generator()
|
||||
assert isinstance(gen, MagicPromptGenerator)
|
||||
assert gen._blocklist_regex.pattern == popular_artist
|
||||
|
||||
|
||||
def test_frozen_generator(tmp_path):
|
||||
gb = GeneratorBuilder(wildcard_manager=WildcardManager(tmp_path)).set_freeze_prompt(
|
||||
True,
|
||||
)
|
||||
gen = gb.create_generator()
|
||||
assert type(gen) == FrozenPromptGenerator
|
||||
|
|
|
|||
Loading…
Reference in New Issue