Merge branch 'akx-paths-refactor'
commit
4e41203d97
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
import logging
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
|
||||
import dynamicprompts
|
||||
|
|
@ -22,11 +21,15 @@ from sd_dynamic_prompts.element_ids import make_element_id
|
|||
from sd_dynamic_prompts.generator_builder import GeneratorBuilder
|
||||
from sd_dynamic_prompts.helpers import (
|
||||
generate_prompts,
|
||||
get_magicmodels_path,
|
||||
get_seeds,
|
||||
load_magicprompt_models,
|
||||
should_freeze_prompt,
|
||||
)
|
||||
from sd_dynamic_prompts.paths import (
|
||||
get_extension_base_path,
|
||||
get_magicprompt_models_txt_path,
|
||||
get_wildcard_dir,
|
||||
)
|
||||
from sd_dynamic_prompts.pnginfo_saver import PngInfoSaver
|
||||
from sd_dynamic_prompts.prompt_writer import PromptWriter
|
||||
|
||||
|
|
@ -40,21 +43,6 @@ is_debug = getattr(opts, "is_debug", False)
|
|||
if is_debug:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
base_dir = Path(scripts.basedir())
|
||||
magicprompt_models_path = get_magicmodels_path(base_dir)
|
||||
|
||||
|
||||
def get_wildcard_dir() -> Path:
|
||||
wildcard_dir = getattr(opts, "wildcard_dir", None)
|
||||
if wildcard_dir is None:
|
||||
wildcard_dir = base_dir / "wildcards"
|
||||
wildcard_dir = Path(wildcard_dir)
|
||||
try:
|
||||
wildcard_dir.mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to create wildcard directory {wildcard_dir}")
|
||||
return wildcard_dir
|
||||
|
||||
|
||||
def _get_effective_prompt(prompts: list[str], prompt: str) -> str:
|
||||
return prompts[0] if prompts else prompt
|
||||
|
|
@ -113,6 +101,7 @@ class Script(scripts.Script):
|
|||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
base_dir = get_extension_base_path()
|
||||
install_message = _get_install_error_message()
|
||||
correct_lib_version = bool(not install_message)
|
||||
|
||||
|
|
@ -172,9 +161,7 @@ class Script(scripts.Script):
|
|||
with gr.Accordion("Prompt Magic", open=False):
|
||||
with gr.Group():
|
||||
try:
|
||||
magicprompt_models = load_magicprompt_models(
|
||||
magicprompt_models_path,
|
||||
)
|
||||
magicprompt_models = load_magicprompt_models()
|
||||
default_magicprompt_model = (
|
||||
opts.dp_magicprompt_default_model
|
||||
if hasattr(opts, "dp_magicprompt_default_model")
|
||||
|
|
@ -183,7 +170,8 @@ class Script(scripts.Script):
|
|||
is_magic_model_available = True
|
||||
except IndexError:
|
||||
logger.warning(
|
||||
f"The magicprompts config file at {magicprompt_models_path} does not contain any models.",
|
||||
f"The magic prompts config file {get_magicprompt_models_txt_path()} "
|
||||
f"does not contain any models.",
|
||||
)
|
||||
|
||||
magicprompt_models = []
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ from pathlib import Path
|
|||
|
||||
from dynamicprompts.generators.promptgenerator import PromptGenerator
|
||||
|
||||
from sd_dynamic_prompts.paths import get_magicprompt_models_txt_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -48,27 +50,24 @@ def should_freeze_prompt(p):
|
|||
return p.subseed_strength > 0
|
||||
|
||||
|
||||
def load_magicprompt_models(modelfile: str) -> list[str]:
|
||||
def load_magicprompt_models(models_file: Path | None = None) -> list[str]:
|
||||
if not models_file:
|
||||
models_file = get_magicprompt_models_txt_path()
|
||||
try:
|
||||
models = []
|
||||
with open(modelfile) as f:
|
||||
for line in f:
|
||||
# ignore comments and empty lines
|
||||
line = line.split("#")[0].strip()
|
||||
if line:
|
||||
models.append(line)
|
||||
return models
|
||||
# ignore empty lines
|
||||
return [
|
||||
model
|
||||
for model in (
|
||||
line.partition("#")[0].strip()
|
||||
for line in models_file.read_text().splitlines()
|
||||
)
|
||||
if model
|
||||
]
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Could not find magicprompts config file at {modelfile}")
|
||||
logger.warning(f"Could not find magicprompts config file at {models_file}")
|
||||
return []
|
||||
|
||||
|
||||
def get_magicmodels_path(base_dir: str) -> str:
|
||||
magicprompt_models_path = Path(base_dir / "config" / "magicprompt_models.txt")
|
||||
|
||||
return magicprompt_models_path
|
||||
|
||||
|
||||
def generate_prompts(
|
||||
prompt_generator: PromptGenerator,
|
||||
negative_prompt_generator: PromptGenerator,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_extension_base_path() -> Path:
|
||||
"""
|
||||
Get the directory the extension is installed in.
|
||||
"""
|
||||
path = Path(__file__).parent.parent
|
||||
assert (path / "sd_dynamic_prompts").is_dir() # sanity check
|
||||
assert (path / "scripts").is_dir() # sanity check
|
||||
return path
|
||||
|
||||
|
||||
def get_magicprompt_models_txt_path() -> Path:
|
||||
return Path(get_extension_base_path() / "config" / "magicprompt_models.txt")
|
||||
|
||||
|
||||
def get_wildcard_dir() -> Path:
|
||||
try:
|
||||
from modules.shared import opts
|
||||
except ImportError: # likely not in an a1111 context
|
||||
opts = None
|
||||
|
||||
wildcard_dir = getattr(opts, "wildcard_dir", None)
|
||||
if wildcard_dir is None:
|
||||
wildcard_dir = get_extension_base_path() / "wildcards"
|
||||
wildcard_dir = Path(wildcard_dir)
|
||||
try:
|
||||
wildcard_dir.mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to create wildcard directory {wildcard_dir}")
|
||||
return wildcard_dir
|
||||
|
|
@ -1,11 +1,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from modules import scripts, shared
|
||||
from modules import shared
|
||||
|
||||
from sd_dynamic_prompts.helpers import get_magicmodels_path, load_magicprompt_models
|
||||
|
||||
base_dir = Path(scripts.basedir())
|
||||
from sd_dynamic_prompts.helpers import load_magicprompt_models
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
|
|
@ -103,7 +99,7 @@ def on_ui_settings():
|
|||
),
|
||||
)
|
||||
|
||||
magic_models = load_magicprompt_models(get_magicmodels_path(base_dir))
|
||||
magic_models = load_magicprompt_models()
|
||||
shared.opts.add_option(
|
||||
key="dp_magicprompt_default_model",
|
||||
info=shared.OptionInfo(
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import traceback
|
|||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import modules.scripts as scripts
|
||||
from dynamicprompts.wildcards import WildcardManager
|
||||
from dynamicprompts.wildcards.collection import WildcardTextFile
|
||||
from dynamicprompts.wildcards.tree import WildcardTreeNode
|
||||
|
|
@ -26,13 +25,15 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
wildcard_manager: WildcardManager
|
||||
|
||||
collections_path = Path(scripts.basedir()) / "collections"
|
||||
|
||||
|
||||
def get_collection_dirs() -> dict[str, Path]:
|
||||
"""
|
||||
Get a mapping of name -> subdirectory path for the extension's collections/ directory.
|
||||
"""
|
||||
from sd_dynamic_prompts.paths import get_extension_base_path
|
||||
|
||||
collections_path = get_extension_base_path() / "collections"
|
||||
|
||||
return {
|
||||
str(pth.relative_to(collections_path)): pth
|
||||
for pth in collections_path.iterdir()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import os
|
||||
import tempfile
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
|
@ -90,7 +88,7 @@ def test_get_seeds_with_random_seed(processing):
|
|||
assert subseeds == list(range(subseed, subseed + num_seeds))
|
||||
|
||||
|
||||
def test_load_magicprompt_models():
|
||||
def test_load_magicprompt_models(tmp_path):
|
||||
s = """# a comment
|
||||
model1 # another comment
|
||||
# empty lines below
|
||||
|
|
@ -100,15 +98,9 @@ model 2
|
|||
|
||||
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file:
|
||||
tmp_file.write(s)
|
||||
tmp_filename = tmp_file.name
|
||||
|
||||
try:
|
||||
load_magicprompt_models(tmp_filename)
|
||||
finally:
|
||||
os.unlink(tmp_filename)
|
||||
p = tmp_path / "magicprompt_models.txt"
|
||||
p.write_text(s)
|
||||
assert load_magicprompt_models(p) == ["model1", "model 2"]
|
||||
|
||||
|
||||
def test_cross_product():
|
||||
|
|
|
|||
Loading…
Reference in New Issue