diff --git a/sd_dynamic_prompts/dynamic_prompting.py b/sd_dynamic_prompts/dynamic_prompting.py index 6c8cb93..81574ef 100644 --- a/sd_dynamic_prompts/dynamic_prompting.py +++ b/sd_dynamic_prompts/dynamic_prompting.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging import math from functools import lru_cache -from pathlib import Path from string import Template import dynamicprompts @@ -22,11 +21,15 @@ from sd_dynamic_prompts.element_ids import make_element_id from sd_dynamic_prompts.generator_builder import GeneratorBuilder from sd_dynamic_prompts.helpers import ( generate_prompts, - get_magicmodels_path, get_seeds, load_magicprompt_models, should_freeze_prompt, ) +from sd_dynamic_prompts.paths import ( + get_extension_base_path, + get_magicprompt_models_txt_path, + get_wildcard_dir, +) from sd_dynamic_prompts.pnginfo_saver import PngInfoSaver from sd_dynamic_prompts.prompt_writer import PromptWriter @@ -40,21 +43,6 @@ is_debug = getattr(opts, "is_debug", False) if is_debug: logger.setLevel(logging.DEBUG) -base_dir = Path(scripts.basedir()) -magicprompt_models_path = get_magicmodels_path(base_dir) - - -def get_wildcard_dir() -> Path: - wildcard_dir = getattr(opts, "wildcard_dir", None) - if wildcard_dir is None: - wildcard_dir = base_dir / "wildcards" - wildcard_dir = Path(wildcard_dir) - try: - wildcard_dir.mkdir(parents=True, exist_ok=True) - except Exception: - logger.exception(f"Failed to create wildcard directory {wildcard_dir}") - return wildcard_dir - def _get_effective_prompt(prompts: list[str], prompt: str) -> str: return prompts[0] if prompts else prompt @@ -113,6 +101,7 @@ class Script(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): + base_dir = get_extension_base_path() install_message = _get_install_error_message() correct_lib_version = bool(not install_message) @@ -172,9 +161,7 @@ class Script(scripts.Script): with gr.Accordion("Prompt Magic", open=False): with gr.Group(): try: - magicprompt_models = load_magicprompt_models( - magicprompt_models_path, - ) + magicprompt_models = load_magicprompt_models() default_magicprompt_model = ( opts.dp_magicprompt_default_model if hasattr(opts, "dp_magicprompt_default_model") @@ -183,7 +170,8 @@ class Script(scripts.Script): is_magic_model_available = True except IndexError: logger.warning( - f"The magicprompts config file at {magicprompt_models_path} does not contain any models.", + f"The magic prompts config file {get_magicprompt_models_txt_path()} " + f"does not contain any models.", ) magicprompt_models = [] diff --git a/sd_dynamic_prompts/helpers.py b/sd_dynamic_prompts/helpers.py index 82c1759..15ee93e 100644 --- a/sd_dynamic_prompts/helpers.py +++ b/sd_dynamic_prompts/helpers.py @@ -6,6 +6,8 @@ from pathlib import Path from dynamicprompts.generators.promptgenerator import PromptGenerator +from sd_dynamic_prompts.paths import get_magicprompt_models_txt_path + logger = logging.getLogger(__name__) @@ -48,27 +50,24 @@ def should_freeze_prompt(p): return p.subseed_strength > 0 -def load_magicprompt_models(modelfile: str) -> list[str]: +def load_magicprompt_models(models_file: Path | None = None) -> list[str]: + if not models_file: + models_file = get_magicprompt_models_txt_path() try: - models = [] - with open(modelfile) as f: - for line in f: - # ignore comments and empty lines - line = line.split("#")[0].strip() - if line: - models.append(line) - return models + # ignore empty lines + return [ + model + for model in ( + line.partition("#")[0].strip() + for line in models_file.read_text().splitlines() + ) + if model + ] except FileNotFoundError: - logger.warning(f"Could not find magicprompts config file at {modelfile}") + logger.warning(f"Could not find magicprompts config file at {models_file}") return [] -def get_magicmodels_path(base_dir: str) -> str: - magicprompt_models_path = Path(base_dir / "config" / "magicprompt_models.txt") - - return magicprompt_models_path - - def generate_prompts( prompt_generator: PromptGenerator, negative_prompt_generator: PromptGenerator, diff --git a/sd_dynamic_prompts/paths.py b/sd_dynamic_prompts/paths.py new file mode 100644 index 0000000..d192016 --- /dev/null +++ b/sd_dynamic_prompts/paths.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import logging +from functools import lru_cache +from pathlib import Path + +logger = logging.getLogger(__name__) + + +@lru_cache(maxsize=1) +def get_extension_base_path() -> Path: + """ + Get the directory the extension is installed in. + """ + path = Path(__file__).parent.parent + assert (path / "sd_dynamic_prompts").is_dir() # sanity check + assert (path / "scripts").is_dir() # sanity check + return path + + +def get_magicprompt_models_txt_path() -> Path: + return Path(get_extension_base_path() / "config" / "magicprompt_models.txt") + + +def get_wildcard_dir() -> Path: + try: + from modules.shared import opts + except ImportError: # likely not in an a1111 context + opts = None + + wildcard_dir = getattr(opts, "wildcard_dir", None) + if wildcard_dir is None: + wildcard_dir = get_extension_base_path() / "wildcards" + wildcard_dir = Path(wildcard_dir) + try: + wildcard_dir.mkdir(parents=True, exist_ok=True) + except Exception: + logger.exception(f"Failed to create wildcard directory {wildcard_dir}") + return wildcard_dir diff --git a/sd_dynamic_prompts/settings.py b/sd_dynamic_prompts/settings.py index b06068b..70c2848 100644 --- a/sd_dynamic_prompts/settings.py +++ b/sd_dynamic_prompts/settings.py @@ -1,11 +1,7 @@ -from pathlib import Path - import gradio as gr -from modules import scripts, shared +from modules import shared -from sd_dynamic_prompts.helpers import get_magicmodels_path, load_magicprompt_models - -base_dir = Path(scripts.basedir()) +from sd_dynamic_prompts.helpers import load_magicprompt_models def on_ui_settings(): @@ -103,7 +99,7 @@ def on_ui_settings(): ), ) - magic_models = load_magicprompt_models(get_magicmodels_path(base_dir)) + magic_models = load_magicprompt_models() shared.opts.add_option( key="dp_magicprompt_default_model", info=shared.OptionInfo( diff --git a/sd_dynamic_prompts/wildcards_tab.py b/sd_dynamic_prompts/wildcards_tab.py index d14df4b..c490df1 100644 --- a/sd_dynamic_prompts/wildcards_tab.py +++ b/sd_dynamic_prompts/wildcards_tab.py @@ -8,7 +8,6 @@ import traceback from pathlib import Path import gradio as gr -import modules.scripts as scripts from dynamicprompts.wildcards import WildcardManager from dynamicprompts.wildcards.collection import WildcardTextFile from dynamicprompts.wildcards.tree import WildcardTreeNode @@ -26,13 +25,15 @@ logger = logging.getLogger(__name__) wildcard_manager: WildcardManager -collections_path = Path(scripts.basedir()) / "collections" - def get_collection_dirs() -> dict[str, Path]: """ Get a mapping of name -> subdirectory path for the extension's collections/ directory. """ + from sd_dynamic_prompts.paths import get_extension_base_path + + collections_path = get_extension_base_path() / "collections" + return { str(pth.relative_to(collections_path)): pth for pth in collections_path.iterdir() diff --git a/tests/prompts/test_helpers.py b/tests/prompts/test_helpers.py index aff3d4c..094d7d3 100644 --- a/tests/prompts/test_helpers.py +++ b/tests/prompts/test_helpers.py @@ -1,5 +1,3 @@ -import os -import tempfile from unittest import mock import pytest @@ -90,7 +88,7 @@ def test_get_seeds_with_random_seed(processing): assert subseeds == list(range(subseed, subseed + num_seeds)) -def test_load_magicprompt_models(): +def test_load_magicprompt_models(tmp_path): s = """# a comment model1 # another comment # empty lines below @@ -100,15 +98,9 @@ model 2 """ - - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file: - tmp_file.write(s) - tmp_filename = tmp_file.name - - try: - load_magicprompt_models(tmp_filename) - finally: - os.unlink(tmp_filename) + p = tmp_path / "magicprompt_models.txt" + p.write_text(s) + assert load_magicprompt_models(p) == ["model1", "model 2"] def test_cross_product():