Merge branch 'akx-paths-refactor'

no-strict-zip
Adi Eyal 2023-08-19 16:38:37 +02:00
commit 4e41203d97
6 changed files with 74 additions and 59 deletions

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import logging import logging
import math import math
from functools import lru_cache from functools import lru_cache
from pathlib import Path
from string import Template from string import Template
import dynamicprompts import dynamicprompts
@ -22,11 +21,15 @@ from sd_dynamic_prompts.element_ids import make_element_id
from sd_dynamic_prompts.generator_builder import GeneratorBuilder from sd_dynamic_prompts.generator_builder import GeneratorBuilder
from sd_dynamic_prompts.helpers import ( from sd_dynamic_prompts.helpers import (
generate_prompts, generate_prompts,
get_magicmodels_path,
get_seeds, get_seeds,
load_magicprompt_models, load_magicprompt_models,
should_freeze_prompt, should_freeze_prompt,
) )
from sd_dynamic_prompts.paths import (
get_extension_base_path,
get_magicprompt_models_txt_path,
get_wildcard_dir,
)
from sd_dynamic_prompts.pnginfo_saver import PngInfoSaver from sd_dynamic_prompts.pnginfo_saver import PngInfoSaver
from sd_dynamic_prompts.prompt_writer import PromptWriter from sd_dynamic_prompts.prompt_writer import PromptWriter
@ -40,21 +43,6 @@ is_debug = getattr(opts, "is_debug", False)
if is_debug: if is_debug:
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
base_dir = Path(scripts.basedir())
magicprompt_models_path = get_magicmodels_path(base_dir)
def get_wildcard_dir() -> Path:
wildcard_dir = getattr(opts, "wildcard_dir", None)
if wildcard_dir is None:
wildcard_dir = base_dir / "wildcards"
wildcard_dir = Path(wildcard_dir)
try:
wildcard_dir.mkdir(parents=True, exist_ok=True)
except Exception:
logger.exception(f"Failed to create wildcard directory {wildcard_dir}")
return wildcard_dir
def _get_effective_prompt(prompts: list[str], prompt: str) -> str: def _get_effective_prompt(prompts: list[str], prompt: str) -> str:
return prompts[0] if prompts else prompt return prompts[0] if prompts else prompt
@ -113,6 +101,7 @@ class Script(scripts.Script):
return scripts.AlwaysVisible return scripts.AlwaysVisible
def ui(self, is_img2img): def ui(self, is_img2img):
base_dir = get_extension_base_path()
install_message = _get_install_error_message() install_message = _get_install_error_message()
correct_lib_version = bool(not install_message) correct_lib_version = bool(not install_message)
@ -172,9 +161,7 @@ class Script(scripts.Script):
with gr.Accordion("Prompt Magic", open=False): with gr.Accordion("Prompt Magic", open=False):
with gr.Group(): with gr.Group():
try: try:
magicprompt_models = load_magicprompt_models( magicprompt_models = load_magicprompt_models()
magicprompt_models_path,
)
default_magicprompt_model = ( default_magicprompt_model = (
opts.dp_magicprompt_default_model opts.dp_magicprompt_default_model
if hasattr(opts, "dp_magicprompt_default_model") if hasattr(opts, "dp_magicprompt_default_model")
@ -183,7 +170,8 @@ class Script(scripts.Script):
is_magic_model_available = True is_magic_model_available = True
except IndexError: except IndexError:
logger.warning( logger.warning(
f"The magicprompts config file at {magicprompt_models_path} does not contain any models.", f"The magic prompts config file {get_magicprompt_models_txt_path()} "
f"does not contain any models.",
) )
magicprompt_models = [] magicprompt_models = []

View File

@ -6,6 +6,8 @@ from pathlib import Path
from dynamicprompts.generators.promptgenerator import PromptGenerator from dynamicprompts.generators.promptgenerator import PromptGenerator
from sd_dynamic_prompts.paths import get_magicprompt_models_txt_path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,27 +50,24 @@ def should_freeze_prompt(p):
return p.subseed_strength > 0 return p.subseed_strength > 0
def load_magicprompt_models(modelfile: str) -> list[str]: def load_magicprompt_models(models_file: Path | None = None) -> list[str]:
if not models_file:
models_file = get_magicprompt_models_txt_path()
try: try:
models = [] # ignore empty lines
with open(modelfile) as f: return [
for line in f: model
# ignore comments and empty lines for model in (
line = line.split("#")[0].strip() line.partition("#")[0].strip()
if line: for line in models_file.read_text().splitlines()
models.append(line) )
return models if model
]
except FileNotFoundError: except FileNotFoundError:
logger.warning(f"Could not find magicprompts config file at {modelfile}") logger.warning(f"Could not find magicprompts config file at {models_file}")
return [] return []
def get_magicmodels_path(base_dir: str) -> str:
magicprompt_models_path = Path(base_dir / "config" / "magicprompt_models.txt")
return magicprompt_models_path
def generate_prompts( def generate_prompts(
prompt_generator: PromptGenerator, prompt_generator: PromptGenerator,
negative_prompt_generator: PromptGenerator, negative_prompt_generator: PromptGenerator,

View File

@ -0,0 +1,39 @@
from __future__ import annotations
import logging
from functools import lru_cache
from pathlib import Path
logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def get_extension_base_path() -> Path:
"""
Get the directory the extension is installed in.
"""
path = Path(__file__).parent.parent
assert (path / "sd_dynamic_prompts").is_dir() # sanity check
assert (path / "scripts").is_dir() # sanity check
return path
def get_magicprompt_models_txt_path() -> Path:
return Path(get_extension_base_path() / "config" / "magicprompt_models.txt")
def get_wildcard_dir() -> Path:
try:
from modules.shared import opts
except ImportError: # likely not in an a1111 context
opts = None
wildcard_dir = getattr(opts, "wildcard_dir", None)
if wildcard_dir is None:
wildcard_dir = get_extension_base_path() / "wildcards"
wildcard_dir = Path(wildcard_dir)
try:
wildcard_dir.mkdir(parents=True, exist_ok=True)
except Exception:
logger.exception(f"Failed to create wildcard directory {wildcard_dir}")
return wildcard_dir

View File

@ -1,11 +1,7 @@
from pathlib import Path
import gradio as gr import gradio as gr
from modules import scripts, shared from modules import shared
from sd_dynamic_prompts.helpers import get_magicmodels_path, load_magicprompt_models from sd_dynamic_prompts.helpers import load_magicprompt_models
base_dir = Path(scripts.basedir())
def on_ui_settings(): def on_ui_settings():
@ -103,7 +99,7 @@ def on_ui_settings():
), ),
) )
magic_models = load_magicprompt_models(get_magicmodels_path(base_dir)) magic_models = load_magicprompt_models()
shared.opts.add_option( shared.opts.add_option(
key="dp_magicprompt_default_model", key="dp_magicprompt_default_model",
info=shared.OptionInfo( info=shared.OptionInfo(

View File

@ -8,7 +8,6 @@ import traceback
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import modules.scripts as scripts
from dynamicprompts.wildcards import WildcardManager from dynamicprompts.wildcards import WildcardManager
from dynamicprompts.wildcards.collection import WildcardTextFile from dynamicprompts.wildcards.collection import WildcardTextFile
from dynamicprompts.wildcards.tree import WildcardTreeNode from dynamicprompts.wildcards.tree import WildcardTreeNode
@ -26,13 +25,15 @@ logger = logging.getLogger(__name__)
wildcard_manager: WildcardManager wildcard_manager: WildcardManager
collections_path = Path(scripts.basedir()) / "collections"
def get_collection_dirs() -> dict[str, Path]: def get_collection_dirs() -> dict[str, Path]:
""" """
Get a mapping of name -> subdirectory path for the extension's collections/ directory. Get a mapping of name -> subdirectory path for the extension's collections/ directory.
""" """
from sd_dynamic_prompts.paths import get_extension_base_path
collections_path = get_extension_base_path() / "collections"
return { return {
str(pth.relative_to(collections_path)): pth str(pth.relative_to(collections_path)): pth
for pth in collections_path.iterdir() for pth in collections_path.iterdir()

View File

@ -1,5 +1,3 @@
import os
import tempfile
from unittest import mock from unittest import mock
import pytest import pytest
@ -90,7 +88,7 @@ def test_get_seeds_with_random_seed(processing):
assert subseeds == list(range(subseed, subseed + num_seeds)) assert subseeds == list(range(subseed, subseed + num_seeds))
def test_load_magicprompt_models(): def test_load_magicprompt_models(tmp_path):
s = """# a comment s = """# a comment
model1 # another comment model1 # another comment
# empty lines below # empty lines below
@ -100,15 +98,9 @@ model 2
""" """
p = tmp_path / "magicprompt_models.txt"
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file: p.write_text(s)
tmp_file.write(s) assert load_magicprompt_models(p) == ["model1", "model 2"]
tmp_filename = tmp_file.name
try:
load_magicprompt_models(tmp_filename)
finally:
os.unlink(tmp_filename)
def test_cross_product(): def test_cross_product():