254 lines
8.1 KiB
Python
254 lines
8.1 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from dynamicprompts.generators import (
|
|
BatchedCombinatorialPromptGenerator,
|
|
CombinatorialPromptGenerator,
|
|
DummyGenerator,
|
|
FeelingLuckyGenerator,
|
|
JinjaGenerator,
|
|
PromptGenerator,
|
|
RandomPromptGenerator,
|
|
)
|
|
from dynamicprompts.parser.parse import default_parser_config
|
|
|
|
from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GeneratorBuilder:
|
|
def __init__(
|
|
self,
|
|
wildcard_manager,
|
|
parser_config=default_parser_config,
|
|
ignore_whitespace=False,
|
|
):
|
|
self._wildcard_manager = wildcard_manager
|
|
|
|
self._is_dummy = False
|
|
self._should_freeze_prompt = False
|
|
self._is_feeling_lucky = False
|
|
self._is_jinja_template = False
|
|
self._is_combinatorial = False
|
|
self._is_magic_prompt = False
|
|
self._is_attention_grabber = False
|
|
|
|
self._combinatorial_batches = 1
|
|
self._magic_model = None
|
|
self._magic_prompt_length = 100
|
|
self._magic_temp_value = 0.7
|
|
self._magic_blocklist_regex = None
|
|
self._min_attention = 1.1
|
|
self._max_attention = 1.5
|
|
self._device = 0
|
|
self._ignore_whitespace = ignore_whitespace
|
|
self._unlink_seed_from_prompt = False
|
|
self._seed = -1
|
|
self._context = None
|
|
self._parser_config = parser_config
|
|
|
|
def log_configuration(self):
|
|
logger.debug(
|
|
f"""
|
|
Creating generator:
|
|
is_dummy: {self._is_dummy}
|
|
is_feeling_lucky: {self._is_feeling_lucky}
|
|
enable_jinja_templates: {self._is_jinja_template}
|
|
is_combinatorial: {self._is_combinatorial}
|
|
is_magic_prompt: {self._is_magic_prompt}
|
|
combinatorial_batches: {self._combinatorial_batches}
|
|
magic_prompt_length: {self._magic_prompt_length}
|
|
magic_temp_value: {self._magic_temp_value}
|
|
magic_blocklist_regex: {self._magic_blocklist_regex}
|
|
is_attention_grabber: {self._is_attention_grabber}
|
|
min_attention: {self._min_attention}
|
|
max_attention: {self._max_attention}
|
|
|
|
""",
|
|
)
|
|
|
|
def set_is_dummy(self, is_dummy=True):
|
|
self._is_dummy = is_dummy
|
|
return self
|
|
|
|
def set_is_feeling_lucky(self, is_feeling_lucky=True):
|
|
self._is_feeling_lucky = is_feeling_lucky
|
|
return self
|
|
|
|
def set_is_attention_grabber(
|
|
self,
|
|
is_attention_grabber=True,
|
|
min_attention=1.1,
|
|
max_attention=1.5,
|
|
):
|
|
self._is_attention_grabber = is_attention_grabber
|
|
self._min_attention = min_attention
|
|
self._max_attention = max_attention
|
|
return self
|
|
|
|
def set_is_jinja_template(self, is_jinja_template=True, limit_prompts=False):
|
|
self._is_jinja_template = is_jinja_template
|
|
self._limit_jinja_prompts = limit_prompts
|
|
return self
|
|
|
|
def set_is_combinatorial(self, is_combinatorial=True, combinatorial_batches=1):
|
|
self._is_combinatorial = is_combinatorial
|
|
self._combinatorial_batches = combinatorial_batches
|
|
return self
|
|
|
|
def set_is_magic_prompt(
|
|
self,
|
|
is_magic_prompt=True,
|
|
magic_model=None,
|
|
magic_prompt_length=100,
|
|
magic_temp_value=0.7,
|
|
device=0,
|
|
magic_blocklist_regex: str | None = None,
|
|
batch_size=1,
|
|
):
|
|
if not magic_model:
|
|
self._is_magic_prompt = False
|
|
return self
|
|
|
|
self._magic_model = magic_model
|
|
self._magic_prompt_length = magic_prompt_length
|
|
self._magic_temp_value = magic_temp_value
|
|
self._magic_blocklist_regex = magic_blocklist_regex
|
|
self._is_magic_prompt = is_magic_prompt
|
|
self._magic_batch_size = batch_size
|
|
self._device = device
|
|
|
|
return self
|
|
|
|
def set_unlink_seed_from_prompt(self, unlink_seed_from_prompt=True):
|
|
self._unlink_seed_from_prompt = unlink_seed_from_prompt
|
|
return self
|
|
|
|
def set_seed(self, seed):
|
|
self._seed = seed
|
|
return self
|
|
|
|
def set_freeze_prompt(self, should_freeze: bool):
|
|
self._should_freeze_prompt = should_freeze
|
|
return self
|
|
|
|
def set_context(self, context):
|
|
self._context = context
|
|
return self
|
|
|
|
def disable_prompt_magic(self):
|
|
self.set_is_attention_grabber(False)
|
|
self.set_is_magic_prompt(False)
|
|
self.set_is_feeling_lucky(False)
|
|
|
|
return self
|
|
|
|
def create_generator(self):
|
|
if self._is_dummy:
|
|
return DummyGenerator()
|
|
|
|
elif self._is_feeling_lucky:
|
|
generator = FeelingLuckyGenerator()
|
|
|
|
elif self._is_jinja_template:
|
|
generator = self.create_jinja_generator(self._context)
|
|
else:
|
|
generator = self.create_basic_generator()
|
|
|
|
if self._is_magic_prompt:
|
|
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
|
|
|
|
generator = MagicPromptGenerator(
|
|
generator,
|
|
model_name=self._magic_model,
|
|
device=self._device,
|
|
max_prompt_length=self._magic_prompt_length,
|
|
temperature=self._magic_temp_value,
|
|
seed=self._seed,
|
|
blocklist_regex=self._magic_blocklist_regex,
|
|
batch_size=self._magic_batch_size,
|
|
)
|
|
|
|
if self._is_attention_grabber:
|
|
try:
|
|
from sd_dynamic_prompts.attention_generator import (
|
|
SpecialSyntaxAwareAttentionGenerator,
|
|
)
|
|
|
|
generator = SpecialSyntaxAwareAttentionGenerator(
|
|
generator,
|
|
min_attention=self._min_attention,
|
|
max_attention=self._max_attention,
|
|
)
|
|
except ImportError as ie:
|
|
logger.error(f"Not using AttentionGenerator: {ie}")
|
|
|
|
if self._should_freeze_prompt:
|
|
generator = FrozenPromptGenerator(generator)
|
|
return generator
|
|
|
|
def create_basic_generator(
|
|
self,
|
|
) -> PromptGenerator:
|
|
if self._is_combinatorial:
|
|
prompt_generator = CombinatorialPromptGenerator(
|
|
self._wildcard_manager,
|
|
parser_config=self._parser_config,
|
|
ignore_whitespace=self._ignore_whitespace,
|
|
)
|
|
return BatchedCombinatorialPromptGenerator(
|
|
prompt_generator,
|
|
batches=self._combinatorial_batches,
|
|
)
|
|
return RandomPromptGenerator(
|
|
self._wildcard_manager,
|
|
seed=self._seed,
|
|
parser_config=self._parser_config,
|
|
unlink_seed_from_prompt=self._unlink_seed_from_prompt,
|
|
ignore_whitespace=self._ignore_whitespace,
|
|
)
|
|
|
|
def create_jinja_generator(self, p) -> PromptGenerator:
|
|
original_prompt = p.all_prompts[0] if len(p.all_prompts) > 0 else p.prompt
|
|
original_negative_prompt = (
|
|
p.all_negative_prompts[0]
|
|
if len(p.all_negative_prompts) > 0
|
|
else p.negative_prompt
|
|
)
|
|
context = {
|
|
"model": {
|
|
"filename": p.sd_model.sd_checkpoint_info.filename,
|
|
"title": p.sd_model.sd_checkpoint_info.title,
|
|
"hash": p.sd_model.sd_checkpoint_info.hash,
|
|
"model_name": p.sd_model.sd_checkpoint_info.model_name,
|
|
},
|
|
"image": {
|
|
"width": p.width,
|
|
"height": p.height,
|
|
},
|
|
"parameters": {
|
|
"steps": p.steps,
|
|
"batch_size": p.batch_size,
|
|
"num_batches": p.n_iter,
|
|
"width": p.width,
|
|
"height": p.height,
|
|
"cfg_scale": p.cfg_scale,
|
|
"sampler_name": p.sampler_name,
|
|
"seed": p.seed,
|
|
},
|
|
"prompt": {
|
|
"prompt": original_prompt,
|
|
"negative_prompt": original_negative_prompt,
|
|
},
|
|
}
|
|
|
|
generator = JinjaGenerator(
|
|
self._wildcard_manager,
|
|
context,
|
|
limit_prompts=self._limit_jinja_prompts,
|
|
ignore_whitespace=self._ignore_whitespace,
|
|
)
|
|
return generator
|