From d00d148540a96e959b728ab7f17d10081f1d0008 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <-> Date: Sun, 30 Jul 2023 20:06:25 +0300 Subject: [PATCH] add --wildcards-dir commandline argument make wildcards work with hires fix prompts --- .gitignore | 1 + preload.py | 3 +++ scripts/wildcards.py | 27 ++++++++++++++++++--------- 3 files changed, 22 insertions(+), 9 deletions(-) create mode 100644 preload.py diff --git a/.gitignore b/.gitignore index 2a467eb..4f95d83 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ wildcards +__pycache__ diff --git a/preload.py b/preload.py new file mode 100644 index 0000000..e9bc5be --- /dev/null +++ b/preload.py @@ -0,0 +1,3 @@ + +def preload(parser): + parser.add_argument("--wildcards-dir", type=str, help="directory with wildcards", default=None) diff --git a/scripts/wildcards.py b/scripts/wildcards.py index 68eb0dd..9fced7b 100644 --- a/scripts/wildcards.py +++ b/scripts/wildcards.py @@ -5,7 +5,7 @@ import sys from modules import scripts, script_callbacks, shared warned_about_files = {} -wildcard_dir = scripts.basedir() +repo_dir = scripts.basedir() class WildcardsScript(scripts.Script): @@ -19,7 +19,9 @@ class WildcardsScript(scripts.Script): if " " in text or len(text) == 0: return text - replacement_file = os.path.join(wildcard_dir, "wildcards", f"{text}.txt") + wildcards_dir = shared.cmd_opts.wildcards_dir or os.path.join(repo_dir, "wildcards") + + replacement_file = os.path.join(wildcards_dir, f"{text}.txt") if os.path.exists(replacement_file): with open(replacement_file, encoding="utf8") as f: return gen.choice(f.read().splitlines()) @@ -30,16 +32,23 @@ class WildcardsScript(scripts.Script): return text + def replace_prompts(self, prompts, seeds): + res = [] + + for i, text in enumerate(prompts): + gen = random.Random() + gen.seed(seeds[0 if shared.opts.wildcards_same_seed else i]) + + res.append("".join(self.replace_wildcard(chunk, gen) for chunk in text.split("__"))) + + return res + def process(self, p): original_prompt = p.all_prompts[0] - for i in range(len(p.all_prompts)): - gen = random.Random() - gen.seed(p.all_seeds[0 if shared.opts.wildcards_same_seed else i]) - - prompt = p.all_prompts[i] - prompt = "".join(self.replace_wildcard(chunk, gen) for chunk in prompt.split("__")) - p.all_prompts[i] = prompt + p.all_prompts = self.replace_prompts(p.all_prompts, p.all_seeds) + if hasattr(p, 'all_hr_prompts'): + p.all_hr_prompts = self.replace_prompts(p.all_hr_prompts, p.all_seeds) if original_prompt != p.all_prompts[0]: p.extra_generation_params["Wildcard prompt"] = original_prompt