enable hotreload with "Apply and restart UI" on Extensions tab

main
toshiaki1729 2022-12-28 17:07:22 +09:00
parent 38fc93c5fa
commit 4972255f31
7 changed files with 51 additions and 27 deletions

2
.gitignore vendored
View File

@ -152,3 +152,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.vscode

View File

@ -4,9 +4,18 @@ import gradio as gr
from modules import script_callbacks
from modules import generation_parameters_copypaste as params_copypaste
import scripts.t2p.prompt_generator as pgen
import scripts.t2p.settings as settings
wd_like = pgen.WDLike()
if settings.DEVELOP:
import scripts.t2p.prompt_generator as pgen
from scripts.t2p.prompt_generator.wd_like import WDLike
else:
from scripts.t2p.dynamic_import import dynamic_import
_wd_like = dynamic_import('scripts/t2p/prompt_generator/wd_like.py')
WDLike = _wd_like.WDLike
pgen = _wd_like.pgen
wd_like = WDLike()
# brought from modules/deepbooru.py
re_special = re.compile(r'([\\()])')
@ -28,7 +37,7 @@ def get_tag_range_txt(tag_range: int):
maxval = len(wd_like.database.tag_idx) - 1
i = max(0, min(tag_range, maxval))
r = wd_like.database.tag_idx[i]
return f'Tag range: <b> &gt; {r[0]} tagged</b> ({r[1] + 1} tags total)'
return f'Tag range: <b> {r[0]} tagged</b> ({r[1] + 1} tags total)'
def dd_database_changed(database_name: str, tag_range: int):
@ -68,8 +77,8 @@ def on_ui_tabs():
gr.HTML(value='Generation Settings')
choices = wd_like.get_model_names()
with gr.Column():
if choices: wd_like.load_data(choices[0])
dd_database = gr.Dropdown(choices=choices, value=choices[0] if choices else None, interactive=True, label='Database')
if choices: wd_like.load_data(choices[-1])
dd_database = gr.Dropdown(choices=choices, value=choices[-1] if choices else None, interactive=True, label='Database')
sl_tag_range = gr.Slider(0, 8, 0, step=1, interactive=True, label='Tag count filter')
txt_tag_range = gr.HTML(get_tag_range_txt(0))
with gr.Column():

View File

@ -0,0 +1,7 @@
print('[text2prompt] Load dynamic_import.py')
def dynamic_import(path: str):
import os
from modules import scripts, script_loading
path = os.path.abspath(os.path.join(scripts.basedir(), path))
return script_loading.load_module(path)

View File

@ -41,7 +41,4 @@ class PromptGenerator:
def ready(self) -> bool:
raise NotImplementedError()
def __call__(self, text: str, settings: GenerationSettings) -> List[str]:
raise NotImplementedError()
from .wd_like import WDLike
raise NotImplementedError()

View File

@ -4,7 +4,7 @@ import csv
from typing import Dict
import numpy as np
from .. import settings
import scripts.t2p.settings as settings
class Database:
@ -91,8 +91,9 @@ class DatabaseLoader:
if ext == '.npz':
ds = Database(filepath, re_filename)
self.datas[ds.name()] = ds
print('[text2prompt] Loaded following databases')
print(f' {sorted(self.datas.keys())}')
print('[text2prompt] Following databases are available:')
for name in sorted(self.datas.keys()):
print(f' {name}')
def load(self, database_name: str):

View File

@ -5,10 +5,15 @@ import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbabilityConversion
from .database_loader import DatabaseLoader
from .. import settings
import scripts.t2p.settings as settings
if settings.DEVELOP:
import scripts.t2p.prompt_generator as pgen
import scripts.t2p.prompt_generator.database_loader as dloader
else:
from scripts.t2p.dynamic_import import dynamic_import
pgen = dynamic_import('scripts/t2p/prompt_generator/__init__.py')
dloader = dynamic_import('scripts/t2p/prompt_generator/database_loader.py')
NUM_CHOICE = 10
K_VALUE = 50
@ -20,10 +25,10 @@ def mean_pooling(model_output, attention_mask):
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class WDLike(PromptGenerator):
class WDLike(pgen.PromptGenerator):
def __init__(self):
self.clear()
self.database_loader = DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
self.database_loader = dloader.DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
def get_model_names(self):
return sorted(self.database_loader.datas.keys())
@ -38,9 +43,9 @@ class WDLike(PromptGenerator):
self.loaded_model_name = None
def load_data(self, database_name: str):
print('[text2prompt] Loading database...')
print(f'[text2prompt] Loading database with name "{database_name}"...')
self.database = self.database_loader.load(database_name)
print('[text2prompt] Loaded')
print('[text2prompt] Database loaded')
def load_model(self):
if self.database is None:
@ -52,11 +57,11 @@ class WDLike(PromptGenerator):
if self.loaded_model_name and self.loaded_model_name == self.database.model_name:
return
else:
print('[text2prompt] Loading model...')
print(f'[text2prompt] Loading model with name "{self.database.model_name}"...')
self.tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name])
self.model = AutoModel.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]).to(device)
self.loaded_model_name = self.database.model_name
print('[text2prompt] model loaded')
print('[text2prompt] Model loaded')
def unload_model(self):
if self.tokenizer is not None:
@ -71,7 +76,7 @@ class WDLike(PromptGenerator):
and self.tokenizer is not None \
and self.loaded_model_name and self.loaded_model_name == self.database.model_name
def __call__(self, text: str, opts: GenerationSettings) -> List[str]:
def __call__(self, text: str, opts: pgen.GenerationSettings) -> List[str]:
if not self.ready(): return ''
i = max(0, min(opts.tag_range, len(self.database.tag_idx) - 1))
@ -103,16 +108,16 @@ class WDLike(PromptGenerator):
similarity: torch.Tensor = cos(sentence_embeddings[0], tag_tokens_dev)
# Convert similarity into probablity
if opts.conversion == ProbabilityConversion.CUTOFF_AND_POWER:
if opts.conversion is pgen.ProbabilityConversion.CUTOFF_AND_POWER:
probs_cpu = torch.clamp(similarity.detach().cpu(), 0, 1) ** opts.prob_power
elif opts.conversion == ProbabilityConversion.SOFTMAX:
elif opts.conversion is pgen.ProbabilityConversion.SOFTMAX:
probs_cpu = torch.softmax(similarity.detach().cpu(), dim=0)
probs_cpu = probs_cpu / probs_cpu.sum(dim=0)
results = None
if opts.sampling == SamplingMethod.NONE:
if opts.sampling is pgen.SamplingMethod.NONE:
tags_np = np.array(self.tags)
opts.n = min(tags_np.shape[0], opts.n)
if opts.n <= 0: return []
@ -128,7 +133,7 @@ class WDLike(PromptGenerator):
# Just sample randomly
results = np.random.choice(a=tags_np, size=opts.n, replace=False)
elif opts.sampling == SamplingMethod.TOP_K:
elif opts.sampling is pgen.SamplingMethod.TOP_K:
probs, indices = probs_cpu.topk(opts.k)
indices = indices.detach().numpy().tolist()
if len(indices) <= 0: return []
@ -147,7 +152,7 @@ class WDLike(PromptGenerator):
results = np.random.choice(tags_np, opts.n, replace=False)
# brought from https://nn.labml.ai/sampling/nucleus.html
elif opts.sampling == SamplingMethod.TOP_P:
elif opts.sampling is pgen.SamplingMethod.TOP_P:
sorted_probs, sorted_indices = probs_cpu.sort(descending=True)
cs_probs = torch.cumsum(sorted_probs, dim=0)
nucleus = cs_probs < opts.p

View File

@ -2,6 +2,9 @@ import os
import re
from modules import scripts
# to use intellisense on vscode
DEVELOP = False
def get_abspath(path: str):
return os.path.abspath(os.path.join(scripts.basedir(), path))