Merge branch 'akx-paths-refactor'

no-strict-zip
Adi Eyal 2023-08-19 16:38:37 +02:00
commit 4e41203d97
6 changed files with 74 additions and 59 deletions

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import logging
import math
from functools import lru_cache
from pathlib import Path
from string import Template
import dynamicprompts
@ -22,11 +21,15 @@ from sd_dynamic_prompts.element_ids import make_element_id
from sd_dynamic_prompts.generator_builder import GeneratorBuilder
from sd_dynamic_prompts.helpers import (
generate_prompts,
get_magicmodels_path,
get_seeds,
load_magicprompt_models,
should_freeze_prompt,
)
from sd_dynamic_prompts.paths import (
get_extension_base_path,
get_magicprompt_models_txt_path,
get_wildcard_dir,
)
from sd_dynamic_prompts.pnginfo_saver import PngInfoSaver
from sd_dynamic_prompts.prompt_writer import PromptWriter
@ -40,21 +43,6 @@ is_debug = getattr(opts, "is_debug", False)
if is_debug:
logger.setLevel(logging.DEBUG)
base_dir = Path(scripts.basedir())
magicprompt_models_path = get_magicmodels_path(base_dir)
def get_wildcard_dir() -> Path:
wildcard_dir = getattr(opts, "wildcard_dir", None)
if wildcard_dir is None:
wildcard_dir = base_dir / "wildcards"
wildcard_dir = Path(wildcard_dir)
try:
wildcard_dir.mkdir(parents=True, exist_ok=True)
except Exception:
logger.exception(f"Failed to create wildcard directory {wildcard_dir}")
return wildcard_dir
def _get_effective_prompt(prompts: list[str], prompt: str) -> str:
return prompts[0] if prompts else prompt
@ -113,6 +101,7 @@ class Script(scripts.Script):
return scripts.AlwaysVisible
def ui(self, is_img2img):
base_dir = get_extension_base_path()
install_message = _get_install_error_message()
correct_lib_version = bool(not install_message)
@ -172,9 +161,7 @@ class Script(scripts.Script):
with gr.Accordion("Prompt Magic", open=False):
with gr.Group():
try:
magicprompt_models = load_magicprompt_models(
magicprompt_models_path,
)
magicprompt_models = load_magicprompt_models()
default_magicprompt_model = (
opts.dp_magicprompt_default_model
if hasattr(opts, "dp_magicprompt_default_model")
@ -183,7 +170,8 @@ class Script(scripts.Script):
is_magic_model_available = True
except IndexError:
logger.warning(
f"The magicprompts config file at {magicprompt_models_path} does not contain any models.",
f"The magic prompts config file {get_magicprompt_models_txt_path()} "
f"does not contain any models.",
)
magicprompt_models = []

View File

@ -6,6 +6,8 @@ from pathlib import Path
from dynamicprompts.generators.promptgenerator import PromptGenerator
from sd_dynamic_prompts.paths import get_magicprompt_models_txt_path
logger = logging.getLogger(__name__)
@ -48,27 +50,24 @@ def should_freeze_prompt(p):
return p.subseed_strength > 0
def load_magicprompt_models(modelfile: str) -> list[str]:
def load_magicprompt_models(models_file: Path | None = None) -> list[str]:
if not models_file:
models_file = get_magicprompt_models_txt_path()
try:
models = []
with open(modelfile) as f:
for line in f:
# ignore comments and empty lines
line = line.split("#")[0].strip()
if line:
models.append(line)
return models
# ignore empty lines
return [
model
for model in (
line.partition("#")[0].strip()
for line in models_file.read_text().splitlines()
)
if model
]
except FileNotFoundError:
logger.warning(f"Could not find magicprompts config file at {modelfile}")
logger.warning(f"Could not find magicprompts config file at {models_file}")
return []
def get_magicmodels_path(base_dir: str) -> str:
magicprompt_models_path = Path(base_dir / "config" / "magicprompt_models.txt")
return magicprompt_models_path
def generate_prompts(
prompt_generator: PromptGenerator,
negative_prompt_generator: PromptGenerator,

View File

@ -0,0 +1,39 @@
from __future__ import annotations
import logging
from functools import lru_cache
from pathlib import Path
logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def get_extension_base_path() -> Path:
"""
Get the directory the extension is installed in.
"""
path = Path(__file__).parent.parent
assert (path / "sd_dynamic_prompts").is_dir() # sanity check
assert (path / "scripts").is_dir() # sanity check
return path
def get_magicprompt_models_txt_path() -> Path:
return Path(get_extension_base_path() / "config" / "magicprompt_models.txt")
def get_wildcard_dir() -> Path:
try:
from modules.shared import opts
except ImportError: # likely not in an a1111 context
opts = None
wildcard_dir = getattr(opts, "wildcard_dir", None)
if wildcard_dir is None:
wildcard_dir = get_extension_base_path() / "wildcards"
wildcard_dir = Path(wildcard_dir)
try:
wildcard_dir.mkdir(parents=True, exist_ok=True)
except Exception:
logger.exception(f"Failed to create wildcard directory {wildcard_dir}")
return wildcard_dir

View File

@ -1,11 +1,7 @@
from pathlib import Path
import gradio as gr
from modules import scripts, shared
from modules import shared
from sd_dynamic_prompts.helpers import get_magicmodels_path, load_magicprompt_models
base_dir = Path(scripts.basedir())
from sd_dynamic_prompts.helpers import load_magicprompt_models
def on_ui_settings():
@ -103,7 +99,7 @@ def on_ui_settings():
),
)
magic_models = load_magicprompt_models(get_magicmodels_path(base_dir))
magic_models = load_magicprompt_models()
shared.opts.add_option(
key="dp_magicprompt_default_model",
info=shared.OptionInfo(

View File

@ -8,7 +8,6 @@ import traceback
from pathlib import Path
import gradio as gr
import modules.scripts as scripts
from dynamicprompts.wildcards import WildcardManager
from dynamicprompts.wildcards.collection import WildcardTextFile
from dynamicprompts.wildcards.tree import WildcardTreeNode
@ -26,13 +25,15 @@ logger = logging.getLogger(__name__)
wildcard_manager: WildcardManager
collections_path = Path(scripts.basedir()) / "collections"
def get_collection_dirs() -> dict[str, Path]:
"""
Get a mapping of name -> subdirectory path for the extension's collections/ directory.
"""
from sd_dynamic_prompts.paths import get_extension_base_path
collections_path = get_extension_base_path() / "collections"
return {
str(pth.relative_to(collections_path)): pth
for pth in collections_path.iterdir()

View File

@ -1,5 +1,3 @@
import os
import tempfile
from unittest import mock
import pytest
@ -90,7 +88,7 @@ def test_get_seeds_with_random_seed(processing):
assert subseeds == list(range(subseed, subseed + num_seeds))
def test_load_magicprompt_models():
def test_load_magicprompt_models(tmp_path):
s = """# a comment
model1 # another comment
# empty lines below
@ -100,15 +98,9 @@ model 2
"""
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file:
tmp_file.write(s)
tmp_filename = tmp_file.name
try:
load_magicprompt_models(tmp_filename)
finally:
os.unlink(tmp_filename)
p = tmp_path / "magicprompt_models.txt"
p.write_text(s)
assert load_magicprompt_models(p) == ["model1", "model 2"]
def test_cross_product():