Updated how prompt seeds are generated
These are now returned from the get_seeds function which decides whether if should be the same as image seeds or generated separately. Fixes #535bug/fixed-seeds
parent
78d599c256
commit
2b74c19cae
|
|
@ -464,16 +464,18 @@ class Script(scripts.Script):
|
|||
else:
|
||||
negative_generator = generator
|
||||
|
||||
all_seeds = None
|
||||
if num_images and not unlink_seed_from_prompt:
|
||||
p.all_seeds, p.all_subseeds = get_seeds(
|
||||
prompt_seeds = p.all_seeds
|
||||
if num_images:
|
||||
image_seeds, image_subseeds, prompt_seeds = get_seeds(
|
||||
p,
|
||||
num_images,
|
||||
use_fixed_seed,
|
||||
is_combinatorial,
|
||||
combinatorial_batches,
|
||||
unlink_seed_from_prompt,
|
||||
)
|
||||
all_seeds = p.all_seeds
|
||||
p.all_seeds = image_seeds
|
||||
p.all_subseeds = image_subseeds
|
||||
|
||||
all_prompts, all_negative_prompts = generate_prompts(
|
||||
generator,
|
||||
|
|
@ -481,7 +483,7 @@ class Script(scripts.Script):
|
|||
original_prompt,
|
||||
original_negative_prompt,
|
||||
num_images,
|
||||
all_seeds,
|
||||
prompt_seeds,
|
||||
)
|
||||
|
||||
except GeneratorException as e:
|
||||
|
|
@ -493,12 +495,13 @@ class Script(scripts.Script):
|
|||
p.n_iter = math.ceil(updated_count / p.batch_size)
|
||||
|
||||
if num_images != updated_count:
|
||||
p.all_seeds, p.all_subseeds = get_seeds(
|
||||
p.all_seeds, p.all_subseeds, _ = get_seeds(
|
||||
p,
|
||||
updated_count,
|
||||
use_fixed_seed,
|
||||
is_combinatorial,
|
||||
combinatorial_batches,
|
||||
unlink_seed_from_prompt,
|
||||
)
|
||||
|
||||
if updated_count > 1:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
from dynamicprompts.generators.promptgenerator import PromptGenerator
|
||||
|
|
@ -8,12 +9,21 @@ from dynamicprompts.generators.promptgenerator import PromptGenerator
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_fixed_seed(seed):
|
||||
# Copied from auto1111 modules/processing.py
|
||||
if seed is None or seed == "" or seed == -1:
|
||||
return int(random.randrange(4294967294))
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
def get_seeds(
|
||||
p,
|
||||
num_seeds,
|
||||
use_fixed_seed,
|
||||
is_combinatorial=False,
|
||||
combinatorial_batches=1,
|
||||
unlink_seed_from_prompt=False,
|
||||
):
|
||||
if p.subseed_strength != 0:
|
||||
seed = int(p.all_seeds[0])
|
||||
|
|
@ -24,22 +34,27 @@ def get_seeds(
|
|||
|
||||
if use_fixed_seed:
|
||||
if is_combinatorial:
|
||||
all_seeds = []
|
||||
all_subseeds = [subseed] * num_seeds
|
||||
image_seeds = []
|
||||
image_subseeds = [subseed] * num_seeds
|
||||
for i in range(combinatorial_batches):
|
||||
all_seeds.extend([seed + i] * (num_seeds // combinatorial_batches))
|
||||
image_seeds.extend([seed + i] * (num_seeds // combinatorial_batches))
|
||||
else:
|
||||
all_seeds = [seed] * num_seeds
|
||||
all_subseeds = [subseed] * num_seeds
|
||||
image_seeds = [seed] * num_seeds
|
||||
image_subseeds = [subseed] * num_seeds
|
||||
else:
|
||||
if p.subseed_strength == 0:
|
||||
all_seeds = [seed + i for i in range(num_seeds)]
|
||||
image_seeds = [seed + i for i in range(num_seeds)]
|
||||
else:
|
||||
all_seeds = [seed] * num_seeds
|
||||
image_seeds = [seed] * num_seeds
|
||||
|
||||
all_subseeds = [subseed + i for i in range(num_seeds)]
|
||||
image_subseeds = [subseed + i for i in range(num_seeds)]
|
||||
|
||||
return all_seeds, all_subseeds
|
||||
if unlink_seed_from_prompt:
|
||||
prompt_seeds = [get_fixed_seed(None) for _ in range(num_seeds)]
|
||||
else:
|
||||
prompt_seeds = image_seeds
|
||||
|
||||
return image_seeds, image_subseeds, prompt_seeds
|
||||
|
||||
|
||||
def should_freeze_prompt(p):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator
|
|||
|
||||
|
||||
def test_repeats_correctly():
|
||||
generator = FrozenPromptGenerator(RandomPromptGenerator())
|
||||
generator = FrozenPromptGenerator(
|
||||
RandomPromptGenerator(unlink_seed_from_prompt=True),
|
||||
)
|
||||
template = "{A|B|C|D|E|F|G|H|I|J|K}"
|
||||
prompts = generator.generate(template, 10)
|
||||
|
||||
|
|
@ -15,5 +17,4 @@ def test_repeats_correctly():
|
|||
|
||||
assert len(prompts2) == 10
|
||||
assert len(set(prompts2)) == 1
|
||||
|
||||
assert prompts[0] != prompts2[0]
|
||||
|
|
|
|||
|
|
@ -22,21 +22,29 @@ def processing():
|
|||
def test_get_seeds_with_fixed_seed(processing):
|
||||
num_seeds = 10
|
||||
|
||||
seeds, subseeds = get_seeds(processing, num_seeds, use_fixed_seed=True)
|
||||
assert seeds == [processing.seed] * num_seeds
|
||||
assert subseeds == [processing.subseed] * num_seeds
|
||||
image_seeds, image_subseeds, _ = get_seeds(
|
||||
processing,
|
||||
num_seeds,
|
||||
use_fixed_seed=True,
|
||||
)
|
||||
assert image_seeds == [processing.seed] * num_seeds
|
||||
assert image_subseeds == [processing.subseed] * num_seeds
|
||||
|
||||
processing.subseed_strength = 0.5
|
||||
|
||||
seeds, subseeds = get_seeds(processing, num_seeds, use_fixed_seed=True)
|
||||
assert seeds == [processing.all_seeds[0]] * num_seeds
|
||||
assert subseeds == [processing.all_subseeds[0]] * num_seeds
|
||||
image_seeds, image_subseeds, _ = get_seeds(
|
||||
processing,
|
||||
num_seeds,
|
||||
use_fixed_seed=True,
|
||||
)
|
||||
assert image_seeds == [processing.all_seeds[0]] * num_seeds
|
||||
assert image_subseeds == [processing.all_subseeds[0]] * num_seeds
|
||||
|
||||
|
||||
def test_get_seeds_with_fixed_seed_batched_combinatorial(processing):
|
||||
num_seeds = 10
|
||||
combinatorial_batches = 3
|
||||
seeds, subseeds = get_seeds(
|
||||
image_seeds, image_subseeds, _ = get_seeds(
|
||||
processing,
|
||||
num_seeds,
|
||||
use_fixed_seed=True,
|
||||
|
|
@ -44,16 +52,16 @@ def test_get_seeds_with_fixed_seed_batched_combinatorial(processing):
|
|||
combinatorial_batches=combinatorial_batches,
|
||||
)
|
||||
seed0 = processing.seed
|
||||
assert seeds == (
|
||||
assert image_seeds == (
|
||||
[seed0] * (num_seeds // 3)
|
||||
+ [seed0 + 1] * (num_seeds // 3)
|
||||
+ [seed0 + 2] * (num_seeds // 3)
|
||||
)
|
||||
assert subseeds == [processing.subseed] * num_seeds
|
||||
assert image_subseeds == [processing.subseed] * num_seeds
|
||||
|
||||
processing.subseed_strength = 0.5
|
||||
|
||||
seeds, subseeds = get_seeds(
|
||||
image_seeds, image_subseeds, _ = get_seeds(
|
||||
processing,
|
||||
num_seeds,
|
||||
use_fixed_seed=True,
|
||||
|
|
@ -61,28 +69,57 @@ def test_get_seeds_with_fixed_seed_batched_combinatorial(processing):
|
|||
combinatorial_batches=combinatorial_batches,
|
||||
)
|
||||
seed0 = processing.all_seeds[0]
|
||||
assert seeds == (
|
||||
assert image_seeds == (
|
||||
[seed0] * (num_seeds // 3)
|
||||
+ [seed0 + 1] * (num_seeds // 3)
|
||||
+ [seed0 + 2] * (num_seeds // 3)
|
||||
)
|
||||
assert subseeds == [processing.all_subseeds[0]] * num_seeds
|
||||
assert image_subseeds == [processing.all_subseeds[0]] * num_seeds
|
||||
|
||||
|
||||
def test_get_seeds_with_random_seed(processing):
|
||||
num_seeds = 10
|
||||
|
||||
seed, subseed = processing.seed, processing.subseed
|
||||
seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False)
|
||||
assert seeds == list(range(seed, seed + num_seeds))
|
||||
assert subseeds == list(range(subseed, subseed + num_seeds))
|
||||
image_seeds, image_subseeds = processing.seed, processing.subseed
|
||||
seeds, subseeds, _ = get_seeds(
|
||||
processing,
|
||||
num_seeds=num_seeds,
|
||||
use_fixed_seed=False,
|
||||
)
|
||||
assert seeds == list(range(image_seeds, image_seeds + num_seeds))
|
||||
assert subseeds == list(range(image_subseeds, image_subseeds + num_seeds))
|
||||
|
||||
processing.subseed_strength = 0.5
|
||||
|
||||
seed, subseed = processing.all_seeds[0], processing.all_subseeds[0]
|
||||
seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False)
|
||||
assert seeds == [seed] * num_seeds
|
||||
assert subseeds == list(range(subseed, subseed + num_seeds))
|
||||
image_seeds, image_subseeds = processing.all_seeds[0], processing.all_subseeds[0]
|
||||
seeds, subseeds, _ = get_seeds(
|
||||
processing,
|
||||
num_seeds=num_seeds,
|
||||
use_fixed_seed=False,
|
||||
)
|
||||
assert seeds == [image_seeds] * num_seeds
|
||||
assert subseeds == list(range(image_subseeds, image_subseeds + num_seeds))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fixed_seed", [True, False])
|
||||
def test_get_with_unlinked_seed(processing, use_fixed_seed):
|
||||
num_seeds = 10
|
||||
|
||||
image_seeds, _, prompt_seeds = get_seeds(
|
||||
processing,
|
||||
num_seeds,
|
||||
use_fixed_seed=use_fixed_seed,
|
||||
unlink_seed_from_prompt=False,
|
||||
)
|
||||
assert image_seeds == prompt_seeds
|
||||
|
||||
image_seeds, _, prompt_seeds = get_seeds(
|
||||
processing,
|
||||
num_seeds,
|
||||
use_fixed_seed=use_fixed_seed,
|
||||
unlink_seed_from_prompt=True,
|
||||
)
|
||||
assert image_seeds != prompt_seeds
|
||||
|
||||
|
||||
def test_load_magicprompt_models():
|
||||
|
|
|
|||
Loading…
Reference in New Issue