enable hotreload with "Apply and restart UI" on Extensions tab
parent
38fc93c5fa
commit
4972255f31
|
|
@ -152,3 +152,5 @@ cython_debug/
|
|||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
.vscode
|
||||
|
|
@ -4,9 +4,18 @@ import gradio as gr
|
|||
from modules import script_callbacks
|
||||
from modules import generation_parameters_copypaste as params_copypaste
|
||||
|
||||
import scripts.t2p.prompt_generator as pgen
|
||||
import scripts.t2p.settings as settings
|
||||
|
||||
wd_like = pgen.WDLike()
|
||||
if settings.DEVELOP:
|
||||
import scripts.t2p.prompt_generator as pgen
|
||||
from scripts.t2p.prompt_generator.wd_like import WDLike
|
||||
else:
|
||||
from scripts.t2p.dynamic_import import dynamic_import
|
||||
_wd_like = dynamic_import('scripts/t2p/prompt_generator/wd_like.py')
|
||||
WDLike = _wd_like.WDLike
|
||||
pgen = _wd_like.pgen
|
||||
|
||||
wd_like = WDLike()
|
||||
|
||||
# brought from modules/deepbooru.py
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
|
@ -28,7 +37,7 @@ def get_tag_range_txt(tag_range: int):
|
|||
maxval = len(wd_like.database.tag_idx) - 1
|
||||
i = max(0, min(tag_range, maxval))
|
||||
r = wd_like.database.tag_idx[i]
|
||||
return f'Tag range: <b> > {r[0]} tagged</b> ({r[1] + 1} tags total)'
|
||||
return f'Tag range: <b> ≥ {r[0]} tagged</b> ({r[1] + 1} tags total)'
|
||||
|
||||
|
||||
def dd_database_changed(database_name: str, tag_range: int):
|
||||
|
|
@ -68,8 +77,8 @@ def on_ui_tabs():
|
|||
gr.HTML(value='Generation Settings')
|
||||
choices = wd_like.get_model_names()
|
||||
with gr.Column():
|
||||
if choices: wd_like.load_data(choices[0])
|
||||
dd_database = gr.Dropdown(choices=choices, value=choices[0] if choices else None, interactive=True, label='Database')
|
||||
if choices: wd_like.load_data(choices[-1])
|
||||
dd_database = gr.Dropdown(choices=choices, value=choices[-1] if choices else None, interactive=True, label='Database')
|
||||
sl_tag_range = gr.Slider(0, 8, 0, step=1, interactive=True, label='Tag count filter')
|
||||
txt_tag_range = gr.HTML(get_tag_range_txt(0))
|
||||
with gr.Column():
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
print('[text2prompt] Load dynamic_import.py')
|
||||
|
||||
def dynamic_import(path: str):
|
||||
import os
|
||||
from modules import scripts, script_loading
|
||||
path = os.path.abspath(os.path.join(scripts.basedir(), path))
|
||||
return script_loading.load_module(path)
|
||||
|
|
@ -41,7 +41,4 @@ class PromptGenerator:
|
|||
def ready(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
def __call__(self, text: str, settings: GenerationSettings) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
from .wd_like import WDLike
|
||||
raise NotImplementedError()
|
||||
|
|
@ -4,7 +4,7 @@ import csv
|
|||
from typing import Dict
|
||||
import numpy as np
|
||||
|
||||
from .. import settings
|
||||
import scripts.t2p.settings as settings
|
||||
|
||||
|
||||
class Database:
|
||||
|
|
@ -91,8 +91,9 @@ class DatabaseLoader:
|
|||
if ext == '.npz':
|
||||
ds = Database(filepath, re_filename)
|
||||
self.datas[ds.name()] = ds
|
||||
print('[text2prompt] Loaded following databases')
|
||||
print(f' {sorted(self.datas.keys())}')
|
||||
print('[text2prompt] Following databases are available:')
|
||||
for name in sorted(self.datas.keys()):
|
||||
print(f' {name}')
|
||||
|
||||
|
||||
def load(self, database_name: str):
|
||||
|
|
|
|||
|
|
@ -5,10 +5,15 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbabilityConversion
|
||||
from .database_loader import DatabaseLoader
|
||||
from .. import settings
|
||||
import scripts.t2p.settings as settings
|
||||
|
||||
if settings.DEVELOP:
|
||||
import scripts.t2p.prompt_generator as pgen
|
||||
import scripts.t2p.prompt_generator.database_loader as dloader
|
||||
else:
|
||||
from scripts.t2p.dynamic_import import dynamic_import
|
||||
pgen = dynamic_import('scripts/t2p/prompt_generator/__init__.py')
|
||||
dloader = dynamic_import('scripts/t2p/prompt_generator/database_loader.py')
|
||||
|
||||
NUM_CHOICE = 10
|
||||
K_VALUE = 50
|
||||
|
|
@ -20,10 +25,10 @@ def mean_pooling(model_output, attention_mask):
|
|||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
|
||||
|
||||
class WDLike(PromptGenerator):
|
||||
class WDLike(pgen.PromptGenerator):
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
self.database_loader = DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
|
||||
self.database_loader = dloader.DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
|
||||
|
||||
def get_model_names(self):
|
||||
return sorted(self.database_loader.datas.keys())
|
||||
|
|
@ -38,9 +43,9 @@ class WDLike(PromptGenerator):
|
|||
self.loaded_model_name = None
|
||||
|
||||
def load_data(self, database_name: str):
|
||||
print('[text2prompt] Loading database...')
|
||||
print(f'[text2prompt] Loading database with name "{database_name}"...')
|
||||
self.database = self.database_loader.load(database_name)
|
||||
print('[text2prompt] Loaded')
|
||||
print('[text2prompt] Database loaded')
|
||||
|
||||
def load_model(self):
|
||||
if self.database is None:
|
||||
|
|
@ -52,11 +57,11 @@ class WDLike(PromptGenerator):
|
|||
if self.loaded_model_name and self.loaded_model_name == self.database.model_name:
|
||||
return
|
||||
else:
|
||||
print('[text2prompt] Loading model...')
|
||||
print(f'[text2prompt] Loading model with name "{self.database.model_name}"...')
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name])
|
||||
self.model = AutoModel.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]).to(device)
|
||||
self.loaded_model_name = self.database.model_name
|
||||
print('[text2prompt] model loaded')
|
||||
print('[text2prompt] Model loaded')
|
||||
|
||||
def unload_model(self):
|
||||
if self.tokenizer is not None:
|
||||
|
|
@ -71,7 +76,7 @@ class WDLike(PromptGenerator):
|
|||
and self.tokenizer is not None \
|
||||
and self.loaded_model_name and self.loaded_model_name == self.database.model_name
|
||||
|
||||
def __call__(self, text: str, opts: GenerationSettings) -> List[str]:
|
||||
def __call__(self, text: str, opts: pgen.GenerationSettings) -> List[str]:
|
||||
if not self.ready(): return ''
|
||||
|
||||
i = max(0, min(opts.tag_range, len(self.database.tag_idx) - 1))
|
||||
|
|
@ -103,16 +108,16 @@ class WDLike(PromptGenerator):
|
|||
similarity: torch.Tensor = cos(sentence_embeddings[0], tag_tokens_dev)
|
||||
|
||||
# Convert similarity into probablity
|
||||
if opts.conversion == ProbabilityConversion.CUTOFF_AND_POWER:
|
||||
if opts.conversion is pgen.ProbabilityConversion.CUTOFF_AND_POWER:
|
||||
probs_cpu = torch.clamp(similarity.detach().cpu(), 0, 1) ** opts.prob_power
|
||||
elif opts.conversion == ProbabilityConversion.SOFTMAX:
|
||||
elif opts.conversion is pgen.ProbabilityConversion.SOFTMAX:
|
||||
probs_cpu = torch.softmax(similarity.detach().cpu(), dim=0)
|
||||
|
||||
probs_cpu = probs_cpu / probs_cpu.sum(dim=0)
|
||||
|
||||
results = None
|
||||
|
||||
if opts.sampling == SamplingMethod.NONE:
|
||||
if opts.sampling is pgen.SamplingMethod.NONE:
|
||||
tags_np = np.array(self.tags)
|
||||
opts.n = min(tags_np.shape[0], opts.n)
|
||||
if opts.n <= 0: return []
|
||||
|
|
@ -128,7 +133,7 @@ class WDLike(PromptGenerator):
|
|||
# Just sample randomly
|
||||
results = np.random.choice(a=tags_np, size=opts.n, replace=False)
|
||||
|
||||
elif opts.sampling == SamplingMethod.TOP_K:
|
||||
elif opts.sampling is pgen.SamplingMethod.TOP_K:
|
||||
probs, indices = probs_cpu.topk(opts.k)
|
||||
indices = indices.detach().numpy().tolist()
|
||||
if len(indices) <= 0: return []
|
||||
|
|
@ -147,7 +152,7 @@ class WDLike(PromptGenerator):
|
|||
results = np.random.choice(tags_np, opts.n, replace=False)
|
||||
|
||||
# brought from https://nn.labml.ai/sampling/nucleus.html
|
||||
elif opts.sampling == SamplingMethod.TOP_P:
|
||||
elif opts.sampling is pgen.SamplingMethod.TOP_P:
|
||||
sorted_probs, sorted_indices = probs_cpu.sort(descending=True)
|
||||
cs_probs = torch.cumsum(sorted_probs, dim=0)
|
||||
nucleus = cs_probs < opts.p
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ import os
|
|||
import re
|
||||
from modules import scripts
|
||||
|
||||
# to use intellisense on vscode
|
||||
DEVELOP = False
|
||||
|
||||
def get_abspath(path: str):
|
||||
return os.path.abspath(os.path.join(scripts.basedir(), path))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue