parent
b1edd80487
commit
eea2ebb68f
|
|
@ -1,31 +1,9 @@
|
|||
import re
|
||||
|
||||
from dynamicprompts.generators.attentiongenerator import AttentionGenerator
|
||||
|
||||
# A1111 special syntax (LoRA, hypernet, etc.)
|
||||
A1111_SPECIAL_SYNTAX_RE = re.compile(r"\s*<[^>]+>")
|
||||
|
||||
|
||||
def remove_a1111_special_syntax_chunks(s: str) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Remove A1111 special syntax chunks from a string and return the string and the chunks.
|
||||
"""
|
||||
chunks: list[str] = []
|
||||
|
||||
def put_chunk(m):
|
||||
chunks.append(m.group(0))
|
||||
return ""
|
||||
|
||||
return re.sub(A1111_SPECIAL_SYNTAX_RE, put_chunk, s), chunks
|
||||
|
||||
|
||||
def append_chunks(s: str, chunks: list[str]) -> str:
|
||||
"""
|
||||
Append (A1111 special syntax) chunks to a string.
|
||||
"""
|
||||
if not chunks:
|
||||
return s
|
||||
return f"{s}{''.join(chunks)}"
|
||||
from sd_dynamic_prompts.special_syntax import (
|
||||
append_chunks,
|
||||
remove_a1111_special_syntax_chunks,
|
||||
)
|
||||
|
||||
|
||||
class SpecialSyntaxAwareAttentionGenerator(AttentionGenerator):
|
||||
|
|
|
|||
|
|
@ -158,9 +158,11 @@ class GeneratorBuilder:
|
|||
generator = self.create_basic_generator()
|
||||
|
||||
if self._is_magic_prompt:
|
||||
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
|
||||
from sd_dynamic_prompts.magic_prompt import (
|
||||
SpecialSyntaxAwareMagicPromptGenerator,
|
||||
)
|
||||
|
||||
generator = MagicPromptGenerator(
|
||||
generator = SpecialSyntaxAwareMagicPromptGenerator(
|
||||
generator,
|
||||
model_name=self._magic_model,
|
||||
device=self._device,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
from itertools import zip_longest
|
||||
|
||||
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
|
||||
|
||||
from sd_dynamic_prompts.special_syntax import (
|
||||
append_chunks,
|
||||
remove_a1111_special_syntax_chunks,
|
||||
)
|
||||
|
||||
|
||||
class SpecialSyntaxAwareMagicPromptGenerator(MagicPromptGenerator):
|
||||
"""
|
||||
Magic Prompt generator that is aware of A1111 special syntax (LoRA, hypernet, etc.).
|
||||
"""
|
||||
|
||||
def _generate_magic_prompts(self, orig_prompts: list[str]) -> list[str]:
|
||||
orig_prompts, chunks = zip(
|
||||
*(remove_a1111_special_syntax_chunks(p) for p in orig_prompts),
|
||||
)
|
||||
magic_prompts = super()._generate_magic_prompts(orig_prompts)
|
||||
# in case we somehow get less magic prompts than we started with,
|
||||
# use zip_longest instead of zip.
|
||||
return [
|
||||
append_chunks(prompt, chunk)
|
||||
for prompt, chunk in zip_longest(magic_prompts, chunks, fillvalue=None)
|
||||
]
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
import re
|
||||
|
||||
# A1111 special syntax (LoRA, hypernet, etc.)
|
||||
A1111_SPECIAL_SYNTAX_RE = re.compile(r"\s*<[^>]+>")
|
||||
|
||||
|
||||
def remove_a1111_special_syntax_chunks(s: str) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Remove A1111 special syntax chunks from a string and return the string and the chunks.
|
||||
"""
|
||||
chunks: list[str] = []
|
||||
|
||||
def put_chunk(m):
|
||||
chunks.append(m.group(0))
|
||||
return ""
|
||||
|
||||
return re.sub(A1111_SPECIAL_SYNTAX_RE, put_chunk, s), chunks
|
||||
|
||||
|
||||
def append_chunks(s: str, chunks: list[str]) -> str:
|
||||
"""
|
||||
Append (A1111 special syntax) chunks to a string.
|
||||
"""
|
||||
if not chunks:
|
||||
return s
|
||||
return f"{s}{''.join(chunks)}"
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
def fake_generator(prompts, **_kwargs):
|
||||
for prompt in prompts:
|
||||
assert "<" not in prompt # should have been stripped
|
||||
yield [{"generated_text": f"magical {prompt}"}]
|
||||
|
||||
|
||||
def test_magic_prompts(monkeypatch):
|
||||
# Instrument the superclass so it doesn't try to load the model
|
||||
import dynamicprompts.generators.magicprompt as mp
|
||||
|
||||
if hasattr(mp, "_import_transformers"):
|
||||
monkeypatch.setattr(mp, "_import_transformers", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
mp.MagicPromptGenerator,
|
||||
"_load_pipeline",
|
||||
lambda self, model_name: fake_generator,
|
||||
)
|
||||
|
||||
from sd_dynamic_prompts.magic_prompt import SpecialSyntaxAwareMagicPromptGenerator
|
||||
|
||||
generator = SpecialSyntaxAwareMagicPromptGenerator()
|
||||
for prompt in generator.generate(
|
||||
"purple cat singing opera, artistic, painting "
|
||||
"<lora:loraname:0.7> <hypernet:v18000Steps:1>",
|
||||
5,
|
||||
):
|
||||
# These must remain unchanged
|
||||
assert "<lora:loraname:0.7>" in prompt
|
||||
assert "<hypernet:v18000Steps:1>" in prompt
|
||||
# but we should expect to see some magic
|
||||
assert prompt.startswith("magical ")
|
||||
Loading…
Reference in New Issue