diff --git a/config/magicprompt_models.txt b/config/magicprompt_models.txt new file mode 100644 index 0000000..15ff9d4 --- /dev/null +++ b/config/magicprompt_models.txt @@ -0,0 +1,15 @@ +Gustavosta/MagicPrompt-Stable-Diffusion +daspartho/prompt-extend +FredZhang7/anime-anything-promptgen-v2 +succinctly/text2image-prompt-generator +microsoft/Promptist +AUTOMATIC/promptgen-lexart +AUTOMATIC/promptgen-majinai-safe +AUTOMATIC/promptgen-majinai-unsafe +kmewhort/stable-diffusion-prompt-bolster +Gustavosta/MagicPrompt-Dalle +Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator +Ar4ikov/gpt2-medium-650k-stable-diffusion-prompt-generator +crumb/bloom-560m-RLHF-SD2-prompter-aesthetic +Meli/GPT2-Prompt +DrishtiSharma/StableDiffusion-Prompt-Generator-GPT-Neo-125M diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 6b5d2ba..f824268 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,3 +1,4 @@ +- 2.9.0 Magic Prompt models are now read from a text file enabling users to add their own. Default model and magicprompt batch size options have been added to the settings tab. - 2.8.12 Prompts are frozen if the variation strength is greater than 0. See [#310](https://github.com/adieyal/sd-dynamic-prompts/issues/310) - 2.8.11 Fixed the broken wildcards manager, see [#338](https://github.com/adieyal/sd-dynamic-prompts/issues/338) - 2.8.10 Magic Prompt now works on M1/M2 Mac - see [#329](https://github.com/adieyal/sd-dynamic-prompts/issues/329) diff --git a/sd_dynamic_prompts/consts.py b/sd_dynamic_prompts/consts.py deleted file mode 100644 index bae4ee0..0000000 --- a/sd_dynamic_prompts/consts.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -MAGIC_PROMPT_MODELS = [ - "Gustavosta/MagicPrompt-Stable-Diffusion", - "daspartho/prompt-extend", - "succinctly/text2image-prompt-generator", - "microsoft/Promptist", - "AUTOMATIC/promptgen-lexart", - "AUTOMATIC/promptgen-majinai-safe", - "AUTOMATIC/promptgen-majinai-unsafe", - "kmewhort/stable-diffusion-prompt-bolster", - "Gustavosta/MagicPrompt-Dalle", - "Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator", - "Ar4ikov/gpt2-medium-650k-stable-diffusion-prompt-generator", - "crumb/bloom-560m-RLHF-SD2-prompter-aesthetic", - "Meli/GPT2-Prompt", - "DrishtiSharma/StableDiffusion-Prompt-Generator-GPT-Neo-125M", -] -DEFAULT_MAGIC_MODEL = MAGIC_PROMPT_MODELS[0] -OPTION_WRITE_RAW_TEMPLATE = "pp_write_raw_template" diff --git a/sd_dynamic_prompts/dynamic_prompting.py b/sd_dynamic_prompts/dynamic_prompting.py index 2fe2f53..9154ab6 100644 --- a/sd_dynamic_prompts/dynamic_prompting.py +++ b/sd_dynamic_prompts/dynamic_prompting.py @@ -17,13 +17,17 @@ from modules.processing import fix_seed from modules.shared import opts from sd_dynamic_prompts import callbacks -from sd_dynamic_prompts.consts import MAGIC_PROMPT_MODELS from sd_dynamic_prompts.generator_builder import GeneratorBuilder -from sd_dynamic_prompts.helpers import get_seeds, should_freeze_prompt +from sd_dynamic_prompts.helpers import ( + get_magicmodels_path, + get_seeds, + load_magicprompt_models, + should_freeze_prompt, +) from sd_dynamic_prompts.ui.pnginfo_saver import PngInfoSaver from sd_dynamic_prompts.ui.prompt_writer import PromptWriter -VERSION = "2.8.12" +VERSION = "2.9.0" logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -34,6 +38,7 @@ if is_debug: logger.setLevel(logging.DEBUG) base_dir = Path(scripts.basedir()) +magicprompt_models_path = get_magicmodels_path(base_dir) def get_wildcard_dir() -> Path: @@ -171,10 +176,30 @@ class Script(scripts.Script): with gr.Accordion("Prompt Magic", open=False): with gr.Group(): + try: + magicprompt_models = load_magicprompt_models( + magicprompt_models_path, + ) + default_magicprompt_model = ( + opts.dp_magicprompt_default_model + if hasattr(opts, "dp_magicprompt_default_model") + else magicprompt_models[0] + ) + is_magic_model_available = True + except IndexError: + logger.warning( + f"The magicprompts config file at {magicprompt_models_path} does not contain any models.", + ) + + magicprompt_models = [] + default_magicprompt_model = "" + is_magic_model_available = False + is_magic_prompt = gr.Checkbox( label="Magic prompt", value=False, elem_id="is-magicprompt", + interactive=is_magic_model_available, ) magic_prompt_length = gr.Slider( @@ -183,6 +208,7 @@ class Script(scripts.Script): minimum=30, maximum=300, step=10, + interactive=is_magic_model_available, ) magic_temp_value = gr.Slider( @@ -191,14 +217,16 @@ class Script(scripts.Script): minimum=0.1, maximum=3.0, step=0.10, + interactive=is_magic_model_available, ) magic_model = gr.Dropdown( - MAGIC_PROMPT_MODELS, - value=MAGIC_PROMPT_MODELS[0], + magicprompt_models, + value=default_magicprompt_model, multiselect=False, label="Magic prompt model", elem_id="magic-prompt-model", + interactive=is_magic_model_available, ) magic_blocklist_regex = gr.Textbox( @@ -209,14 +237,7 @@ class Script(scripts.Script): "Regular expression pattern for blocking terms out of the generated prompt. Applied case-insensitively. " 'For instance, to block both "purple" and "interdimensional", you could use the pattern "purple|interdimensional".' ), - ) - - magic_batch_size = gr.Slider( - label="Magic Prompt batch size", - value=1, - minimum=1, - maximum=64, - step=1, + interactive=is_magic_model_available, ) is_feeling_lucky = gr.Checkbox( @@ -325,7 +346,6 @@ class Script(scripts.Script): max_generations, magic_model, magic_blocklist_regex, - magic_batch_size, ] def process( @@ -349,7 +369,6 @@ class Script(scripts.Script): max_generations, magic_model, magic_blocklist_regex: str | None, - magic_batch_size, ): if not is_enabled: logger.debug("Dynamic prompts disabled - exiting") @@ -360,6 +379,7 @@ class Script(scripts.Script): self._pnginfo_saver.enabled = opts.dp_write_raw_template self._prompt_writer.enabled = opts.dp_write_prompts_to_file self._limit_jinja_prompts = opts.dp_limit_jinja_prompts + magicprompt_batch_size = opts.dp_magicprompt_batch_size parser_config = ParserConfig( variant_start=opts.dp_parser_variant_start, @@ -383,6 +403,7 @@ class Script(scripts.Script): try: logger.debug("Creating generator") + generator_builder = ( GeneratorBuilder( self._wildcard_manager, @@ -401,12 +422,12 @@ class Script(scripts.Script): ) .set_is_combinatorial(is_combinatorial, combinatorial_batches) .set_is_magic_prompt( - is_magic_prompt, + is_magic_prompt=is_magic_prompt, magic_model=magic_model, magic_prompt_length=magic_prompt_length, magic_temp_value=magic_temp_value, magic_blocklist_regex=magic_blocklist_regex, - batch_size=magic_batch_size, + batch_size=magicprompt_batch_size, device=device, ) .set_is_dummy(False) diff --git a/sd_dynamic_prompts/generator_builder.py b/sd_dynamic_prompts/generator_builder.py index 0dfef65..62b81be 100644 --- a/sd_dynamic_prompts/generator_builder.py +++ b/sd_dynamic_prompts/generator_builder.py @@ -13,7 +13,6 @@ from dynamicprompts.generators import ( ) from dynamicprompts.parser.parse import default_parser_config -from sd_dynamic_prompts.consts import DEFAULT_MAGIC_MODEL from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator logger = logging.getLogger(__name__) @@ -102,13 +101,17 @@ class GeneratorBuilder: def set_is_magic_prompt( self, is_magic_prompt=True, - magic_model=DEFAULT_MAGIC_MODEL, + magic_model=None, magic_prompt_length=100, magic_temp_value=0.7, device=0, magic_blocklist_regex: str | None = None, batch_size=1, ): + if not magic_model: + self._is_magic_prompt = False + return self + self._magic_model = magic_model self._magic_prompt_length = magic_prompt_length self._magic_temp_value = magic_temp_value diff --git a/sd_dynamic_prompts/helpers.py b/sd_dynamic_prompts/helpers.py index 964b8cf..8649170 100644 --- a/sd_dynamic_prompts/helpers.py +++ b/sd_dynamic_prompts/helpers.py @@ -1,3 +1,11 @@ +from __future__ import annotations + +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + + def get_seeds(p, num_seeds, use_fixed_seed): if p.subseed_strength != 0: seed = int(p.all_seeds[0]) @@ -23,3 +31,24 @@ def get_seeds(p, num_seeds, use_fixed_seed): def should_freeze_prompt(p): # When using a variation seed, the prompt shouldn't change between generations return p.subseed_strength > 0 + + +def load_magicprompt_models(modelfile: str) -> list[str]: + try: + models = [] + with open(modelfile) as f: + for line in f: + # ignore comments and empty lines + line = line.split("#")[0].strip() + if line: + models.append(line) + return models + except FileNotFoundError: + logger.warning(f"Could not find magicprompts config file at {modelfile}") + return [] + + +def get_magicmodels_path(base_dir: str) -> str: + magicprompt_models_path = Path(base_dir / "config" / "magicprompt_models.txt") + + return magicprompt_models_path diff --git a/sd_dynamic_prompts/ui/settings.py b/sd_dynamic_prompts/ui/settings.py index f47a078..14cb8fe 100644 --- a/sd_dynamic_prompts/ui/settings.py +++ b/sd_dynamic_prompts/ui/settings.py @@ -1,4 +1,11 @@ -from modules import shared +from pathlib import Path + +import gradio as gr +from modules import scripts, shared + +from sd_dynamic_prompts.helpers import get_magicmodels_path, load_magicprompt_models + +base_dir = Path(scripts.basedir()) def on_ui_settings(): @@ -59,3 +66,26 @@ def on_ui_settings(): section=section, ), ) + + magic_models = load_magicprompt_models(get_magicmodels_path(base_dir)) + shared.opts.add_option( + key="dp_magicprompt_default_model", + info=shared.OptionInfo( + magic_models[0] if magic_models else "", + label="Default magic prompt model", + component=gr.Dropdown, + component_args={"choices": magic_models}, + section=section, + ), + ) + + shared.opts.add_option( + key="dp_magicprompt_batch_size", + info=shared.OptionInfo( + 1, + label="Magic Prompt batch size (higher is faster but uses more memory)", + component=gr.Slider, + component_args={"minimum": 1, "maximum": 64, "step": 1}, + section=section, + ), + ) diff --git a/tests/prompts/test_helpers.py b/tests/prompts/test_helpers.py index 337a64c..4cf8c6d 100644 --- a/tests/prompts/test_helpers.py +++ b/tests/prompts/test_helpers.py @@ -1,8 +1,10 @@ +import os +import tempfile from unittest import mock import pytest -from sd_dynamic_prompts.helpers import get_seeds +from sd_dynamic_prompts.helpers import get_seeds, load_magicprompt_models @pytest.fixture @@ -45,3 +47,24 @@ def test_get_seeds_with_random_seed(processing): seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False) assert seeds == [seed] * num_seeds assert subseeds == list(range(subseed, subseed + num_seeds)) + + +def test_load_magicprompt_models(): + s = """# a comment +model1 # another comment +# empty lines below + + +model 2 + + + """ + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file: + tmp_file.write(s) + tmp_filename = tmp_file.name + + try: + load_magicprompt_models(tmp_filename) + finally: + os.unlink(tmp_filename)