From 4972255f31b0d5be926449951ba127e86cddd217 Mon Sep 17 00:00:00 2001 From: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com> Date: Wed, 28 Dec 2022 17:07:22 +0900 Subject: [PATCH] enable hotreload with "Apply and restart UI" on Extensions tab --- .gitignore | 2 ++ scripts/main.py | 19 +++++++--- scripts/t2p/dynamic_import.py | 7 ++++ scripts/t2p/prompt_generator/__init__.py | 5 +-- .../t2p/prompt_generator/database_loader.py | 7 ++-- scripts/t2p/prompt_generator/wd_like.py | 35 +++++++++++-------- scripts/t2p/settings.py | 3 ++ 7 files changed, 51 insertions(+), 27 deletions(-) create mode 100644 scripts/t2p/dynamic_import.py diff --git a/.gitignore b/.gitignore index a5b0ea1..be811df 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +.vscode \ No newline at end of file diff --git a/scripts/main.py b/scripts/main.py index 1c1ce51..862f776 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -4,9 +4,18 @@ import gradio as gr from modules import script_callbacks from modules import generation_parameters_copypaste as params_copypaste -import scripts.t2p.prompt_generator as pgen +import scripts.t2p.settings as settings -wd_like = pgen.WDLike() +if settings.DEVELOP: + import scripts.t2p.prompt_generator as pgen + from scripts.t2p.prompt_generator.wd_like import WDLike +else: + from scripts.t2p.dynamic_import import dynamic_import + _wd_like = dynamic_import('scripts/t2p/prompt_generator/wd_like.py') + WDLike = _wd_like.WDLike + pgen = _wd_like.pgen + +wd_like = WDLike() # brought from modules/deepbooru.py re_special = re.compile(r'([\\()])') @@ -28,7 +37,7 @@ def get_tag_range_txt(tag_range: int): maxval = len(wd_like.database.tag_idx) - 1 i = max(0, min(tag_range, maxval)) r = wd_like.database.tag_idx[i] - return f'Tag range: > {r[0]} tagged ({r[1] + 1} tags total)' + return f'Tag range: ≥ {r[0]} tagged ({r[1] + 1} tags total)' def dd_database_changed(database_name: str, tag_range: int): @@ -68,8 +77,8 @@ def on_ui_tabs(): gr.HTML(value='Generation Settings') choices = wd_like.get_model_names() with gr.Column(): - if choices: wd_like.load_data(choices[0]) - dd_database = gr.Dropdown(choices=choices, value=choices[0] if choices else None, interactive=True, label='Database') + if choices: wd_like.load_data(choices[-1]) + dd_database = gr.Dropdown(choices=choices, value=choices[-1] if choices else None, interactive=True, label='Database') sl_tag_range = gr.Slider(0, 8, 0, step=1, interactive=True, label='Tag count filter') txt_tag_range = gr.HTML(get_tag_range_txt(0)) with gr.Column(): diff --git a/scripts/t2p/dynamic_import.py b/scripts/t2p/dynamic_import.py new file mode 100644 index 0000000..08669c0 --- /dev/null +++ b/scripts/t2p/dynamic_import.py @@ -0,0 +1,7 @@ +print('[text2prompt] Load dynamic_import.py') + +def dynamic_import(path: str): + import os + from modules import scripts, script_loading + path = os.path.abspath(os.path.join(scripts.basedir(), path)) + return script_loading.load_module(path) \ No newline at end of file diff --git a/scripts/t2p/prompt_generator/__init__.py b/scripts/t2p/prompt_generator/__init__.py index 37d91f9..6dd5d8c 100644 --- a/scripts/t2p/prompt_generator/__init__.py +++ b/scripts/t2p/prompt_generator/__init__.py @@ -41,7 +41,4 @@ class PromptGenerator: def ready(self) -> bool: raise NotImplementedError() def __call__(self, text: str, settings: GenerationSettings) -> List[str]: - raise NotImplementedError() - - -from .wd_like import WDLike \ No newline at end of file + raise NotImplementedError() \ No newline at end of file diff --git a/scripts/t2p/prompt_generator/database_loader.py b/scripts/t2p/prompt_generator/database_loader.py index c7faa43..d0e483b 100644 --- a/scripts/t2p/prompt_generator/database_loader.py +++ b/scripts/t2p/prompt_generator/database_loader.py @@ -4,7 +4,7 @@ import csv from typing import Dict import numpy as np -from .. import settings +import scripts.t2p.settings as settings class Database: @@ -91,8 +91,9 @@ class DatabaseLoader: if ext == '.npz': ds = Database(filepath, re_filename) self.datas[ds.name()] = ds - print('[text2prompt] Loaded following databases') - print(f' {sorted(self.datas.keys())}') + print('[text2prompt] Following databases are available:') + for name in sorted(self.datas.keys()): + print(f' {name}') def load(self, database_name: str): diff --git a/scripts/t2p/prompt_generator/wd_like.py b/scripts/t2p/prompt_generator/wd_like.py index 383e1e8..0d03f1f 100644 --- a/scripts/t2p/prompt_generator/wd_like.py +++ b/scripts/t2p/prompt_generator/wd_like.py @@ -5,10 +5,15 @@ import torch import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer -from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbabilityConversion -from .database_loader import DatabaseLoader -from .. import settings +import scripts.t2p.settings as settings +if settings.DEVELOP: + import scripts.t2p.prompt_generator as pgen + import scripts.t2p.prompt_generator.database_loader as dloader +else: + from scripts.t2p.dynamic_import import dynamic_import + pgen = dynamic_import('scripts/t2p/prompt_generator/__init__.py') + dloader = dynamic_import('scripts/t2p/prompt_generator/database_loader.py') NUM_CHOICE = 10 K_VALUE = 50 @@ -20,10 +25,10 @@ def mean_pooling(model_output, attention_mask): return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) -class WDLike(PromptGenerator): +class WDLike(pgen.PromptGenerator): def __init__(self): self.clear() - self.database_loader = DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU) + self.database_loader = dloader.DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU) def get_model_names(self): return sorted(self.database_loader.datas.keys()) @@ -38,9 +43,9 @@ class WDLike(PromptGenerator): self.loaded_model_name = None def load_data(self, database_name: str): - print('[text2prompt] Loading database...') + print(f'[text2prompt] Loading database with name "{database_name}"...') self.database = self.database_loader.load(database_name) - print('[text2prompt] Loaded') + print('[text2prompt] Database loaded') def load_model(self): if self.database is None: @@ -52,11 +57,11 @@ class WDLike(PromptGenerator): if self.loaded_model_name and self.loaded_model_name == self.database.model_name: return else: - print('[text2prompt] Loading model...') + print(f'[text2prompt] Loading model with name "{self.database.model_name}"...') self.tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]) self.model = AutoModel.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]).to(device) self.loaded_model_name = self.database.model_name - print('[text2prompt] model loaded') + print('[text2prompt] Model loaded') def unload_model(self): if self.tokenizer is not None: @@ -71,7 +76,7 @@ class WDLike(PromptGenerator): and self.tokenizer is not None \ and self.loaded_model_name and self.loaded_model_name == self.database.model_name - def __call__(self, text: str, opts: GenerationSettings) -> List[str]: + def __call__(self, text: str, opts: pgen.GenerationSettings) -> List[str]: if not self.ready(): return '' i = max(0, min(opts.tag_range, len(self.database.tag_idx) - 1)) @@ -103,16 +108,16 @@ class WDLike(PromptGenerator): similarity: torch.Tensor = cos(sentence_embeddings[0], tag_tokens_dev) # Convert similarity into probablity - if opts.conversion == ProbabilityConversion.CUTOFF_AND_POWER: + if opts.conversion is pgen.ProbabilityConversion.CUTOFF_AND_POWER: probs_cpu = torch.clamp(similarity.detach().cpu(), 0, 1) ** opts.prob_power - elif opts.conversion == ProbabilityConversion.SOFTMAX: + elif opts.conversion is pgen.ProbabilityConversion.SOFTMAX: probs_cpu = torch.softmax(similarity.detach().cpu(), dim=0) probs_cpu = probs_cpu / probs_cpu.sum(dim=0) results = None - if opts.sampling == SamplingMethod.NONE: + if opts.sampling is pgen.SamplingMethod.NONE: tags_np = np.array(self.tags) opts.n = min(tags_np.shape[0], opts.n) if opts.n <= 0: return [] @@ -128,7 +133,7 @@ class WDLike(PromptGenerator): # Just sample randomly results = np.random.choice(a=tags_np, size=opts.n, replace=False) - elif opts.sampling == SamplingMethod.TOP_K: + elif opts.sampling is pgen.SamplingMethod.TOP_K: probs, indices = probs_cpu.topk(opts.k) indices = indices.detach().numpy().tolist() if len(indices) <= 0: return [] @@ -147,7 +152,7 @@ class WDLike(PromptGenerator): results = np.random.choice(tags_np, opts.n, replace=False) # brought from https://nn.labml.ai/sampling/nucleus.html - elif opts.sampling == SamplingMethod.TOP_P: + elif opts.sampling is pgen.SamplingMethod.TOP_P: sorted_probs, sorted_indices = probs_cpu.sort(descending=True) cs_probs = torch.cumsum(sorted_probs, dim=0) nucleus = cs_probs < opts.p diff --git a/scripts/t2p/settings.py b/scripts/t2p/settings.py index 8eac28b..4abe67a 100644 --- a/scripts/t2p/settings.py +++ b/scripts/t2p/settings.py @@ -2,6 +2,9 @@ import os import re from modules import scripts +# to use intellisense on vscode +DEVELOP = False + def get_abspath(path: str): return os.path.abspath(os.path.join(scripts.basedir(), path))