Refactored Magic Prompts configuration

Magic Prompt models are now read from a text file enabling users to add
their own. Default model and magicprompt batch size options have been
added to the settings tab.
pull/355/head
Adi Eyal 2023-03-29 18:51:41 +03:00
parent b293722b3d
commit 4dd9df6251
8 changed files with 143 additions and 41 deletions

View File

@ -0,0 +1,15 @@
Gustavosta/MagicPrompt-Stable-Diffusion
daspartho/prompt-extend
FredZhang7/anime-anything-promptgen-v2
succinctly/text2image-prompt-generator
microsoft/Promptist
AUTOMATIC/promptgen-lexart
AUTOMATIC/promptgen-majinai-safe
AUTOMATIC/promptgen-majinai-unsafe
kmewhort/stable-diffusion-prompt-bolster
Gustavosta/MagicPrompt-Dalle
Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator
Ar4ikov/gpt2-medium-650k-stable-diffusion-prompt-generator
crumb/bloom-560m-RLHF-SD2-prompter-aesthetic
Meli/GPT2-Prompt
DrishtiSharma/StableDiffusion-Prompt-Generator-GPT-Neo-125M

View File

@ -1,3 +1,4 @@
- 2.9.0 Magic Prompt models are now read from a text file enabling users to add their own. Default model and magicprompt batch size options have been added to the settings tab.
- 2.8.12 Prompts are frozen if the variation strength is greater than 0. See [#310](https://github.com/adieyal/sd-dynamic-prompts/issues/310) - 2.8.12 Prompts are frozen if the variation strength is greater than 0. See [#310](https://github.com/adieyal/sd-dynamic-prompts/issues/310)
- 2.8.11 Fixed the broken wildcards manager, see [#338](https://github.com/adieyal/sd-dynamic-prompts/issues/338) - 2.8.11 Fixed the broken wildcards manager, see [#338](https://github.com/adieyal/sd-dynamic-prompts/issues/338)
- 2.8.10 Magic Prompt now works on M1/M2 Mac - see [#329](https://github.com/adieyal/sd-dynamic-prompts/issues/329) - 2.8.10 Magic Prompt now works on M1/M2 Mac - see [#329](https://github.com/adieyal/sd-dynamic-prompts/issues/329)

View File

@ -1,20 +0,0 @@
from __future__ import annotations
MAGIC_PROMPT_MODELS = [
"Gustavosta/MagicPrompt-Stable-Diffusion",
"daspartho/prompt-extend",
"succinctly/text2image-prompt-generator",
"microsoft/Promptist",
"AUTOMATIC/promptgen-lexart",
"AUTOMATIC/promptgen-majinai-safe",
"AUTOMATIC/promptgen-majinai-unsafe",
"kmewhort/stable-diffusion-prompt-bolster",
"Gustavosta/MagicPrompt-Dalle",
"Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator",
"Ar4ikov/gpt2-medium-650k-stable-diffusion-prompt-generator",
"crumb/bloom-560m-RLHF-SD2-prompter-aesthetic",
"Meli/GPT2-Prompt",
"DrishtiSharma/StableDiffusion-Prompt-Generator-GPT-Neo-125M",
]
DEFAULT_MAGIC_MODEL = MAGIC_PROMPT_MODELS[0]
OPTION_WRITE_RAW_TEMPLATE = "pp_write_raw_template"

View File

@ -17,13 +17,17 @@ from modules.processing import fix_seed
from modules.shared import opts from modules.shared import opts
from sd_dynamic_prompts import callbacks from sd_dynamic_prompts import callbacks
from sd_dynamic_prompts.consts import MAGIC_PROMPT_MODELS
from sd_dynamic_prompts.generator_builder import GeneratorBuilder from sd_dynamic_prompts.generator_builder import GeneratorBuilder
from sd_dynamic_prompts.helpers import get_seeds, should_freeze_prompt from sd_dynamic_prompts.helpers import (
get_magicmodels_path,
get_seeds,
load_magicprompt_models,
should_freeze_prompt,
)
from sd_dynamic_prompts.ui.pnginfo_saver import PngInfoSaver from sd_dynamic_prompts.ui.pnginfo_saver import PngInfoSaver
from sd_dynamic_prompts.ui.prompt_writer import PromptWriter from sd_dynamic_prompts.ui.prompt_writer import PromptWriter
VERSION = "2.8.12" VERSION = "2.9.0"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -34,6 +38,7 @@ if is_debug:
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
base_dir = Path(scripts.basedir()) base_dir = Path(scripts.basedir())
magicprompt_models_path = get_magicmodels_path(base_dir)
def get_wildcard_dir() -> Path: def get_wildcard_dir() -> Path:
@ -171,10 +176,30 @@ class Script(scripts.Script):
with gr.Accordion("Prompt Magic", open=False): with gr.Accordion("Prompt Magic", open=False):
with gr.Group(): with gr.Group():
try:
magicprompt_models = load_magicprompt_models(
magicprompt_models_path,
)
default_magicprompt_model = (
opts.dp_magicprompt_default_model
if hasattr(opts, "dp_magicprompt_default_model")
else magicprompt_models[0]
)
is_magic_model_available = True
except IndexError:
logger.warning(
f"The magicprompts config file at {magicprompt_models_path} does not contain any models.",
)
magicprompt_models = []
default_magicprompt_model = ""
is_magic_model_available = False
is_magic_prompt = gr.Checkbox( is_magic_prompt = gr.Checkbox(
label="Magic prompt", label="Magic prompt",
value=False, value=False,
elem_id="is-magicprompt", elem_id="is-magicprompt",
interactive=is_magic_model_available,
) )
magic_prompt_length = gr.Slider( magic_prompt_length = gr.Slider(
@ -183,6 +208,7 @@ class Script(scripts.Script):
minimum=30, minimum=30,
maximum=300, maximum=300,
step=10, step=10,
interactive=is_magic_model_available,
) )
magic_temp_value = gr.Slider( magic_temp_value = gr.Slider(
@ -191,14 +217,16 @@ class Script(scripts.Script):
minimum=0.1, minimum=0.1,
maximum=3.0, maximum=3.0,
step=0.10, step=0.10,
interactive=is_magic_model_available,
) )
magic_model = gr.Dropdown( magic_model = gr.Dropdown(
MAGIC_PROMPT_MODELS, magicprompt_models,
value=MAGIC_PROMPT_MODELS[0], value=default_magicprompt_model,
multiselect=False, multiselect=False,
label="Magic prompt model", label="Magic prompt model",
elem_id="magic-prompt-model", elem_id="magic-prompt-model",
interactive=is_magic_model_available,
) )
magic_blocklist_regex = gr.Textbox( magic_blocklist_regex = gr.Textbox(
@ -209,14 +237,7 @@ class Script(scripts.Script):
"Regular expression pattern for blocking terms out of the generated prompt. Applied case-insensitively. " "Regular expression pattern for blocking terms out of the generated prompt. Applied case-insensitively. "
'For instance, to block both "purple" and "interdimensional", you could use the pattern "purple|interdimensional".' 'For instance, to block both "purple" and "interdimensional", you could use the pattern "purple|interdimensional".'
), ),
) interactive=is_magic_model_available,
magic_batch_size = gr.Slider(
label="Magic Prompt batch size",
value=1,
minimum=1,
maximum=64,
step=1,
) )
is_feeling_lucky = gr.Checkbox( is_feeling_lucky = gr.Checkbox(
@ -325,7 +346,6 @@ class Script(scripts.Script):
max_generations, max_generations,
magic_model, magic_model,
magic_blocklist_regex, magic_blocklist_regex,
magic_batch_size,
] ]
def process( def process(
@ -349,7 +369,6 @@ class Script(scripts.Script):
max_generations, max_generations,
magic_model, magic_model,
magic_blocklist_regex: str | None, magic_blocklist_regex: str | None,
magic_batch_size,
): ):
if not is_enabled: if not is_enabled:
logger.debug("Dynamic prompts disabled - exiting") logger.debug("Dynamic prompts disabled - exiting")
@ -360,6 +379,7 @@ class Script(scripts.Script):
self._pnginfo_saver.enabled = opts.dp_write_raw_template self._pnginfo_saver.enabled = opts.dp_write_raw_template
self._prompt_writer.enabled = opts.dp_write_prompts_to_file self._prompt_writer.enabled = opts.dp_write_prompts_to_file
self._limit_jinja_prompts = opts.dp_limit_jinja_prompts self._limit_jinja_prompts = opts.dp_limit_jinja_prompts
magicprompt_batch_size = opts.dp_magicprompt_batch_size
parser_config = ParserConfig( parser_config = ParserConfig(
variant_start=opts.dp_parser_variant_start, variant_start=opts.dp_parser_variant_start,
@ -383,6 +403,7 @@ class Script(scripts.Script):
try: try:
logger.debug("Creating generator") logger.debug("Creating generator")
generator_builder = ( generator_builder = (
GeneratorBuilder( GeneratorBuilder(
self._wildcard_manager, self._wildcard_manager,
@ -401,12 +422,12 @@ class Script(scripts.Script):
) )
.set_is_combinatorial(is_combinatorial, combinatorial_batches) .set_is_combinatorial(is_combinatorial, combinatorial_batches)
.set_is_magic_prompt( .set_is_magic_prompt(
is_magic_prompt, is_magic_prompt=is_magic_prompt,
magic_model=magic_model, magic_model=magic_model,
magic_prompt_length=magic_prompt_length, magic_prompt_length=magic_prompt_length,
magic_temp_value=magic_temp_value, magic_temp_value=magic_temp_value,
magic_blocklist_regex=magic_blocklist_regex, magic_blocklist_regex=magic_blocklist_regex,
batch_size=magic_batch_size, batch_size=magicprompt_batch_size,
device=device, device=device,
) )
.set_is_dummy(False) .set_is_dummy(False)

View File

@ -13,7 +13,6 @@ from dynamicprompts.generators import (
) )
from dynamicprompts.parser.parse import default_parser_config from dynamicprompts.parser.parse import default_parser_config
from sd_dynamic_prompts.consts import DEFAULT_MAGIC_MODEL
from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -102,13 +101,17 @@ class GeneratorBuilder:
def set_is_magic_prompt( def set_is_magic_prompt(
self, self,
is_magic_prompt=True, is_magic_prompt=True,
magic_model=DEFAULT_MAGIC_MODEL, magic_model=None,
magic_prompt_length=100, magic_prompt_length=100,
magic_temp_value=0.7, magic_temp_value=0.7,
device=0, device=0,
magic_blocklist_regex: str | None = None, magic_blocklist_regex: str | None = None,
batch_size=1, batch_size=1,
): ):
if not magic_model:
self._is_magic_prompt = False
return self
self._magic_model = magic_model self._magic_model = magic_model
self._magic_prompt_length = magic_prompt_length self._magic_prompt_length = magic_prompt_length
self._magic_temp_value = magic_temp_value self._magic_temp_value = magic_temp_value

View File

@ -1,3 +1,11 @@
from __future__ import annotations
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
def get_seeds(p, num_seeds, use_fixed_seed): def get_seeds(p, num_seeds, use_fixed_seed):
if p.subseed_strength != 0: if p.subseed_strength != 0:
seed = int(p.all_seeds[0]) seed = int(p.all_seeds[0])
@ -23,3 +31,24 @@ def get_seeds(p, num_seeds, use_fixed_seed):
def should_freeze_prompt(p): def should_freeze_prompt(p):
# When using a variation seed, the prompt shouldn't change between generations # When using a variation seed, the prompt shouldn't change between generations
return p.subseed_strength > 0 return p.subseed_strength > 0
def load_magicprompt_models(modelfile: str) -> list[str]:
try:
models = []
with open(modelfile) as f:
for line in f:
# ignore comments and empty lines
line = line.split("#")[0].strip()
if line:
models.append(line)
return models
except FileNotFoundError:
logger.warning(f"Could not find magicprompts config file at {modelfile}")
return []
def get_magicmodels_path(base_dir: str) -> str:
magicprompt_models_path = Path(base_dir / "config" / "magicprompt_models.txt")
return magicprompt_models_path

View File

@ -1,4 +1,11 @@
from modules import shared from pathlib import Path
import gradio as gr
from modules import scripts, shared
from sd_dynamic_prompts.helpers import get_magicmodels_path, load_magicprompt_models
base_dir = Path(scripts.basedir())
def on_ui_settings(): def on_ui_settings():
@ -59,3 +66,26 @@ def on_ui_settings():
section=section, section=section,
), ),
) )
magic_models = load_magicprompt_models(get_magicmodels_path(base_dir))
shared.opts.add_option(
key="dp_magicprompt_default_model",
info=shared.OptionInfo(
magic_models[0] if magic_models else "",
label="Default magic prompt model",
component=gr.Dropdown,
component_args={"choices": magic_models},
section=section,
),
)
shared.opts.add_option(
key="dp_magicprompt_batch_size",
info=shared.OptionInfo(
1,
label="Magic Prompt batch size (higher is faster but uses more memory)",
component=gr.Slider,
component_args={"minimum": 1, "maximum": 64, "step": 1},
section=section,
),
)

View File

@ -1,8 +1,10 @@
import os
import tempfile
from unittest import mock from unittest import mock
import pytest import pytest
from sd_dynamic_prompts.helpers import get_seeds from sd_dynamic_prompts.helpers import get_seeds, load_magicprompt_models
@pytest.fixture @pytest.fixture
@ -45,3 +47,24 @@ def test_get_seeds_with_random_seed(processing):
seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False) seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False)
assert seeds == [seed] * num_seeds assert seeds == [seed] * num_seeds
assert subseeds == list(range(subseed, subseed + num_seeds)) assert subseeds == list(range(subseed, subseed + num_seeds))
def test_load_magicprompt_models():
s = """# a comment
model1 # another comment
# empty lines below
model 2
"""
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file:
tmp_file.write(s)
tmp_filename = tmp_file.name
try:
load_magicprompt_models(tmp_filename)
finally:
os.unlink(tmp_filename)