Squashed commit of the following:

commit 0571f2dab8bf3d8e1826f0fb5d53ca5b9bfcf6b0
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 01:05:27 2022 +0900

    Update README.md

commit b4a5164081662db9085e2b5c17af6e98473180ff
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:44:25 2022 +0900

    remove

commit 127946a3d200825eb15244ad09db5361db8bf89c
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:28:28 2022 +0900

    Revert "remove"

    This reverts commit f7666de150d1bb8acdcedcb0b82ddfb939206088.

commit f7666de150d1bb8acdcedcb0b82ddfb939206088
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:26:44 2022 +0900

    remove

commit 41b99ba60bd59050384bdd55a498395e4148bfda
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:26:20 2022 +0900

    Update .gitignore

commit 2fdce554bdd92bc87dce13181aa67c58d523d830
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:24:56 2022 +0900

    Update .gitignore

commit 50927e064e5ffe052d785038f4c7de851ede6c91
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:23:49 2022 +0900

    upload database

commit 39172118e8d8fac2ed7f3aab0c53e2a6e89c16b3
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:22:34 2022 +0900

    Update .gitignore

commit 7c6225ade3a55fdf342eb2b3617302edab595192
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:21:25 2022 +0900

    Create test

commit 8010ec05a5e7836fc004450bc82cc3ce93d975ad
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:20:22 2022 +0900

    Update .gitignore

commit 6265874465281db81a3935f651537ac9c6cfc30a
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Wed Dec 28 00:17:17 2022 +0900

    update README.md

commit 3f52e109aa0b7a41f4970b11b933c7f57c7ddd62
Merge: 180c4c5 867a48a
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Tue Dec 27 23:20:32 2022 +0900

    Merge branch 'main' into feature/well-filtered-database

commit 180c4c5b414058939eb47dfe7783eb561c5da65e
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Tue Dec 27 23:19:51 2022 +0900

    implement dynamic database loader

commit fdccc1a9cb9a80ce9b4042892afc9922a20e6524
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Sun Dec 25 19:39:32 2022 +0900

    Update README.md

commit b539cbf4d3791777e424883b2a5bd64a73de9bf0
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Sun Dec 25 19:34:11 2022 +0900

    remove thread lock to allow generating prompt while doing other tasks

commit 43b5bfd73a549f2bced6704cdcfacb1483648b17
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Sun Dec 25 19:26:26 2022 +0900

    update README

commit 58df449f36678de2e9b35726315d53a281bdc380
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date:   Sun Dec 25 19:05:30 2022 +0900

    add new language model and database

    The size of data will be so large
main
toshiaki1729 2022-12-28 01:10:02 +09:00
parent 867a48ad52
commit 374325e3e2
14 changed files with 14859 additions and 5579 deletions

2
.gitignore vendored
View File

@ -6,6 +6,8 @@ __pycache__/
# C extensions # C extensions
*.so *.so
archive/
# Distribution / packaging # Distribution / packaging
.Python .Python
build/ build/

View File

@ -18,9 +18,13 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt.git
## Usage ## Usage
1. Type some words into "Input Theme" 1. Type some words into "Input Theme"
1. Push "Generate" button 1. Push "Generate" button
![](pic/pic1.png)
### Tips ### Tips
- For more creative result - For more creative result
@ -34,20 +38,34 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt.git
- You can enter very long sentences, but the more specific it is, the fewer results you will get. - You can enter very long sentences, but the more specific it is, the fewer results you will get.
## How it works ## How it works
It's doing nothing special; It's doing nothing special;
1. Danbooru tags and it's descriptions are in the `data` folder 1. Danbooru tags and it's descriptions are in the `data` folder
- descriptions are generated from wiki and already tokenized - descriptions are generated from wiki and already tokenized
- [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model is used to tokenize the text - [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) models are used to tokenize the text
- for now, some tags (such as <1k tagged or containing title of the work) are deleted to prevent from "noisy" result 1. Tokenize your input text and calculate cosine similarity with all tag descriptions
1. Tokenize your input text and calculate cosine similarity to each tag descriptions
1. Choose some tags depending on their similarities 1. Choose some tags depending on their similarities
![](pic/pic1.png)
### Database (Optional)
You can choose the following dataset if needed.
Download the following, unzip and put its contents into `text2prompt-root-dir/data/danbooru/`.
|Tag description|all-mpnet-base-v2|all-MiniLM-L6-v2|
|:---|:---:|:---:|
|**well filtered (recommended)**|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_strict_all-mpnet-base-v2.zip) (preinstalled)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_strict_all-MiniLM-L6-v2.zip)|
|normal (same as previous one)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_normal_all-mpnet-base-v2.zip)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_normal_all-MiniLM-L6-v2.zip)|
|full (noisy)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_full_all-mpnet-base-v2.zip)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_full_all-MiniLM-L6-v2.zip)|
**well filtered:** Tags are removed if their description include the title of the work. These tags are heavily related to a specific work, meaning they are not "general" tags.
**normal:** Tags containing the title of a work, like tag_name(work_name), are removed.
**full:** Including all tags.
--- ---
### More detailed ## More detailed description
$i \in N = \\{1, 2, ..., n\\}$ for index number of the tag $i \in N = \\{1, 2, ..., n\\}$ for index number of the tag
$s_i = S_C(d_i, t)$ for cosine similarity between tag description $d_i$ and your text $t$ $s_i = S_C(d_i, t)$ for cosine similarity between tag description $d_i$ and your text $t$
$P_i$ for probability for the tag to be chosen $P_i$ for probability for the tag to be chosen

View File

@ -0,0 +1,8 @@
50,13005
100,11294
200,9490
500,7099
1000,5425
2000,3956
5000,2458
10000,1633
1 50 13005
2 100 11294
3 200 9490
4 500 7099
5 1000 5425
6 2000 3956
7 5000 2458
8 10000 1633

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 101 KiB

After

Width:  |  Height:  |  Size: 132 KiB

View File

@ -6,7 +6,7 @@ from modules import generation_parameters_copypaste as params_copypaste
import scripts.t2p.prompt_generator as pgen import scripts.t2p.prompt_generator as pgen
wd_like = None wd_like = pgen.WDLike()
# brought from modules/deepbooru.py # brought from modules/deepbooru.py
re_special = re.compile(r'([\\()])') re_special = re.compile(r'([\\()])')
@ -22,18 +22,35 @@ def get_sampling(choice: int):
elif choice == 2: return pgen.SamplingMethod.TOP_P elif choice == 2: return pgen.SamplingMethod.TOP_P
else: raise NotImplementedError() else: raise NotImplementedError()
def generate_prompt(text: str, conversion: int, power: float, sampling: int, n: int, k: int, p: float, weighted: bool, replace_underscore: bool, excape_brackets: bool): def get_tag_range_txt(tag_range: int):
global wd_like if wd_like.database is None:
if not wd_like: return 'Tag range: NONE'
print('Loading tag data and model') maxval = len(wd_like.database.tag_idx) - 1
wd_like = pgen.WDLike() i = max(0, min(tag_range, maxval))
wd_like.load_model() r = wd_like.database.tag_idx[i]
print('Loaded') return f'Tag range: <b> &gt; {r[0]} tagged</b> ({r[1] + 1} tags total)'
tags = wd_like(text, pgen.GenerationSettings(get_conversion(conversion), power, get_sampling(sampling), n, k, p, weighted))
def dd_database_changed(database_name: str, tag_range: int):
wd_like.load_data(database_name)
return [
gr.Slider.update(tag_range, 0, len(wd_like.database.tag_idx) - 1),
get_tag_range_txt(tag_range)
]
def sl_tag_range_changed(tag_range: int):
return get_tag_range_txt(tag_range)
def generate_prompt(text: str, tag_range: int, conversion: int, power: float, sampling: int, n: int, k: int, p: float, weighted: bool, replace_underscore: bool, excape_brackets: bool):
wd_like.load_model() #skip loading if not needed
tags = wd_like(text, pgen.GenerationSettings(tag_range, get_conversion(conversion), power, get_sampling(sampling), n, k, p, weighted))
if replace_underscore: tags = [t.replace('_', ' ') for t in tags] if replace_underscore: tags = [t.replace('_', ' ') for t in tags]
if excape_brackets: tags = [re.sub(re_special, r'\\\1', t) for t in tags] if excape_brackets: tags = [re.sub(re_special, r'\\\1', t) for t in tags]
return ', '.join(tags) return ', '.join(tags)
def on_ui_tabs(): def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as text2prompt_interface: with gr.Blocks(analytics_enabled=False) as text2prompt_interface:
with gr.Row(): with gr.Row():
@ -47,8 +64,14 @@ def on_ui_tabs():
buttons = params_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) buttons = params_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
params_copypaste.bind_buttons(buttons, None, tb_output) params_copypaste.bind_buttons(buttons, None, tb_output)
with gr.Column(): with gr.Column(variant='panel'):
gr.HTML(value='Generation Settings') gr.HTML(value='Generation Settings')
choices = wd_like.get_model_names()
with gr.Column():
dd_database = gr.Dropdown(choices=choices, value=None, interactive=True, label='Database')
sl_tag_range = gr.Slider(0, 8, 0, step=1, interactive=True, label='Tag count filter')
txt_tag_range = gr.HTML(get_tag_range_txt(0))
with gr.Column():
rb_prob_conversion_method = gr.Radio(choices=['Cutoff and Power', 'Softmax'], value='Cutoff and Power', type='index', label='Method to convert similarity into probability') rb_prob_conversion_method = gr.Radio(choices=['Cutoff and Power', 'Softmax'], value='Cutoff and Power', type='index', label='Method to convert similarity into probability')
sl_power = gr.Slider(0, 5, value=2, step=0.1, label='Power', interactive=True) sl_power = gr.Slider(0, 5, value=2, step=0.1, label='Power', interactive=True)
rb_sampling_method = gr.Radio(choices=['NONE', 'Top-k', 'Top-p (Nucleus)'], value='Top-k', type='index', label='Sampling method') rb_sampling_method = gr.Radio(choices=['NONE', 'Top-k', 'Top-p (Nucleus)'], value='Top-k', type='index', label='Sampling method')
@ -57,6 +80,18 @@ def on_ui_tabs():
sl_p_value = gr.Slider(0, 1, label='p value', value=0.1, step=0.01, interactive=True) sl_p_value = gr.Slider(0, 1, label='p value', value=0.1, step=0.01, interactive=True)
cb_weighted = gr.Checkbox(value=True, label='Use weighted choice', interactive=True) cb_weighted = gr.Checkbox(value=True, label='Use weighted choice', interactive=True)
dd_database.change(
fn=dd_database_changed,
inputs=[dd_database, sl_tag_range],
outputs=[sl_tag_range, txt_tag_range]
)
sl_tag_range.change(
fn=sl_tag_range_changed,
inputs=sl_tag_range,
outputs=txt_tag_range
)
nb_max_tag_num.change( nb_max_tag_num.change(
fn=lambda x: max(0, x), fn=lambda x: max(0, x),
inputs=nb_max_tag_num, inputs=nb_max_tag_num,
@ -71,7 +106,19 @@ def on_ui_tabs():
btn_generate.click( btn_generate.click(
fn=generate_prompt, fn=generate_prompt,
inputs=[tb_input, rb_prob_conversion_method, sl_power, rb_sampling_method, nb_max_tag_num, nb_k_value, sl_p_value, cb_weighted, cb_replace_underscore, cb_escape_brackets], inputs=[
tb_input,
sl_tag_range,
rb_prob_conversion_method,
sl_power,
rb_sampling_method,
nb_max_tag_num,
nb_k_value,
sl_p_value,
cb_weighted,
cb_replace_underscore,
cb_escape_brackets
],
outputs=tb_output outputs=tb_output
) )

View File

@ -13,6 +13,7 @@ class ProbabilityConversion(Enum):
class GenerationSettings: class GenerationSettings:
def __init__( def __init__(
self, self,
tag_range: int = 0,
conversion: ProbabilityConversion = ProbabilityConversion.CUTOFF_AND_POWER, conversion: ProbabilityConversion = ProbabilityConversion.CUTOFF_AND_POWER,
prob_power: float = 2, prob_power: float = 2,
sampling: SamplingMethod = SamplingMethod.TOP_K, sampling: SamplingMethod = SamplingMethod.TOP_K,
@ -21,6 +22,7 @@ class GenerationSettings:
p: Optional[float] = 0.3, p: Optional[float] = 0.3,
weighted: bool = True): weighted: bool = True):
self.tag_range = tag_range
self.sampling = sampling self.sampling = sampling
self.conversion = conversion self.conversion = conversion
self.n = n self.n = n
@ -31,6 +33,13 @@ class GenerationSettings:
class PromptGenerator: class PromptGenerator:
def clear(self): pass
def load_data(self, model_id: str, data_id: str):
raise NotImplementedError()
def load_model(self, model_id: str):
raise NotImplementedError()
def ready(self) -> bool:
raise NotImplementedError()
def __call__(self, text: str, settings: GenerationSettings) -> List[str]: def __call__(self, text: str, settings: GenerationSettings) -> List[str]:
raise NotImplementedError() raise NotImplementedError()

View File

@ -0,0 +1,100 @@
import os
import re
import csv
from typing import Dict
import numpy as np
from .. import settings
class Database:
def __init__(self, database_path: str, re_filename: re.Pattern[str]):
self.read_files(database_path, re_filename)
def read_files(self, database_path: str, re_filename: re.Pattern[str]):
self.clear()
self.database_path = database_path
fn, _ = os.path.splitext(os.path.basename(database_path))
m = re_filename.match(fn)
self.size_name = m.group(1)
self.model_name = m.group(2)
if self.model_name not in settings.TOKENIZER_NAMES:
print(f'Cannot use database in {database_path}; Incompatible model name "{self.model_name}"')
self.clear()
return
tag_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tags.txt')
if not os.path.isfile(tag_path):
print(f'Cannot use database in {database_path}; No tag file exists')
self.clear()
return
with open(tag_path, mode='r', encoding='utf8', newline='\n') as f:
self.tags = [l.strip() for l in f.readlines()]
tag_idx_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tagidx.csv')
if not os.path.isfile(tag_idx_path):
print(f'Cannot read tag indices file. Tag count filter cannot be used.')
else:
with open(tag_idx_path, mode='r', encoding='utf8', newline='') as f:
cr = csv.reader(f)
for row in cr:
self.tag_idx.append((int(row[0]), int(row[1])))
self.tag_idx.sort(key=lambda t : t[0])
self.tag_idx = [(0, len(self.tags))] + self.tag_idx
def clear(self):
self.database_path = ''
self.model_name = ''
self.size_name = ''
self.tag_idx = []
self.tags = []
self.data: np.ndarray = None
def ready_to_load(self):
return self.database_path \
and self.model_name \
and self.size_name \
and self.tags \
and self.tag_idx
def loaded(self):
return self.data is not None
def load(self):
if not self.ready_to_load(): return None
if not self.loaded():
self.data = np.load(self.database_path)['db']
return self
def name(self):
return f'{self.model_name} : {self.size_name}'
class DatabaseLoader:
def __init__(self, path: str, re_filename: re.Pattern[str]):
self.datas: Dict[str, Database] = dict()
self.preload(path, re_filename)
def preload(self, path: str, re_filename: re.Pattern[str]):
dirs = os.listdir(path)
for d in dirs:
filepath = os.path.join(path, d)
if not os.path.isfile(filepath): continue
_, ext = os.path.splitext(filepath)
if ext == '.npz':
ds = Database(filepath, re_filename)
self.datas[ds.name()] = ds
print('[text2prompt] Loaded following databases')
print(sorted(self.datas.keys()))
def load(self, database_name: str):
database = self.datas.get(database_name)
return database.load() if database else None

View File

@ -6,6 +6,7 @@ import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbabilityConversion from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbabilityConversion
from .database_loader import DatabaseLoader
from .. import settings from .. import settings
@ -21,28 +22,62 @@ def mean_pooling(model_output, attention_mask):
class WDLike(PromptGenerator): class WDLike(PromptGenerator):
def __init__(self): def __init__(self):
self.tags = [] self.clear()
self.database_loader = DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
def get_model_names(self):
return sorted(self.database_loader.datas.keys())
def clear(self):
self.database_loader = None
self.database = None
self.tags = None
self.tokens = None self.tokens = None
self.tokenizer = None self.tokenizer = None
self.model = None self.model = None
self.load_data() self.loaded_model_name = None
def load_data(self): def load_data(self, database_name: str):
with open(settings.WDLIKE_TAG_PATH, mode='r', encoding='utf8', newline='\n') as f: print('Loading database...')
self.tags = [l.strip() for l in f.readlines()] self.database = self.database_loader.load(database_name)
with open(settings.WDLIKE_TOKEN_PATH, mode='rb') as f: print('Loaded')
self.tokens = np.load(f)
def load_model(self): def load_model(self):
if self.database is None:
print('Cannot load model; Database is not loaded.')
return
from modules.devices import device from modules.devices import device
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers # brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
# Load model from HuggingFace Hub # Load model from HuggingFace Hub
self.tokenizer = AutoTokenizer.from_pretrained(settings.WDLIKE_MODEL_NAME) if self.loaded_model_name and self.loaded_model_name == self.database.model_name:
self.model = AutoModel.from_pretrained(settings.WDLIKE_MODEL_NAME).to(device) return
else:
print('Loading model...')
self.tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name])
self.model = AutoModel.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]).to(device)
self.loaded_model_name = self.database.model_name
print('model loaded')
def unload_model(self):
if self.tokenizer is not None:
del self.tokenizer
if self.model is not None:
del self.model
def ready(self) -> bool:
return self.database is not None \
and self.database.loaded() \
and self.model is not None \
and self.tokenizer is not None \
and self.loaded_model_name and self.loaded_model_name == self.database.model_name
def __call__(self, text: str, opts: GenerationSettings) -> List[str]: def __call__(self, text: str, opts: GenerationSettings) -> List[str]:
if not self.model or not self.tokenizer: if not self.ready(): return ''
return ''
i = max(0, min(opts.tag_range, len(self.database.tag_idx) - 1))
r = self.database.tag_idx[i][1]
self.tokens = self.database.data[:r, :]
self.tags = self.database.tags[:r]
from modules.devices import device from modules.devices import device

View File

@ -1,6 +1,18 @@
import os import os
import re
from modules import scripts from modules import scripts
WDLIKE_TAG_PATH = os.path.abspath(os.path.join(scripts.basedir(), 'data/danbooru_wiki_tags.txt')) def get_abspath(path: str):
WDLIKE_TOKEN_PATH = os.path.abspath(os.path.join(scripts.basedir(), 'data/danbooru_wiki_token_all-mpnet-base-v2.npy')) return os.path.abspath(os.path.join(scripts.basedir(), path))
WDLIKE_MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
TOKENIZER_NAMES = ['all-mpnet-base-v2', 'all-MiniLM-L6-v2']
TOKENIZER_MODELS = {
TOKENIZER_NAMES[0]: f'sentence-transformers/{TOKENIZER_NAMES[0]}',
TOKENIZER_NAMES[1]: f'sentence-transformers/{TOKENIZER_NAMES[1]}'
}
DATABASE_PATH_DANBOORU = get_abspath('data/danbooru')
RE_TOKENFILE_DANBOORU = re.compile(r'(danbooru_[^_]+)_token_([^_]+)')