Lots of additional debugging
parent
64450c4ac6
commit
e10a8f8882
|
|
@ -1,12 +1,13 @@
|
|||
from __future__ import annotations
|
||||
from . import PromptGenerator
|
||||
|
||||
import logging
|
||||
import re
|
||||
from tqdm import trange
|
||||
|
||||
MODEL_NAME = "Gustavosta/MagicPrompt-Stable-Diffusion"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
class MagicPromptGenerator(PromptGenerator):
|
||||
generator = None
|
||||
|
||||
|
|
@ -32,30 +33,41 @@ class MagicPromptGenerator(PromptGenerator):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
label: str,
|
||||
prompt_generator: PromptGenerator,
|
||||
max_prompt_length: int = 100,
|
||||
temperature: float = 0.7,
|
||||
):
|
||||
self._label = label
|
||||
self._generator = self._load_pipeline()
|
||||
self._prompt_generator = prompt_generator
|
||||
self._max_prompt_length = max_prompt_length
|
||||
self._temperature = float(temperature)
|
||||
logger.debug(f"{self._label} - MagicPromptGenerator initialized")
|
||||
logger.debug(self._generator)
|
||||
|
||||
def generate(self, *args, **kwargs) -> list[str]:
|
||||
logger.debug(f"{self._label} - Start of magic prompt generation")
|
||||
prompts = self._prompt_generator.generate(*args, **kwargs)
|
||||
logger.debug(f"{self._label} - Got prompts from prompt generator")
|
||||
logger.debug(prompts)
|
||||
|
||||
new_prompts = []
|
||||
for i in trange(len(prompts), desc="Generating Magic prompts"):
|
||||
logger.debug(f"{self._label} - Generating magic prompt for {prompts[i]}")
|
||||
orig_prompt = prompts[i]
|
||||
magic_prompt = self._generator(
|
||||
orig_prompt,
|
||||
max_length=self._max_prompt_length,
|
||||
temperature=self._temperature,
|
||||
)[0]["generated_text"]
|
||||
logger.debug(f"{self._label} - Got magic prompt: {magic_prompt}")
|
||||
|
||||
magic_prompt = self.clean_up_magic_prompt(orig_prompt, magic_prompt)
|
||||
logger.debug(f"{self._label} - Cleaned up magic prompt: {magic_prompt}")
|
||||
new_prompts.append(magic_prompt)
|
||||
|
||||
logger.debug(f"{self._label} - Returning {len(new_prompts)} magic prompts")
|
||||
return new_prompts
|
||||
|
||||
def clean_up_magic_prompt(self, orig_prompt, prompt):
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ def new_generation(prompt) -> PromptGenerator:
|
|||
return generator
|
||||
|
||||
class Script(scripts.Script):
|
||||
def _create_generator(self, original_prompt, original_seed, is_dummy=False, is_feeling_lucky=False, is_attention_grabber=False, enable_jinja_templates=False, is_combinatorial=False, is_magic_prompt=False, combinatorial_batches=1, magic_prompt_length=100, magic_temp_value=0.7):
|
||||
def _create_generator(self, label, original_prompt, original_seed, is_dummy=False, is_feeling_lucky=False, is_attention_grabber=False, enable_jinja_templates=False, is_combinatorial=False, is_magic_prompt=False, combinatorial_batches=1, magic_prompt_length=100, magic_temp_value=0.7):
|
||||
logger.debug(f"""
|
||||
Creating generator:
|
||||
original_prompt: {original_prompt}
|
||||
|
|
@ -138,7 +138,7 @@ class Script(scripts.Script):
|
|||
|
||||
if is_magic_prompt:
|
||||
generator = MagicPromptGenerator(
|
||||
generator, magic_prompt_length, magic_temp_value
|
||||
label, generator, magic_prompt_length, magic_temp_value
|
||||
)
|
||||
|
||||
if is_attention_grabber:
|
||||
|
|
@ -341,6 +341,7 @@ class Script(scripts.Script):
|
|||
try:
|
||||
logger.debug("Creating positive generator")
|
||||
generator = self._create_generator(
|
||||
"Positive prompt generator",
|
||||
original_prompt,
|
||||
original_seed,
|
||||
is_feeling_lucky=is_feeling_lucky,
|
||||
|
|
@ -356,6 +357,7 @@ class Script(scripts.Script):
|
|||
|
||||
logger.debug("Creating negative generator")
|
||||
self._negative_prompt_generator = self._create_generator(
|
||||
"Negative prompt generator",
|
||||
p.negative_prompt,
|
||||
original_seed,
|
||||
is_feeling_lucky=is_feeling_lucky,
|
||||
|
|
@ -419,7 +421,7 @@ class Script(scripts.Script):
|
|||
p.prompt = original_prompt
|
||||
p.seed = original_seed
|
||||
|
||||
logger.debug("Finall positive prompts check")
|
||||
logger.debug("Final positive prompts check")
|
||||
for prompt in p.all_prompts:
|
||||
logger.debug(f"Prompt: {prompt}")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue