Squashed commit of the following:

commit 0571f2dab8bf3d8e1826f0fb5d53ca5b9bfcf6b0
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 01:05:27 2022 +0900

    Update README.md

commit b4a5164081662db9085e2b5c17af6e98473180ff
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:44:25 2022 +0900

    remove

commit 127946a3d200825eb15244ad09db5361db8bf89c
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:28:28 2022 +0900

    Revert "remove"

    This reverts commit f7666de150d1bb8acdcedcb0b82ddfb939206088.

commit f7666de150d1bb8acdcedcb0b82ddfb939206088
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:26:44 2022 +0900

    remove

commit 41b99ba60bd59050384bdd55a498395e4148bfda
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:26:20 2022 +0900

    Update .gitignore

commit 2fdce554bdd92bc87dce13181aa67c58d523d830
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:24:56 2022 +0900

    Update .gitignore

commit 50927e064e5ffe052d785038f4c7de851ede6c91
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:23:49 2022 +0900

    upload database

commit 39172118e8d8fac2ed7f3aab0c53e2a6e89c16b3
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:22:34 2022 +0900

    Update .gitignore

commit 7c6225ade3a55fdf342eb2b3617302edab595192
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:21:25 2022 +0900

    Create test

commit 8010ec05a5e7836fc004450bc82cc3ce93d975ad
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:20:22 2022 +0900

    Update .gitignore

commit 6265874465281db81a3935f651537ac9c6cfc30a
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:17:17 2022 +0900

    update README.md

commit 3f52e109aa0b7a41f4970b11b933c7f57c7ddd62
Merge: 180c4c5 867a48a
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Tue Dec 27 23:20:32 2022 +0900

    Merge branch 'main' into feature/well-filtered-database

commit 180c4c5b414058939eb47dfe7783eb561c5da65e
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Tue Dec 27 23:19:51 2022 +0900

    implement dynamic database loader

commit fdccc1a9cb9a80ce9b4042892afc9922a20e6524
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Sun Dec 25 19:39:32 2022 +0900

    Update README.md

commit b539cbf4d3791777e424883b2a5bd64a73de9bf0
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Sun Dec 25 19:34:11 2022 +0900

    remove thread lock to allow generating prompt while doing other tasks

commit 43b5bfd73a549f2bced6704cdcfacb1483648b17
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Sun Dec 25 19:26:26 2022 +0900

    update README

commit 58df449f36678de2e9b35726315d53a281bdc380
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Sun Dec 25 19:05:30 2022 +0900

    add new language model and database

    The size of data will be so large
main
toshiaki1729 2022-12-28 01:10:02 +09:00
parent 867a48ad52
commit 374325e3e2
14 changed files with 14859 additions and 5579 deletions

2
.gitignore vendored
View File

@ -6,6 +6,8 @@ __pycache__/
# C extensions
*.so
archive/
# Distribution / packaging
.Python
build/

View File

@ -18,9 +18,13 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt.git
## Usage
1. Type some words into "Input Theme"
1. Push "Generate" button
![](pic/pic1.png)
### Tips
- For more creative result
@ -34,20 +38,34 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt.git
- You can enter very long sentences, but the more specific it is, the fewer results you will get.
## How it works
It's doing nothing special;
1. Danbooru tags and it's descriptions are in the `data` folder
- descriptions are generated from wiki and already tokenized
- [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model is used to tokenize the text
- for now, some tags (such as <1k tagged or containing title of the work) are deleted to prevent from "noisy" result
1. Tokenize your input text and calculate cosine similarity to each tag descriptions
- [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) models are used to tokenize the text
1. Tokenize your input text and calculate cosine similarity with all tag descriptions
1. Choose some tags depending on their similarities
![](pic/pic1.png)
### Database (Optional)
You can choose the following dataset if needed.
Download the following, unzip and put its contents into `text2prompt-root-dir/data/danbooru/`.
|Tag description|all-mpnet-base-v2|all-MiniLM-L6-v2|
|:---|:---:|:---:|
|**well filtered (recommended)**|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_strict_all-mpnet-base-v2.zip) (preinstalled)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_strict_all-MiniLM-L6-v2.zip)|
|normal (same as previous one)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_normal_all-mpnet-base-v2.zip)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_normal_all-MiniLM-L6-v2.zip)|
|full (noisy)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_full_all-mpnet-base-v2.zip)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_full_all-MiniLM-L6-v2.zip)|
**well filtered:** Tags are removed if their description include the title of the work. These tags are heavily related to a specific work, meaning they are not "general" tags.
**normal:** Tags containing the title of a work, like tag_name(work_name), are removed.
**full:** Including all tags.
---
### More detailed
## More detailed description
$i \in N = \\{1, 2, ..., n\\}$ for index number of the tag
$s_i = S_C(d_i, t)$ for cosine similarity between tag description $d_i$ and your text $t$
$P_i$ for probability for the tag to be chosen

View File

@ -0,0 +1,8 @@
50,13005
100,11294
200,9490
500,7099
1000,5425
2000,3956
5000,2458
10000,1633
1 50 13005
2 100 11294
3 200 9490
4 500 7099
5 1000 5425
6 2000 3956
7 5000 2458
8 10000 1633

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 101 KiB

After

Width:  |  Height:  |  Size: 132 KiB

View File

@ -6,7 +6,7 @@ from modules import generation_parameters_copypaste as params_copypaste
import scripts.t2p.prompt_generator as pgen
wd_like = None
wd_like = pgen.WDLike()
# brought from modules/deepbooru.py
re_special = re.compile(r'([\\()])')
@ -22,18 +22,35 @@ def get_sampling(choice: int):
elif choice == 2: return pgen.SamplingMethod.TOP_P
else: raise NotImplementedError()
def generate_prompt(text: str, conversion: int, power: float, sampling: int, n: int, k: int, p: float, weighted: bool, replace_underscore: bool, excape_brackets: bool):
global wd_like
if not wd_like:
print('Loading tag data and model')
wd_like = pgen.WDLike()
wd_like.load_model()
print('Loaded')
tags = wd_like(text, pgen.GenerationSettings(get_conversion(conversion), power, get_sampling(sampling), n, k, p, weighted))
def get_tag_range_txt(tag_range: int):
if wd_like.database is None:
return 'Tag range: NONE'
maxval = len(wd_like.database.tag_idx) - 1
i = max(0, min(tag_range, maxval))
r = wd_like.database.tag_idx[i]
return f'Tag range: <b> &gt; {r[0]} tagged</b> ({r[1] + 1} tags total)'
def dd_database_changed(database_name: str, tag_range: int):
wd_like.load_data(database_name)
return [
gr.Slider.update(tag_range, 0, len(wd_like.database.tag_idx) - 1),
get_tag_range_txt(tag_range)
]
def sl_tag_range_changed(tag_range: int):
return get_tag_range_txt(tag_range)
def generate_prompt(text: str, tag_range: int, conversion: int, power: float, sampling: int, n: int, k: int, p: float, weighted: bool, replace_underscore: bool, excape_brackets: bool):
wd_like.load_model() #skip loading if not needed
tags = wd_like(text, pgen.GenerationSettings(tag_range, get_conversion(conversion), power, get_sampling(sampling), n, k, p, weighted))
if replace_underscore: tags = [t.replace('_', ' ') for t in tags]
if excape_brackets: tags = [re.sub(re_special, r'\\\1', t) for t in tags]
return ', '.join(tags)
def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as text2prompt_interface:
with gr.Row():
@ -47,15 +64,33 @@ def on_ui_tabs():
buttons = params_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
params_copypaste.bind_buttons(buttons, None, tb_output)
with gr.Column():
with gr.Column(variant='panel'):
gr.HTML(value='Generation Settings')
rb_prob_conversion_method = gr.Radio(choices=['Cutoff and Power', 'Softmax'], value='Cutoff and Power', type='index', label='Method to convert similarity into probability')
sl_power = gr.Slider(0, 5, value=2, step=0.1, label='Power', interactive=True)
rb_sampling_method = gr.Radio(choices=['NONE', 'Top-k', 'Top-p (Nucleus)'], value='Top-k', type='index', label='Sampling method')
nb_max_tag_num = gr.Number(value=20, label='Max number of tags', precision=0, interactive=True)
nb_k_value = gr.Number(value=50, label='k value', precision=0, interactive=True)
sl_p_value = gr.Slider(0, 1, label='p value', value=0.1, step=0.01, interactive=True)
cb_weighted = gr.Checkbox(value=True, label='Use weighted choice', interactive=True)
choices = wd_like.get_model_names()
with gr.Column():
dd_database = gr.Dropdown(choices=choices, value=None, interactive=True, label='Database')
sl_tag_range = gr.Slider(0, 8, 0, step=1, interactive=True, label='Tag count filter')
txt_tag_range = gr.HTML(get_tag_range_txt(0))
with gr.Column():
rb_prob_conversion_method = gr.Radio(choices=['Cutoff and Power', 'Softmax'], value='Cutoff and Power', type='index', label='Method to convert similarity into probability')
sl_power = gr.Slider(0, 5, value=2, step=0.1, label='Power', interactive=True)
rb_sampling_method = gr.Radio(choices=['NONE', 'Top-k', 'Top-p (Nucleus)'], value='Top-k', type='index', label='Sampling method')
nb_max_tag_num = gr.Number(value=20, label='Max number of tags', precision=0, interactive=True)
nb_k_value = gr.Number(value=50, label='k value', precision=0, interactive=True)
sl_p_value = gr.Slider(0, 1, label='p value', value=0.1, step=0.01, interactive=True)
cb_weighted = gr.Checkbox(value=True, label='Use weighted choice', interactive=True)
dd_database.change(
fn=dd_database_changed,
inputs=[dd_database, sl_tag_range],
outputs=[sl_tag_range, txt_tag_range]
)
sl_tag_range.change(
fn=sl_tag_range_changed,
inputs=sl_tag_range,
outputs=txt_tag_range
)
nb_max_tag_num.change(
fn=lambda x: max(0, x),
@ -71,7 +106,19 @@ def on_ui_tabs():
btn_generate.click(
fn=generate_prompt,
inputs=[tb_input, rb_prob_conversion_method, sl_power, rb_sampling_method, nb_max_tag_num, nb_k_value, sl_p_value, cb_weighted, cb_replace_underscore, cb_escape_brackets],
inputs=[
tb_input,
sl_tag_range,
rb_prob_conversion_method,
sl_power,
rb_sampling_method,
nb_max_tag_num,
nb_k_value,
sl_p_value,
cb_weighted,
cb_replace_underscore,
cb_escape_brackets
],
outputs=tb_output
)

View File

@ -12,8 +12,9 @@ class ProbabilityConversion(Enum):
class GenerationSettings:
def __init__(
self,
conversion: ProbabilityConversion = ProbabilityConversion.CUTOFF_AND_POWER,
self,
tag_range: int = 0,
conversion: ProbabilityConversion = ProbabilityConversion.CUTOFF_AND_POWER,
prob_power: float = 2,
sampling: SamplingMethod = SamplingMethod.TOP_K,
n:int = 20,
@ -21,6 +22,7 @@ class GenerationSettings:
p: Optional[float] = 0.3,
weighted: bool = True):
self.tag_range = tag_range
self.sampling = sampling
self.conversion = conversion
self.n = n
@ -31,6 +33,13 @@ class GenerationSettings:
class PromptGenerator:
def clear(self): pass
def load_data(self, model_id: str, data_id: str):
raise NotImplementedError()
def load_model(self, model_id: str):
raise NotImplementedError()
def ready(self) -> bool:
raise NotImplementedError()
def __call__(self, text: str, settings: GenerationSettings) -> List[str]:
raise NotImplementedError()

View File

@ -0,0 +1,100 @@
import os
import re
import csv
from typing import Dict
import numpy as np
from .. import settings
class Database:
def __init__(self, database_path: str, re_filename: re.Pattern[str]):
self.read_files(database_path, re_filename)
def read_files(self, database_path: str, re_filename: re.Pattern[str]):
self.clear()
self.database_path = database_path
fn, _ = os.path.splitext(os.path.basename(database_path))
m = re_filename.match(fn)
self.size_name = m.group(1)
self.model_name = m.group(2)
if self.model_name not in settings.TOKENIZER_NAMES:
print(f'Cannot use database in {database_path}; Incompatible model name "{self.model_name}"')
self.clear()
return
tag_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tags.txt')
if not os.path.isfile(tag_path):
print(f'Cannot use database in {database_path}; No tag file exists')
self.clear()
return
with open(tag_path, mode='r', encoding='utf8', newline='\n') as f:
self.tags = [l.strip() for l in f.readlines()]
tag_idx_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tagidx.csv')
if not os.path.isfile(tag_idx_path):
print(f'Cannot read tag indices file. Tag count filter cannot be used.')
else:
with open(tag_idx_path, mode='r', encoding='utf8', newline='') as f:
cr = csv.reader(f)
for row in cr:
self.tag_idx.append((int(row[0]), int(row[1])))
self.tag_idx.sort(key=lambda t : t[0])
self.tag_idx = [(0, len(self.tags))] + self.tag_idx
def clear(self):
self.database_path = ''
self.model_name = ''
self.size_name = ''
self.tag_idx = []
self.tags = []
self.data: np.ndarray = None
def ready_to_load(self):
return self.database_path \
and self.model_name \
and self.size_name \
and self.tags \
and self.tag_idx
def loaded(self):
return self.data is not None
def load(self):
if not self.ready_to_load(): return None
if not self.loaded():
self.data = np.load(self.database_path)['db']
return self
def name(self):
return f'{self.model_name} : {self.size_name}'
class DatabaseLoader:
def __init__(self, path: str, re_filename: re.Pattern[str]):
self.datas: Dict[str, Database] = dict()
self.preload(path, re_filename)
def preload(self, path: str, re_filename: re.Pattern[str]):
dirs = os.listdir(path)
for d in dirs:
filepath = os.path.join(path, d)
if not os.path.isfile(filepath): continue
_, ext = os.path.splitext(filepath)
if ext == '.npz':
ds = Database(filepath, re_filename)
self.datas[ds.name()] = ds
print('[text2prompt] Loaded following databases')
print(sorted(self.datas.keys()))
def load(self, database_name: str):
database = self.datas.get(database_name)
return database.load() if database else None

View File

@ -6,6 +6,7 @@ import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbabilityConversion
from .database_loader import DatabaseLoader
from .. import settings
@ -21,28 +22,62 @@ def mean_pooling(model_output, attention_mask):
class WDLike(PromptGenerator):
def __init__(self):
self.tags = []
self.clear()
self.database_loader = DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
def get_model_names(self):
return sorted(self.database_loader.datas.keys())
def clear(self):
self.database_loader = None
self.database = None
self.tags = None
self.tokens = None
self.tokenizer = None
self.model = None
self.load_data()
self.loaded_model_name = None
def load_data(self):
with open(settings.WDLIKE_TAG_PATH, mode='r', encoding='utf8', newline='\n') as f:
self.tags = [l.strip() for l in f.readlines()]
with open(settings.WDLIKE_TOKEN_PATH, mode='rb') as f:
self.tokens = np.load(f)
def load_data(self, database_name: str):
print('Loading database...')
self.database = self.database_loader.load(database_name)
print('Loaded')
def load_model(self):
if self.database is None:
print('Cannot load model; Database is not loaded.')
return
from modules.devices import device
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
# Load model from HuggingFace Hub
self.tokenizer = AutoTokenizer.from_pretrained(settings.WDLIKE_MODEL_NAME)
self.model = AutoModel.from_pretrained(settings.WDLIKE_MODEL_NAME).to(device)
if self.loaded_model_name and self.loaded_model_name == self.database.model_name:
return
else:
print('Loading model...')
self.tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name])
self.model = AutoModel.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]).to(device)
self.loaded_model_name = self.database.model_name
print('model loaded')
def unload_model(self):
if self.tokenizer is not None:
del self.tokenizer
if self.model is not None:
del self.model
def ready(self) -> bool:
return self.database is not None \
and self.database.loaded() \
and self.model is not None \
and self.tokenizer is not None \
and self.loaded_model_name and self.loaded_model_name == self.database.model_name
def __call__(self, text: str, opts: GenerationSettings) -> List[str]:
if not self.model or not self.tokenizer:
return ''
if not self.ready(): return ''
i = max(0, min(opts.tag_range, len(self.database.tag_idx) - 1))
r = self.database.tag_idx[i][1]
self.tokens = self.database.data[:r, :]
self.tags = self.database.tags[:r]
from modules.devices import device

View File

@ -1,6 +1,18 @@
import os
import re
from modules import scripts
WDLIKE_TAG_PATH = os.path.abspath(os.path.join(scripts.basedir(), 'data/danbooru_wiki_tags.txt'))
WDLIKE_TOKEN_PATH = os.path.abspath(os.path.join(scripts.basedir(), 'data/danbooru_wiki_token_all-mpnet-base-v2.npy'))
WDLIKE_MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
def get_abspath(path: str):
return os.path.abspath(os.path.join(scripts.basedir(), path))
TOKENIZER_NAMES = ['all-mpnet-base-v2', 'all-MiniLM-L6-v2']
TOKENIZER_MODELS = {
TOKENIZER_NAMES[0]: f'sentence-transformers/{TOKENIZER_NAMES[0]}',
TOKENIZER_NAMES[1]: f'sentence-transformers/{TOKENIZER_NAMES[1]}'
}
DATABASE_PATH_DANBOORU = get_abspath('data/danbooru')
RE_TOKENFILE_DANBOORU = re.compile(r'(danbooru_[^_]+)_token_([^_]+)')