Merge branch 'akx-paths-refactor'
commit
4e41203d97
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
|
||||||
from string import Template
|
from string import Template
|
||||||
|
|
||||||
import dynamicprompts
|
import dynamicprompts
|
||||||
|
|
@ -22,11 +21,15 @@ from sd_dynamic_prompts.element_ids import make_element_id
|
||||||
from sd_dynamic_prompts.generator_builder import GeneratorBuilder
|
from sd_dynamic_prompts.generator_builder import GeneratorBuilder
|
||||||
from sd_dynamic_prompts.helpers import (
|
from sd_dynamic_prompts.helpers import (
|
||||||
generate_prompts,
|
generate_prompts,
|
||||||
get_magicmodels_path,
|
|
||||||
get_seeds,
|
get_seeds,
|
||||||
load_magicprompt_models,
|
load_magicprompt_models,
|
||||||
should_freeze_prompt,
|
should_freeze_prompt,
|
||||||
)
|
)
|
||||||
|
from sd_dynamic_prompts.paths import (
|
||||||
|
get_extension_base_path,
|
||||||
|
get_magicprompt_models_txt_path,
|
||||||
|
get_wildcard_dir,
|
||||||
|
)
|
||||||
from sd_dynamic_prompts.pnginfo_saver import PngInfoSaver
|
from sd_dynamic_prompts.pnginfo_saver import PngInfoSaver
|
||||||
from sd_dynamic_prompts.prompt_writer import PromptWriter
|
from sd_dynamic_prompts.prompt_writer import PromptWriter
|
||||||
|
|
||||||
|
|
@ -40,21 +43,6 @@ is_debug = getattr(opts, "is_debug", False)
|
||||||
if is_debug:
|
if is_debug:
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
base_dir = Path(scripts.basedir())
|
|
||||||
magicprompt_models_path = get_magicmodels_path(base_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def get_wildcard_dir() -> Path:
|
|
||||||
wildcard_dir = getattr(opts, "wildcard_dir", None)
|
|
||||||
if wildcard_dir is None:
|
|
||||||
wildcard_dir = base_dir / "wildcards"
|
|
||||||
wildcard_dir = Path(wildcard_dir)
|
|
||||||
try:
|
|
||||||
wildcard_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(f"Failed to create wildcard directory {wildcard_dir}")
|
|
||||||
return wildcard_dir
|
|
||||||
|
|
||||||
|
|
||||||
def _get_effective_prompt(prompts: list[str], prompt: str) -> str:
|
def _get_effective_prompt(prompts: list[str], prompt: str) -> str:
|
||||||
return prompts[0] if prompts else prompt
|
return prompts[0] if prompts else prompt
|
||||||
|
|
@ -113,6 +101,7 @@ class Script(scripts.Script):
|
||||||
return scripts.AlwaysVisible
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
|
base_dir = get_extension_base_path()
|
||||||
install_message = _get_install_error_message()
|
install_message = _get_install_error_message()
|
||||||
correct_lib_version = bool(not install_message)
|
correct_lib_version = bool(not install_message)
|
||||||
|
|
||||||
|
|
@ -172,9 +161,7 @@ class Script(scripts.Script):
|
||||||
with gr.Accordion("Prompt Magic", open=False):
|
with gr.Accordion("Prompt Magic", open=False):
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
try:
|
try:
|
||||||
magicprompt_models = load_magicprompt_models(
|
magicprompt_models = load_magicprompt_models()
|
||||||
magicprompt_models_path,
|
|
||||||
)
|
|
||||||
default_magicprompt_model = (
|
default_magicprompt_model = (
|
||||||
opts.dp_magicprompt_default_model
|
opts.dp_magicprompt_default_model
|
||||||
if hasattr(opts, "dp_magicprompt_default_model")
|
if hasattr(opts, "dp_magicprompt_default_model")
|
||||||
|
|
@ -183,7 +170,8 @@ class Script(scripts.Script):
|
||||||
is_magic_model_available = True
|
is_magic_model_available = True
|
||||||
except IndexError:
|
except IndexError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The magicprompts config file at {magicprompt_models_path} does not contain any models.",
|
f"The magic prompts config file {get_magicprompt_models_txt_path()} "
|
||||||
|
f"does not contain any models.",
|
||||||
)
|
)
|
||||||
|
|
||||||
magicprompt_models = []
|
magicprompt_models = []
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ from pathlib import Path
|
||||||
|
|
||||||
from dynamicprompts.generators.promptgenerator import PromptGenerator
|
from dynamicprompts.generators.promptgenerator import PromptGenerator
|
||||||
|
|
||||||
|
from sd_dynamic_prompts.paths import get_magicprompt_models_txt_path
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -48,27 +50,24 @@ def should_freeze_prompt(p):
|
||||||
return p.subseed_strength > 0
|
return p.subseed_strength > 0
|
||||||
|
|
||||||
|
|
||||||
def load_magicprompt_models(modelfile: str) -> list[str]:
|
def load_magicprompt_models(models_file: Path | None = None) -> list[str]:
|
||||||
|
if not models_file:
|
||||||
|
models_file = get_magicprompt_models_txt_path()
|
||||||
try:
|
try:
|
||||||
models = []
|
# ignore empty lines
|
||||||
with open(modelfile) as f:
|
return [
|
||||||
for line in f:
|
model
|
||||||
# ignore comments and empty lines
|
for model in (
|
||||||
line = line.split("#")[0].strip()
|
line.partition("#")[0].strip()
|
||||||
if line:
|
for line in models_file.read_text().splitlines()
|
||||||
models.append(line)
|
)
|
||||||
return models
|
if model
|
||||||
|
]
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning(f"Could not find magicprompts config file at {modelfile}")
|
logger.warning(f"Could not find magicprompts config file at {models_file}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def get_magicmodels_path(base_dir: str) -> str:
|
|
||||||
magicprompt_models_path = Path(base_dir / "config" / "magicprompt_models.txt")
|
|
||||||
|
|
||||||
return magicprompt_models_path
|
|
||||||
|
|
||||||
|
|
||||||
def generate_prompts(
|
def generate_prompts(
|
||||||
prompt_generator: PromptGenerator,
|
prompt_generator: PromptGenerator,
|
||||||
negative_prompt_generator: PromptGenerator,
|
negative_prompt_generator: PromptGenerator,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_extension_base_path() -> Path:
|
||||||
|
"""
|
||||||
|
Get the directory the extension is installed in.
|
||||||
|
"""
|
||||||
|
path = Path(__file__).parent.parent
|
||||||
|
assert (path / "sd_dynamic_prompts").is_dir() # sanity check
|
||||||
|
assert (path / "scripts").is_dir() # sanity check
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def get_magicprompt_models_txt_path() -> Path:
|
||||||
|
return Path(get_extension_base_path() / "config" / "magicprompt_models.txt")
|
||||||
|
|
||||||
|
|
||||||
|
def get_wildcard_dir() -> Path:
|
||||||
|
try:
|
||||||
|
from modules.shared import opts
|
||||||
|
except ImportError: # likely not in an a1111 context
|
||||||
|
opts = None
|
||||||
|
|
||||||
|
wildcard_dir = getattr(opts, "wildcard_dir", None)
|
||||||
|
if wildcard_dir is None:
|
||||||
|
wildcard_dir = get_extension_base_path() / "wildcards"
|
||||||
|
wildcard_dir = Path(wildcard_dir)
|
||||||
|
try:
|
||||||
|
wildcard_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Failed to create wildcard directory {wildcard_dir}")
|
||||||
|
return wildcard_dir
|
||||||
|
|
@ -1,11 +1,7 @@
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules import scripts, shared
|
from modules import shared
|
||||||
|
|
||||||
from sd_dynamic_prompts.helpers import get_magicmodels_path, load_magicprompt_models
|
from sd_dynamic_prompts.helpers import load_magicprompt_models
|
||||||
|
|
||||||
base_dir = Path(scripts.basedir())
|
|
||||||
|
|
||||||
|
|
||||||
def on_ui_settings():
|
def on_ui_settings():
|
||||||
|
|
@ -103,7 +99,7 @@ def on_ui_settings():
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
magic_models = load_magicprompt_models(get_magicmodels_path(base_dir))
|
magic_models = load_magicprompt_models()
|
||||||
shared.opts.add_option(
|
shared.opts.add_option(
|
||||||
key="dp_magicprompt_default_model",
|
key="dp_magicprompt_default_model",
|
||||||
info=shared.OptionInfo(
|
info=shared.OptionInfo(
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.scripts as scripts
|
|
||||||
from dynamicprompts.wildcards import WildcardManager
|
from dynamicprompts.wildcards import WildcardManager
|
||||||
from dynamicprompts.wildcards.collection import WildcardTextFile
|
from dynamicprompts.wildcards.collection import WildcardTextFile
|
||||||
from dynamicprompts.wildcards.tree import WildcardTreeNode
|
from dynamicprompts.wildcards.tree import WildcardTreeNode
|
||||||
|
|
@ -26,13 +25,15 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
wildcard_manager: WildcardManager
|
wildcard_manager: WildcardManager
|
||||||
|
|
||||||
collections_path = Path(scripts.basedir()) / "collections"
|
|
||||||
|
|
||||||
|
|
||||||
def get_collection_dirs() -> dict[str, Path]:
|
def get_collection_dirs() -> dict[str, Path]:
|
||||||
"""
|
"""
|
||||||
Get a mapping of name -> subdirectory path for the extension's collections/ directory.
|
Get a mapping of name -> subdirectory path for the extension's collections/ directory.
|
||||||
"""
|
"""
|
||||||
|
from sd_dynamic_prompts.paths import get_extension_base_path
|
||||||
|
|
||||||
|
collections_path = get_extension_base_path() / "collections"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
str(pth.relative_to(collections_path)): pth
|
str(pth.relative_to(collections_path)): pth
|
||||||
for pth in collections_path.iterdir()
|
for pth in collections_path.iterdir()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -90,7 +88,7 @@ def test_get_seeds_with_random_seed(processing):
|
||||||
assert subseeds == list(range(subseed, subseed + num_seeds))
|
assert subseeds == list(range(subseed, subseed + num_seeds))
|
||||||
|
|
||||||
|
|
||||||
def test_load_magicprompt_models():
|
def test_load_magicprompt_models(tmp_path):
|
||||||
s = """# a comment
|
s = """# a comment
|
||||||
model1 # another comment
|
model1 # another comment
|
||||||
# empty lines below
|
# empty lines below
|
||||||
|
|
@ -100,15 +98,9 @@ model 2
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
p = tmp_path / "magicprompt_models.txt"
|
||||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file:
|
p.write_text(s)
|
||||||
tmp_file.write(s)
|
assert load_magicprompt_models(p) == ["model1", "model 2"]
|
||||||
tmp_filename = tmp_file.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
load_magicprompt_models(tmp_filename)
|
|
||||||
finally:
|
|
||||||
os.unlink(tmp_filename)
|
|
||||||
|
|
||||||
|
|
||||||
def test_cross_product():
|
def test_cross_product():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue