Lots of additional debugging

bug/magic-prompt-not-working
Adi Eyal 2022-11-20 14:53:25 +02:00
parent 64450c4ac6
commit e10a8f8882
2 changed files with 19 additions and 5 deletions

View File

@ -1,12 +1,13 @@
from __future__ import annotations
from . import PromptGenerator
import logging
import re
from tqdm import trange
MODEL_NAME = "Gustavosta/MagicPrompt-Stable-Diffusion"
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class MagicPromptGenerator(PromptGenerator):
generator = None
@ -32,30 +33,41 @@ class MagicPromptGenerator(PromptGenerator):
def __init__(
self,
label: str,
prompt_generator: PromptGenerator,
max_prompt_length: int = 100,
temperature: float = 0.7,
):
self._label = label
self._generator = self._load_pipeline()
self._prompt_generator = prompt_generator
self._max_prompt_length = max_prompt_length
self._temperature = float(temperature)
logger.debug(f"{self._label} - MagicPromptGenerator initialized")
logger.debug(self._generator)
def generate(self, *args, **kwargs) -> list[str]:
logger.debug(f"{self._label} - Start of magic prompt generation")
prompts = self._prompt_generator.generate(*args, **kwargs)
logger.debug(f"{self._label} - Got prompts from prompt generator")
logger.debug(prompts)
new_prompts = []
for i in trange(len(prompts), desc="Generating Magic prompts"):
logger.debug(f"{self._label} - Generating magic prompt for {prompts[i]}")
orig_prompt = prompts[i]
magic_prompt = self._generator(
orig_prompt,
max_length=self._max_prompt_length,
temperature=self._temperature,
)[0]["generated_text"]
logger.debug(f"{self._label} - Got magic prompt: {magic_prompt}")
magic_prompt = self.clean_up_magic_prompt(orig_prompt, magic_prompt)
logger.debug(f"{self._label} - Cleaned up magic prompt: {magic_prompt}")
new_prompts.append(magic_prompt)
logger.debug(f"{self._label} - Returning {len(new_prompts)} magic prompts")
return new_prompts
def clean_up_magic_prompt(self, orig_prompt, prompt):

View File

@ -107,7 +107,7 @@ def new_generation(prompt) -> PromptGenerator:
return generator
class Script(scripts.Script):
def _create_generator(self, original_prompt, original_seed, is_dummy=False, is_feeling_lucky=False, is_attention_grabber=False, enable_jinja_templates=False, is_combinatorial=False, is_magic_prompt=False, combinatorial_batches=1, magic_prompt_length=100, magic_temp_value=0.7):
def _create_generator(self, label, original_prompt, original_seed, is_dummy=False, is_feeling_lucky=False, is_attention_grabber=False, enable_jinja_templates=False, is_combinatorial=False, is_magic_prompt=False, combinatorial_batches=1, magic_prompt_length=100, magic_temp_value=0.7):
logger.debug(f"""
Creating generator:
original_prompt: {original_prompt}
@ -138,7 +138,7 @@ class Script(scripts.Script):
if is_magic_prompt:
generator = MagicPromptGenerator(
generator, magic_prompt_length, magic_temp_value
label, generator, magic_prompt_length, magic_temp_value
)
if is_attention_grabber:
@ -341,6 +341,7 @@ class Script(scripts.Script):
try:
logger.debug("Creating positive generator")
generator = self._create_generator(
"Positive prompt generator",
original_prompt,
original_seed,
is_feeling_lucky=is_feeling_lucky,
@ -356,6 +357,7 @@ class Script(scripts.Script):
logger.debug("Creating negative generator")
self._negative_prompt_generator = self._create_generator(
"Negative prompt generator",
p.negative_prompt,
original_seed,
is_feeling_lucky=is_feeling_lucky,
@ -419,7 +421,7 @@ class Script(scripts.Script):
p.prompt = original_prompt
p.seed = original_seed
logger.debug("Finall positive prompts check")
logger.debug("Final positive prompts check")
for prompt in p.all_prompts:
logger.debug(f"Prompt: {prompt}")