Make Magic Prompts aware of LoRA syntax too (#708)

Refs #707
pull/710/head
Aarni Koskela 2024-01-16 12:41:46 +02:00 committed by GitHub
parent b1edd80487
commit eea2ebb68f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 91 additions and 28 deletions

View File

@ -1,31 +1,9 @@
import re
from dynamicprompts.generators.attentiongenerator import AttentionGenerator
# A1111 special syntax (LoRA, hypernet, etc.)
A1111_SPECIAL_SYNTAX_RE = re.compile(r"\s*<[^>]+>")
def remove_a1111_special_syntax_chunks(s: str) -> tuple[str, list[str]]:
"""
Remove A1111 special syntax chunks from a string and return the string and the chunks.
"""
chunks: list[str] = []
def put_chunk(m):
chunks.append(m.group(0))
return ""
return re.sub(A1111_SPECIAL_SYNTAX_RE, put_chunk, s), chunks
def append_chunks(s: str, chunks: list[str]) -> str:
"""
Append (A1111 special syntax) chunks to a string.
"""
if not chunks:
return s
return f"{s}{''.join(chunks)}"
from sd_dynamic_prompts.special_syntax import (
append_chunks,
remove_a1111_special_syntax_chunks,
)
class SpecialSyntaxAwareAttentionGenerator(AttentionGenerator):

View File

@ -158,9 +158,11 @@ class GeneratorBuilder:
generator = self.create_basic_generator()
if self._is_magic_prompt:
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
from sd_dynamic_prompts.magic_prompt import (
SpecialSyntaxAwareMagicPromptGenerator,
)
generator = MagicPromptGenerator(
generator = SpecialSyntaxAwareMagicPromptGenerator(
generator,
model_name=self._magic_model,
device=self._device,

View File

@ -0,0 +1,26 @@
from itertools import zip_longest
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
from sd_dynamic_prompts.special_syntax import (
append_chunks,
remove_a1111_special_syntax_chunks,
)
class SpecialSyntaxAwareMagicPromptGenerator(MagicPromptGenerator):
"""
Magic Prompt generator that is aware of A1111 special syntax (LoRA, hypernet, etc.).
"""
def _generate_magic_prompts(self, orig_prompts: list[str]) -> list[str]:
orig_prompts, chunks = zip(
*(remove_a1111_special_syntax_chunks(p) for p in orig_prompts),
)
magic_prompts = super()._generate_magic_prompts(orig_prompts)
# in case we somehow get less magic prompts than we started with,
# use zip_longest instead of zip.
return [
append_chunks(prompt, chunk)
for prompt, chunk in zip_longest(magic_prompts, chunks, fillvalue=None)
]

View File

@ -0,0 +1,26 @@
import re
# A1111 special syntax (LoRA, hypernet, etc.)
A1111_SPECIAL_SYNTAX_RE = re.compile(r"\s*<[^>]+>")
def remove_a1111_special_syntax_chunks(s: str) -> tuple[str, list[str]]:
"""
Remove A1111 special syntax chunks from a string and return the string and the chunks.
"""
chunks: list[str] = []
def put_chunk(m):
chunks.append(m.group(0))
return ""
return re.sub(A1111_SPECIAL_SYNTAX_RE, put_chunk, s), chunks
def append_chunks(s: str, chunks: list[str]) -> str:
"""
Append (A1111 special syntax) chunks to a string.
"""
if not chunks:
return s
return f"{s}{''.join(chunks)}"

View File

@ -0,0 +1,31 @@
def fake_generator(prompts, **_kwargs):
for prompt in prompts:
assert "<" not in prompt # should have been stripped
yield [{"generated_text": f"magical {prompt}"}]
def test_magic_prompts(monkeypatch):
# Instrument the superclass so it doesn't try to load the model
import dynamicprompts.generators.magicprompt as mp
if hasattr(mp, "_import_transformers"):
monkeypatch.setattr(mp, "_import_transformers", lambda: None)
monkeypatch.setattr(
mp.MagicPromptGenerator,
"_load_pipeline",
lambda self, model_name: fake_generator,
)
from sd_dynamic_prompts.magic_prompt import SpecialSyntaxAwareMagicPromptGenerator
generator = SpecialSyntaxAwareMagicPromptGenerator()
for prompt in generator.generate(
"purple cat singing opera, artistic, painting "
"<lora:loraname:0.7> <hypernet:v18000Steps:1>",
5,
):
# These must remain unchanged
assert "<lora:loraname:0.7>" in prompt
assert "<hypernet:v18000Steps:1>" in prompt
# but we should expect to see some magic
assert prompt.startswith("magical ")