sd-dynamic-prompts/prompts/generator_builder.py

184 lines
6.0 KiB
Python

from __future__ import annotations
import logging
from dynamicprompts.generators import (
DummyGenerator,
FeelingLuckyGenerator,
RandomPromptGenerator,
CombinatorialPromptGenerator,
BatchedCombinatorialPromptGenerator,
PromptGenerator,
JinjaGenerator
)
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
from dynamicprompts.generators.attentiongenerator import AttentionGenerator
logger = logging.getLogger(__name__)
class GeneratorBuilder:
def __init__(self, wildcard_manager, ignore_whitespace=False):
self._wildcard_manager = wildcard_manager
self._is_dummy = False
self._is_feeling_lucky = False
self._is_jinja_template = False
self._is_combinatorial = False
self._is_magic_prompt = False
self._is_attention_grabber = False
self._combinatorial_batches = 1
self._magic_prompt_length = 100
self._magic_temp_value = 0.7
self._min_attention = 1.1
self._max_attention = 1.5
self._device = 0
self._ignore_whitespace = ignore_whitespace
def log_configuration(self):
logger.debug(
f"""
Creating generator:
is_dummy: {self._is_dummy}
is_feeling_lucky: {self._is_feeling_lucky}
enable_jinja_templates: {self._is_jinja_template}
is_combinatorial: {self._is_combinatorial}
is_magic_prompt: {self._is_magic_prompt}
combinatorial_batches: {self._combinatorial_batches}
magic_prompt_length: {self._magic_prompt_length}
magic_temp_value: {self._magic_temp_value}
is_attention_grabber: {self._is_attention_grabber}
min_attention: {self._min_attention}
max_attention: {self._max_attention}
"""
)
def set_is_dummy(self, is_dummy=True):
self._is_dummy = is_dummy
return self
def set_is_feeling_lucky(self, is_feeling_lucky=True):
self._is_feeling_lucky = is_feeling_lucky
return self
def set_is_attention_grabber(
self, is_attention_grabber=True, min_attention=1.1, max_attention=1.5
):
self._is_attention_grabber = is_attention_grabber
self._min_attention = min_attention
self._max_attention = max_attention
return self
def set_is_jinja_template(self, is_jinja_template=True):
self._is_jinja_template = is_jinja_template
return self
def set_is_combinatorial(self, is_combinatorial=True, combinatorial_batches=1):
self._is_combinatorial = is_combinatorial
self._combinatorial_batches = combinatorial_batches
return self
def set_is_magic_prompt(
self, is_magic_prompt=True, magic_prompt_length=100, magic_temp_value=0.7, device=0
):
self._magic_prompt_length = magic_prompt_length
self._magic_temp_value = magic_temp_value
self._is_magic_prompt = is_magic_prompt
self._device = device
return self
def create_generator(
self,
original_seed,
context,
unlink_seed_from_prompt=False,
):
if self._is_dummy:
return DummyGenerator()
elif self._is_feeling_lucky:
generator = FeelingLuckyGenerator()
elif self._is_jinja_template:
generator = self.create_jinja_generator(context)
else:
generator = self.create_basic_generator(
original_seed,
unlink_seed_from_prompt,
)
if self._is_magic_prompt:
generator = MagicPromptGenerator(
generator,
self._device,
self._magic_prompt_length,
self._magic_temp_value,
seed=original_seed,
)
if self._is_attention_grabber:
generator = AttentionGenerator(
generator,
min_attention=self._min_attention,
max_attention=self._max_attention,
)
return generator
def create_basic_generator(
self,
original_seed: int,
unlink_seed_from_prompt: bool = False,
) -> PromptGenerator:
if self._is_combinatorial:
prompt_generator = CombinatorialPromptGenerator(self._wildcard_manager, ignore_whitespace=self._ignore_whitespace)
prompt_generator = BatchedCombinatorialPromptGenerator(
prompt_generator, self._combinatorial_batches
)
else:
prompt_generator = RandomPromptGenerator(
self._wildcard_manager, original_seed, unlink_seed_from_prompt, ignore_whitespace=self._ignore_whitespace
)
return prompt_generator
def create_jinja_generator(self, p) -> PromptGenerator:
original_prompt = p.all_prompts[0] if len(p.all_prompts) > 0 else p.prompt
original_negative_prompt = (
p.all_negative_prompts[0]
if len(p.all_negative_prompts) > 0
else p.negative_prompt
)
context = {
"model": {
"filename": p.sd_model.sd_checkpoint_info.filename,
"title": p.sd_model.sd_checkpoint_info.title,
"hash": p.sd_model.sd_checkpoint_info.hash,
"model_name": p.sd_model.sd_checkpoint_info.model_name,
},
"image": {
"width": p.width,
"height": p.height,
},
"parameters": {
"steps": p.steps,
"batch_size": p.batch_size,
"num_batches": p.n_iter,
"width": p.width,
"height": p.height,
"cfg_scale": p.cfg_scale,
"sampler_name": p.sampler_name,
"seed": p.seed,
},
"prompt": {
"prompt": original_prompt,
"negative_prompt": original_negative_prompt,
},
}
generator = JinjaGenerator(self._wildcard_manager, context)
return generator