sd-dynamic-prompts/prompts/generators/combinatorial.py

80 lines
2.6 KiB
Python

from itertools import chain
import logging
from prompts.wildcardmanager import WildcardManager
from prompts import constants
from . import PromptGenerator, re_combinations, re_wildcard
logger = logging.getLogger(__name__)
class CombinatorialPromptGenerator(PromptGenerator):
def __init__(self, wildcardmanager: WildcardManager, template):
self._wildcard_manager = wildcardmanager
self._template = template
def generate_from_variants(self, seed_template):
templates = [seed_template]
new_templates = []
variants = re_combinations.findall(templates[0])
for variant in variants:
for val in variant.split("|"):
for template in templates:
new_templates.append(template.replace(f"{{{variant}}}", val, 1))
templates = new_templates
new_templates = []
if len(templates) == 0:
return [seed_template]
return templates
def generate_from_wildcards(self, seed_template, recursion=0):
templates = []
if recursion > constants.MAX_RECURSIONS:
raise Exception("Too many recursions, something went wrong with generating the prompt: " + seed_template)
template = seed_template
wildcards = re_wildcard.findall(template)
if len(wildcards) == 0:
return [template]
for wildcard in wildcards:
wildcard_files = self._wildcard_manager.match_files(wildcard)
for val in chain(*[f.get_wildcards() for f in wildcard_files]):
new_template = template.replace(f"__{wildcard}__", val, 1)
logger.debug(f"New template: {new_template}")
templates.append(new_template)
new_templates = []
for template in templates:
new_templates += self.generate_from_wildcards(template, recursion=recursion + 1)
return new_templates
def generate(self, max_prompts=constants.MAX_IMAGES) -> list[str]:
templates = [self._template]
all_prompts = []
while True:
if len(templates) == 0 or len(all_prompts) > max_prompts:
break
template = templates.pop(0)
new_prompts = self.generate_from_wildcards(template)
templates.extend(new_prompts)
template = templates.pop(0)
new_prompts = self.generate_from_variants(template)
no_new_prompts = len(new_prompts) == 1
if no_new_prompts:
all_prompts.append(new_prompts[0])
else:
templates.extend(new_prompts)
return all_prompts[:max_prompts]