Add and use repeat_iterable_to_length helper

pull/597/head
Aarni Koskela 2023-08-15 14:34:51 +03:00
parent 5cb0a814b6
commit 688199fc9b
2 changed files with 46 additions and 16 deletions

View File

@ -23,6 +23,7 @@ from sd_dynamic_prompts.helpers import (
generate_prompts,
get_seeds,
load_magicprompt_models,
repeat_iterable_to_length,
should_freeze_prompt,
)
from sd_dynamic_prompts.paths import (
@ -69,6 +70,16 @@ def _get_install_error_message() -> str | None:
return None
def _get_hr_fix_prompts(
prompts: list[str],
original_hr_prompt: str,
original_prompt: str,
) -> list[str]:
if original_prompt == original_hr_prompt:
return list(prompts)
return repeat_iterable_to_length([original_hr_prompt], len(prompts))
class Script(scripts.Script):
def __init__(self):
global loaded_count
@ -517,13 +528,13 @@ class Script(scripts.Script):
p.prompt = original_prompt
if hr_fix_enabled:
p.all_hr_prompts = (
all_prompts
if original_prompt == original_hr_prompt
else len(all_prompts) * [original_hr_prompt]
p.all_hr_prompts = _get_hr_fix_prompts(
all_prompts,
original_hr_prompt,
original_prompt,
)
p.all_hr_negative_prompts = (
all_negative_prompts
if original_negative_prompt == original_negative_hr_prompt
else len(all_negative_prompts) * [original_negative_hr_prompt]
p.all_hr_negative_prompts = _get_hr_fix_prompts(
all_negative_prompts,
original_negative_hr_prompt,
original_negative_prompt,
)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from itertools import product
from itertools import cycle, islice, product
from pathlib import Path
from dynamicprompts.generators.promptgenerator import PromptGenerator
@ -102,14 +102,14 @@ def generate_prompts(
if num_prompts is None:
return generate_prompt_cross_product(all_prompts, all_negative_prompts)
else:
return all_prompts, (all_negative_prompts * num_prompts)[0:num_prompts]
return all_prompts, repeat_iterable_to_length(all_negative_prompts, num_prompts)
def generate_prompt_cross_product(
prompts: list[str],
negative_prompts: list[str],
) -> tuple(list[str], list[str]):
) -> tuple[list[str], list[str]]:
"""
Create a cross product of all the items in `prompts` and `negative_prompts`.
Return the positive prompts and negative prompts in two separate lists
@ -121,10 +121,29 @@ def generate_prompt_cross_product(
Returns:
- Tuple containing list of positive and negative prompts
"""
if len(prompts) == 0 or len(negative_prompts) == 0:
if not (prompts and negative_prompts):
return [], []
positive_prompts, negative_prompts = list(
zip(*product(prompts, negative_prompts), strict=True),
new_positive_prompts, new_negative_prompts = zip(
*product(prompts, negative_prompts),
strict=True,
)
return list(positive_prompts), list(negative_prompts)
return list(new_positive_prompts), list(new_negative_prompts)
def repeat_iterable_to_length(iterable, length: int) -> list:
"""Repeat an iterable to a given length.
If the iterable is shorter than the desired length, it will be repeated
until it is long enough. If it is longer than the desired length, it will
be truncated.
Args:
iterable (Iterable): The iterable to repeat.
length (int): The desired length of the iterable.
Returns:
list: The repeated iterable.
"""
return list(islice(cycle(iterable), length))