From 4972255f31b0d5be926449951ba127e86cddd217 Mon Sep 17 00:00:00 2001
From: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed, 28 Dec 2022 17:07:22 +0900
Subject: [PATCH] enable hotreload with "Apply and restart UI" on Extensions
tab
---
.gitignore | 2 ++
scripts/main.py | 19 +++++++---
scripts/t2p/dynamic_import.py | 7 ++++
scripts/t2p/prompt_generator/__init__.py | 5 +--
.../t2p/prompt_generator/database_loader.py | 7 ++--
scripts/t2p/prompt_generator/wd_like.py | 35 +++++++++++--------
scripts/t2p/settings.py | 3 ++
7 files changed, 51 insertions(+), 27 deletions(-)
create mode 100644 scripts/t2p/dynamic_import.py
diff --git a/.gitignore b/.gitignore
index a5b0ea1..be811df 100644
--- a/.gitignore
+++ b/.gitignore
@@ -152,3 +152,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
+
+.vscode
\ No newline at end of file
diff --git a/scripts/main.py b/scripts/main.py
index 1c1ce51..862f776 100644
--- a/scripts/main.py
+++ b/scripts/main.py
@@ -4,9 +4,18 @@ import gradio as gr
from modules import script_callbacks
from modules import generation_parameters_copypaste as params_copypaste
-import scripts.t2p.prompt_generator as pgen
+import scripts.t2p.settings as settings
-wd_like = pgen.WDLike()
+if settings.DEVELOP:
+ import scripts.t2p.prompt_generator as pgen
+ from scripts.t2p.prompt_generator.wd_like import WDLike
+else:
+ from scripts.t2p.dynamic_import import dynamic_import
+ _wd_like = dynamic_import('scripts/t2p/prompt_generator/wd_like.py')
+ WDLike = _wd_like.WDLike
+ pgen = _wd_like.pgen
+
+wd_like = WDLike()
# brought from modules/deepbooru.py
re_special = re.compile(r'([\\()])')
@@ -28,7 +37,7 @@ def get_tag_range_txt(tag_range: int):
maxval = len(wd_like.database.tag_idx) - 1
i = max(0, min(tag_range, maxval))
r = wd_like.database.tag_idx[i]
- return f'Tag range: > {r[0]} tagged ({r[1] + 1} tags total)'
+ return f'Tag range: ≥ {r[0]} tagged ({r[1] + 1} tags total)'
def dd_database_changed(database_name: str, tag_range: int):
@@ -68,8 +77,8 @@ def on_ui_tabs():
gr.HTML(value='Generation Settings')
choices = wd_like.get_model_names()
with gr.Column():
- if choices: wd_like.load_data(choices[0])
- dd_database = gr.Dropdown(choices=choices, value=choices[0] if choices else None, interactive=True, label='Database')
+ if choices: wd_like.load_data(choices[-1])
+ dd_database = gr.Dropdown(choices=choices, value=choices[-1] if choices else None, interactive=True, label='Database')
sl_tag_range = gr.Slider(0, 8, 0, step=1, interactive=True, label='Tag count filter')
txt_tag_range = gr.HTML(get_tag_range_txt(0))
with gr.Column():
diff --git a/scripts/t2p/dynamic_import.py b/scripts/t2p/dynamic_import.py
new file mode 100644
index 0000000..08669c0
--- /dev/null
+++ b/scripts/t2p/dynamic_import.py
@@ -0,0 +1,7 @@
+print('[text2prompt] Load dynamic_import.py')
+
+def dynamic_import(path: str):
+ import os
+ from modules import scripts, script_loading
+ path = os.path.abspath(os.path.join(scripts.basedir(), path))
+ return script_loading.load_module(path)
\ No newline at end of file
diff --git a/scripts/t2p/prompt_generator/__init__.py b/scripts/t2p/prompt_generator/__init__.py
index 37d91f9..6dd5d8c 100644
--- a/scripts/t2p/prompt_generator/__init__.py
+++ b/scripts/t2p/prompt_generator/__init__.py
@@ -41,7 +41,4 @@ class PromptGenerator:
def ready(self) -> bool:
raise NotImplementedError()
def __call__(self, text: str, settings: GenerationSettings) -> List[str]:
- raise NotImplementedError()
-
-
-from .wd_like import WDLike
\ No newline at end of file
+ raise NotImplementedError()
\ No newline at end of file
diff --git a/scripts/t2p/prompt_generator/database_loader.py b/scripts/t2p/prompt_generator/database_loader.py
index c7faa43..d0e483b 100644
--- a/scripts/t2p/prompt_generator/database_loader.py
+++ b/scripts/t2p/prompt_generator/database_loader.py
@@ -4,7 +4,7 @@ import csv
from typing import Dict
import numpy as np
-from .. import settings
+import scripts.t2p.settings as settings
class Database:
@@ -91,8 +91,9 @@ class DatabaseLoader:
if ext == '.npz':
ds = Database(filepath, re_filename)
self.datas[ds.name()] = ds
- print('[text2prompt] Loaded following databases')
- print(f' {sorted(self.datas.keys())}')
+ print('[text2prompt] Following databases are available:')
+ for name in sorted(self.datas.keys()):
+ print(f' {name}')
def load(self, database_name: str):
diff --git a/scripts/t2p/prompt_generator/wd_like.py b/scripts/t2p/prompt_generator/wd_like.py
index 383e1e8..0d03f1f 100644
--- a/scripts/t2p/prompt_generator/wd_like.py
+++ b/scripts/t2p/prompt_generator/wd_like.py
@@ -5,10 +5,15 @@ import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
-from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbabilityConversion
-from .database_loader import DatabaseLoader
-from .. import settings
+import scripts.t2p.settings as settings
+if settings.DEVELOP:
+ import scripts.t2p.prompt_generator as pgen
+ import scripts.t2p.prompt_generator.database_loader as dloader
+else:
+ from scripts.t2p.dynamic_import import dynamic_import
+ pgen = dynamic_import('scripts/t2p/prompt_generator/__init__.py')
+ dloader = dynamic_import('scripts/t2p/prompt_generator/database_loader.py')
NUM_CHOICE = 10
K_VALUE = 50
@@ -20,10 +25,10 @@ def mean_pooling(model_output, attention_mask):
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
-class WDLike(PromptGenerator):
+class WDLike(pgen.PromptGenerator):
def __init__(self):
self.clear()
- self.database_loader = DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
+ self.database_loader = dloader.DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
def get_model_names(self):
return sorted(self.database_loader.datas.keys())
@@ -38,9 +43,9 @@ class WDLike(PromptGenerator):
self.loaded_model_name = None
def load_data(self, database_name: str):
- print('[text2prompt] Loading database...')
+ print(f'[text2prompt] Loading database with name "{database_name}"...')
self.database = self.database_loader.load(database_name)
- print('[text2prompt] Loaded')
+ print('[text2prompt] Database loaded')
def load_model(self):
if self.database is None:
@@ -52,11 +57,11 @@ class WDLike(PromptGenerator):
if self.loaded_model_name and self.loaded_model_name == self.database.model_name:
return
else:
- print('[text2prompt] Loading model...')
+ print(f'[text2prompt] Loading model with name "{self.database.model_name}"...')
self.tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name])
self.model = AutoModel.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]).to(device)
self.loaded_model_name = self.database.model_name
- print('[text2prompt] model loaded')
+ print('[text2prompt] Model loaded')
def unload_model(self):
if self.tokenizer is not None:
@@ -71,7 +76,7 @@ class WDLike(PromptGenerator):
and self.tokenizer is not None \
and self.loaded_model_name and self.loaded_model_name == self.database.model_name
- def __call__(self, text: str, opts: GenerationSettings) -> List[str]:
+ def __call__(self, text: str, opts: pgen.GenerationSettings) -> List[str]:
if not self.ready(): return ''
i = max(0, min(opts.tag_range, len(self.database.tag_idx) - 1))
@@ -103,16 +108,16 @@ class WDLike(PromptGenerator):
similarity: torch.Tensor = cos(sentence_embeddings[0], tag_tokens_dev)
# Convert similarity into probablity
- if opts.conversion == ProbabilityConversion.CUTOFF_AND_POWER:
+ if opts.conversion is pgen.ProbabilityConversion.CUTOFF_AND_POWER:
probs_cpu = torch.clamp(similarity.detach().cpu(), 0, 1) ** opts.prob_power
- elif opts.conversion == ProbabilityConversion.SOFTMAX:
+ elif opts.conversion is pgen.ProbabilityConversion.SOFTMAX:
probs_cpu = torch.softmax(similarity.detach().cpu(), dim=0)
probs_cpu = probs_cpu / probs_cpu.sum(dim=0)
results = None
- if opts.sampling == SamplingMethod.NONE:
+ if opts.sampling is pgen.SamplingMethod.NONE:
tags_np = np.array(self.tags)
opts.n = min(tags_np.shape[0], opts.n)
if opts.n <= 0: return []
@@ -128,7 +133,7 @@ class WDLike(PromptGenerator):
# Just sample randomly
results = np.random.choice(a=tags_np, size=opts.n, replace=False)
- elif opts.sampling == SamplingMethod.TOP_K:
+ elif opts.sampling is pgen.SamplingMethod.TOP_K:
probs, indices = probs_cpu.topk(opts.k)
indices = indices.detach().numpy().tolist()
if len(indices) <= 0: return []
@@ -147,7 +152,7 @@ class WDLike(PromptGenerator):
results = np.random.choice(tags_np, opts.n, replace=False)
# brought from https://nn.labml.ai/sampling/nucleus.html
- elif opts.sampling == SamplingMethod.TOP_P:
+ elif opts.sampling is pgen.SamplingMethod.TOP_P:
sorted_probs, sorted_indices = probs_cpu.sort(descending=True)
cs_probs = torch.cumsum(sorted_probs, dim=0)
nucleus = cs_probs < opts.p
diff --git a/scripts/t2p/settings.py b/scripts/t2p/settings.py
index 8eac28b..4abe67a 100644
--- a/scripts/t2p/settings.py
+++ b/scripts/t2p/settings.py
@@ -2,6 +2,9 @@ import os
import re
from modules import scripts
+# to use intellisense on vscode
+DEVELOP = False
+
def get_abspath(path: str):
return os.path.abspath(os.path.join(scripts.basedir(), path))