Add and use repeat_iterable_to_length helper
parent
5cb0a814b6
commit
688199fc9b
|
|
@ -23,6 +23,7 @@ from sd_dynamic_prompts.helpers import (
|
|||
generate_prompts,
|
||||
get_seeds,
|
||||
load_magicprompt_models,
|
||||
repeat_iterable_to_length,
|
||||
should_freeze_prompt,
|
||||
)
|
||||
from sd_dynamic_prompts.paths import (
|
||||
|
|
@ -69,6 +70,16 @@ def _get_install_error_message() -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def _get_hr_fix_prompts(
|
||||
prompts: list[str],
|
||||
original_hr_prompt: str,
|
||||
original_prompt: str,
|
||||
) -> list[str]:
|
||||
if original_prompt == original_hr_prompt:
|
||||
return list(prompts)
|
||||
return repeat_iterable_to_length([original_hr_prompt], len(prompts))
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
def __init__(self):
|
||||
global loaded_count
|
||||
|
|
@ -517,13 +528,13 @@ class Script(scripts.Script):
|
|||
p.prompt = original_prompt
|
||||
|
||||
if hr_fix_enabled:
|
||||
p.all_hr_prompts = (
|
||||
all_prompts
|
||||
if original_prompt == original_hr_prompt
|
||||
else len(all_prompts) * [original_hr_prompt]
|
||||
p.all_hr_prompts = _get_hr_fix_prompts(
|
||||
all_prompts,
|
||||
original_hr_prompt,
|
||||
original_prompt,
|
||||
)
|
||||
p.all_hr_negative_prompts = (
|
||||
all_negative_prompts
|
||||
if original_negative_prompt == original_negative_hr_prompt
|
||||
else len(all_negative_prompts) * [original_negative_hr_prompt]
|
||||
p.all_hr_negative_prompts = _get_hr_fix_prompts(
|
||||
all_negative_prompts,
|
||||
original_negative_hr_prompt,
|
||||
original_negative_prompt,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from itertools import product
|
||||
from itertools import cycle, islice, product
|
||||
from pathlib import Path
|
||||
|
||||
from dynamicprompts.generators.promptgenerator import PromptGenerator
|
||||
|
|
@ -102,14 +102,14 @@ def generate_prompts(
|
|||
|
||||
if num_prompts is None:
|
||||
return generate_prompt_cross_product(all_prompts, all_negative_prompts)
|
||||
else:
|
||||
return all_prompts, (all_negative_prompts * num_prompts)[0:num_prompts]
|
||||
|
||||
return all_prompts, repeat_iterable_to_length(all_negative_prompts, num_prompts)
|
||||
|
||||
|
||||
def generate_prompt_cross_product(
|
||||
prompts: list[str],
|
||||
negative_prompts: list[str],
|
||||
) -> tuple(list[str], list[str]):
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Create a cross product of all the items in `prompts` and `negative_prompts`.
|
||||
Return the positive prompts and negative prompts in two separate lists
|
||||
|
|
@ -121,10 +121,29 @@ def generate_prompt_cross_product(
|
|||
Returns:
|
||||
- Tuple containing list of positive and negative prompts
|
||||
"""
|
||||
if len(prompts) == 0 or len(negative_prompts) == 0:
|
||||
if not (prompts and negative_prompts):
|
||||
return [], []
|
||||
|
||||
positive_prompts, negative_prompts = list(
|
||||
zip(*product(prompts, negative_prompts), strict=True),
|
||||
new_positive_prompts, new_negative_prompts = zip(
|
||||
*product(prompts, negative_prompts),
|
||||
strict=True,
|
||||
)
|
||||
return list(positive_prompts), list(negative_prompts)
|
||||
return list(new_positive_prompts), list(new_negative_prompts)
|
||||
|
||||
|
||||
def repeat_iterable_to_length(iterable, length: int) -> list:
|
||||
"""Repeat an iterable to a given length.
|
||||
|
||||
If the iterable is shorter than the desired length, it will be repeated
|
||||
until it is long enough. If it is longer than the desired length, it will
|
||||
be truncated.
|
||||
|
||||
Args:
|
||||
iterable (Iterable): The iterable to repeat.
|
||||
length (int): The desired length of the iterable.
|
||||
|
||||
Returns:
|
||||
list: The repeated iterable.
|
||||
|
||||
"""
|
||||
return list(islice(cycle(iterable), length))
|
||||
|
|
|
|||
Loading…
Reference in New Issue