Compare commits

...

20 Commits

Author SHA1 Message Date
toshiaki1729 752a660d2b
Update README.md 2024-05-28 13:10:35 +09:00
toshiaki1729 bcfec2f988
README.md の更新
wrong terms
2024-05-24 12:24:00 +09:00
toshiaki1729 c4d8cec126 fix #3
fix type hint incompatible with python < 3.9
2023-02-24 17:29:55 +09:00
toshiaki1729 a6c052ef7a possible fix for #2 2023-02-24 14:40:09 +09:00
toshiaki1729 f45016a4d1 Update wd_like.py 2022-12-29 22:26:16 +09:00
toshiaki1729 238318c605 fix generating with all tags if all probability < 0 2022-12-29 22:18:19 +09:00
toshiaki1729 362da88287 Update README.md 2022-12-29 21:16:33 +09:00
toshiaki1729 f5737b77ea Update database_loader.py
oops
2022-12-29 21:05:59 +09:00
toshiaki1729 da5471671e Update README.md 2022-12-29 21:02:54 +09:00
toshiaki1729 0c45abc73e Update pic1.png 2022-12-29 20:54:47 +09:00
toshiaki1729 c28e403288 update README: implement negative text 2022-12-29 20:50:08 +09:00
toshiaki1729 0406db1963 Merge branch 'main' of https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt 2022-12-29 20:45:16 +09:00
toshiaki1729 03d40a1456 implement negative text 2022-12-29 20:45:12 +09:00
toshiaki1729 6435280d03
Update README.md 2022-12-29 00:47:23 +09:00
toshiaki1729 7e0841d4af remove excess code 2022-12-28 17:13:08 +09:00
toshiaki1729 b99ce707f4 Merge branch 'main' of https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt 2022-12-28 17:07:25 +09:00
toshiaki1729 4972255f31 enable hotreload with "Apply and restart UI" on Extensions tab 2022-12-28 17:07:22 +09:00
toshiaki1729 1e174b6c83
Create LICENSE 2022-12-28 14:51:20 +09:00
toshiaki1729 38fc93c5fa
Update README.md 2022-12-28 02:39:45 +09:00
toshiaki1729 5a13040437 output text change 2022-12-28 02:24:04 +09:00
10 changed files with 121 additions and 57 deletions

2
.gitignore vendored
View File

@ -152,3 +152,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.vscode

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 toshiaki1729
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -20,6 +20,7 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt.git
## Usage
1. Type some words into "Input Theme"
1. Type some unwanted words into "Input Negative Theme"
1. Push "Generate" button
![](pic/pic1.png)
@ -41,14 +42,14 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt.git
It's doing nothing special;
1. Danbooru tags and it's descriptions are in the `data` folder
- descriptions are generated from wiki and already tokenized
- [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) models are used to tokenize the text
1. Danbooru tags and it's descriptions are in the `data` folder
- embeddigs of descriptions are generated from wiki
- [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) models are used to make embeddings from the text
1. Tokenize your input text and calculate cosine similarity with all tag descriptions
1. Choose some tags depending on their similarities
### Database (Optional)
## Database (Optional)
You can choose the following dataset if needed.
Download the following, unzip and put its contents into `text2prompt-root-dir/data/danbooru/`.
@ -59,7 +60,7 @@ Download the following, unzip and put its contents into `text2prompt-root-dir/da
|normal (same as previous one)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_normal_all-mpnet-base-v2.zip)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_normal_all-MiniLM-L6-v2.zip)|
|full (noisy)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_full_all-mpnet-base-v2.zip)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_full_all-MiniLM-L6-v2.zip)|
**well filtered:** Tags are removed if their description include the title of the work. These tags are heavily related to a specific work, meaning they are not "general" tags.
**well filtered:** Tags are removed if their description include the title of some work. These tags are heavily related to a specific work, meaning they are not "general" tags.
**normal:** Tags containing the title of a work, like tag_name(work_name), are removed.
**full:** Including all tags.
@ -67,7 +68,7 @@ Download the following, unzip and put its contents into `text2prompt-root-dir/da
## More detailed description
$i \in N = \\{1, 2, ..., n\\}$ for index number of the tag
$s_i = S_C(d_i, t)$ for cosine similarity between tag description $d_i$ and your text $t$
$s_i = S_C(d_i, t)$ for cosine similarity between tag description $d_i$ and your text $t$
$P_i$ for probability for the tag to be chosen
### "Method to convert similarity into probability"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 132 KiB

After

Width:  |  Height:  |  Size: 163 KiB

View File

@ -4,9 +4,18 @@ import gradio as gr
from modules import script_callbacks
from modules import generation_parameters_copypaste as params_copypaste
import scripts.t2p.prompt_generator as pgen
import scripts.t2p.settings as settings
wd_like = pgen.WDLike()
if settings.DEVELOP:
import scripts.t2p.prompt_generator as pgen
from scripts.t2p.prompt_generator.wd_like import WDLike
else:
from scripts.t2p.dynamic_import import dynamic_import
_wd_like = dynamic_import('scripts/t2p/prompt_generator/wd_like.py')
WDLike = _wd_like.WDLike
pgen = _wd_like.pgen
wd_like = WDLike()
# brought from modules/deepbooru.py
re_special = re.compile(r'([\\()])')
@ -28,7 +37,7 @@ def get_tag_range_txt(tag_range: int):
maxval = len(wd_like.database.tag_idx) - 1
i = max(0, min(tag_range, maxval))
r = wd_like.database.tag_idx[i]
return f'Tag range: <b> &gt; {r[0]} tagged</b> ({r[1] + 1} tags total)'
return f'Tag range: <b> {r[0]} tagged</b> ({r[1] + 1} tags total)'
def dd_database_changed(database_name: str, tag_range: int):
@ -43,9 +52,9 @@ def sl_tag_range_changed(tag_range: int):
return get_tag_range_txt(tag_range)
def generate_prompt(text: str, tag_range: int, conversion: int, power: float, sampling: int, n: int, k: int, p: float, weighted: bool, replace_underscore: bool, excape_brackets: bool):
def generate_prompt(text: str, text_neg: str, neg_weight: float, tag_range: int, conversion: int, power: float, sampling: int, n: int, k: int, p: float, weighted: bool, replace_underscore: bool, excape_brackets: bool):
wd_like.load_model() #skip loading if not needed
tags = wd_like(text, pgen.GenerationSettings(tag_range, get_conversion(conversion), power, get_sampling(sampling), n, k, p, weighted))
tags = wd_like(text, text_neg, neg_weight, pgen.GenerationSettings(tag_range, get_conversion(conversion), power, get_sampling(sampling), n, k, p, weighted))
if replace_underscore: tags = [t.replace('_', ' ') for t in tags]
if excape_brackets: tags = [re.sub(re_special, r'\\\1', t) for t in tags]
return ', '.join(tags)
@ -56,6 +65,8 @@ def on_ui_tabs():
with gr.Row():
with gr.Column():
tb_input = gr.Textbox(label='Input Theme', interactive=True)
tb_input_neg = gr.Textbox(label='Input Negative Theme', interactive=True)
sl_negative_strength = gr.Slider(0, 3, value=1, step=0.01, label='Negative strength', interactive=True)
cb_replace_underscore = gr.Checkbox(value=True, label='Replace underscore in tag with whitespace', interactive=True)
cb_escape_brackets = gr.Checkbox(value=True, label='Escape brackets in tag', interactive=True)
btn_generate = gr.Button(value='Generate', variant='primary')
@ -68,7 +79,8 @@ def on_ui_tabs():
gr.HTML(value='Generation Settings')
choices = wd_like.get_model_names()
with gr.Column():
dd_database = gr.Dropdown(choices=choices, value=None, interactive=True, label='Database')
if choices: wd_like.load_data(choices[-1])
dd_database = gr.Dropdown(choices=choices, value=choices[-1] if choices else None, interactive=True, label='Database')
sl_tag_range = gr.Slider(0, 8, 0, step=1, interactive=True, label='Tag count filter')
txt_tag_range = gr.HTML(get_tag_range_txt(0))
with gr.Column():
@ -108,6 +120,8 @@ def on_ui_tabs():
fn=generate_prompt,
inputs=[
tb_input,
tb_input_neg,
sl_negative_strength,
sl_tag_range,
rb_prob_conversion_method,
sl_power,

View File

@ -0,0 +1,5 @@
def dynamic_import(path: str):
import os
from modules import scripts, script_loading
path = os.path.abspath(os.path.join(scripts.basedir(), path))
return script_loading.load_module(path)

View File

@ -40,8 +40,5 @@ class PromptGenerator:
raise NotImplementedError()
def ready(self) -> bool:
raise NotImplementedError()
def __call__(self, text: str, settings: GenerationSettings) -> List[str]:
raise NotImplementedError()
from .wd_like import WDLike
def __call__(self, text: str, text_neg: str, neg_weight: float, settings: GenerationSettings) -> List[str]:
raise NotImplementedError()

View File

@ -4,15 +4,15 @@ import csv
from typing import Dict
import numpy as np
from .. import settings
import scripts.t2p.settings as settings
class Database:
def __init__(self, database_path: str, re_filename: re.Pattern[str]):
def __init__(self, database_path: str, re_filename: re.Pattern):
self.read_files(database_path, re_filename)
def read_files(self, database_path: str, re_filename: re.Pattern[str]):
def read_files(self, database_path: str, re_filename: re.Pattern):
self.clear()
self.database_path = database_path
fn, _ = os.path.splitext(os.path.basename(database_path))
@ -20,13 +20,13 @@ class Database:
self.size_name = m.group(1)
self.model_name = m.group(2)
if self.model_name not in settings.TOKENIZER_NAMES:
print(f'Cannot use database in {database_path}; Incompatible model name "{self.model_name}"')
print(f'[text2prompt] Cannot use database in {database_path}; Incompatible model name "{self.model_name}"')
self.clear()
return
tag_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tags.txt')
if not os.path.isfile(tag_path):
print(f'Cannot use database in {database_path}; No tag file exists')
print(f'[text2prompt] Cannot use database in {database_path}; No tag file exists')
self.clear()
return
@ -35,14 +35,14 @@ class Database:
tag_idx_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tagidx.csv')
if not os.path.isfile(tag_idx_path):
print(f'Cannot read tag indices file. Tag count filter cannot be used.')
print(f'[text2prompt] Cannot read tag indices file. Tag count filter cannot be used.')
else:
with open(tag_idx_path, mode='r', encoding='utf8', newline='') as f:
cr = csv.reader(f)
for row in cr:
self.tag_idx.append((int(row[0]), int(row[1])))
self.tag_idx.sort(key=lambda t : t[0])
self.tag_idx = [(0, len(self.tags))] + self.tag_idx
self.tag_idx = [(0, len(self.tags) - 1)] + self.tag_idx
def clear(self):
@ -77,12 +77,12 @@ class Database:
class DatabaseLoader:
def __init__(self, path: str, re_filename: re.Pattern[str]):
def __init__(self, path: str, re_filename: re.Pattern):
self.datas: Dict[str, Database] = dict()
self.preload(path, re_filename)
def preload(self, path: str, re_filename: re.Pattern[str]):
def preload(self, path: str, re_filename: re.Pattern):
dirs = os.listdir(path)
for d in dirs:
filepath = os.path.join(path, d)
@ -91,8 +91,9 @@ class DatabaseLoader:
if ext == '.npz':
ds = Database(filepath, re_filename)
self.datas[ds.name()] = ds
print('[text2prompt] Loaded following databases')
print(sorted(self.datas.keys()))
print('[text2prompt] Following databases are available:')
for name in sorted(self.datas.keys()):
print(f' {name}')
def load(self, database_name: str):

View File

@ -5,10 +5,15 @@ import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbabilityConversion
from .database_loader import DatabaseLoader
from .. import settings
import scripts.t2p.settings as settings
if settings.DEVELOP:
import scripts.t2p.prompt_generator as pgen
import scripts.t2p.prompt_generator.database_loader as dloader
else:
from scripts.t2p.dynamic_import import dynamic_import
pgen = dynamic_import('scripts/t2p/prompt_generator/__init__.py')
dloader = dynamic_import('scripts/t2p/prompt_generator/database_loader.py')
NUM_CHOICE = 10
K_VALUE = 50
@ -20,10 +25,10 @@ def mean_pooling(model_output, attention_mask):
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class WDLike(PromptGenerator):
class WDLike(pgen.PromptGenerator):
def __init__(self):
self.clear()
self.database_loader = DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
self.database_loader = dloader.DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
def get_model_names(self):
return sorted(self.database_loader.datas.keys())
@ -38,13 +43,13 @@ class WDLike(PromptGenerator):
self.loaded_model_name = None
def load_data(self, database_name: str):
print('Loading database...')
print(f'[text2prompt] Loading database with name "{database_name}"...')
self.database = self.database_loader.load(database_name)
print('Loaded')
print('[text2prompt] Database loaded')
def load_model(self):
if self.database is None:
print('Cannot load model; Database is not loaded.')
print('[text2prompt] Cannot load model; Database is not loaded.')
return
from modules.devices import device
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
@ -52,11 +57,11 @@ class WDLike(PromptGenerator):
if self.loaded_model_name and self.loaded_model_name == self.database.model_name:
return
else:
print('Loading model...')
print(f'[text2prompt] Loading model with name "{self.database.model_name}"...')
self.tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name])
self.model = AutoModel.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]).to(device)
self.loaded_model_name = self.database.model_name
print('model loaded')
print('[text2prompt] Model loaded')
def unload_model(self):
if self.tokenizer is not None:
@ -71,7 +76,7 @@ class WDLike(PromptGenerator):
and self.tokenizer is not None \
and self.loaded_model_name and self.loaded_model_name == self.database.model_name
def __call__(self, text: str, opts: GenerationSettings) -> List[str]:
def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.GenerationSettings) -> List[str]:
if not self.ready(): return ''
i = max(0, min(opts.tag_range, len(self.database.tag_idx) - 1))
@ -85,50 +90,57 @@ class WDLike(PromptGenerator):
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
# Tokenize sentences
encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device)
if text_neg:
encoded_input_neg = self.tokenizer(text_neg, padding=True, truncation=True, return_tensors='pt').to(device)
# Compute token embeddings
with torch.no_grad():
model_output = self.model(**encoded_input)
if text_neg:
model_output_neg = self.model(**encoded_input_neg)
# Perform pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
if text_neg:
sentence_embeddings -= neg_weight*mean_pooling(model_output_neg, encoded_input_neg['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
# --------------------------------------------------------------------------------------------------------------------------------
# Get cosine similarity between given text and tag descriptions
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
tag_tokens_dev = torch.from_numpy(self.tokens).to(device)
tag_tokens_dev = torch.from_numpy(self.tokens).type(torch.float32).to(device)
similarity: torch.Tensor = cos(sentence_embeddings[0], tag_tokens_dev)
# Convert similarity into probablity
if opts.conversion == ProbabilityConversion.CUTOFF_AND_POWER:
if opts.conversion is pgen.ProbabilityConversion.CUTOFF_AND_POWER:
probs_cpu = torch.clamp(similarity.detach().cpu(), 0, 1) ** opts.prob_power
elif opts.conversion == ProbabilityConversion.SOFTMAX:
elif opts.conversion is pgen.ProbabilityConversion.SOFTMAX:
probs_cpu = torch.softmax(similarity.detach().cpu(), dim=0)
probs_cpu = probs_cpu / probs_cpu.sum(dim=0)
probs_cpu = probs_cpu.nan_to_num()
results = None
if opts.sampling == SamplingMethod.NONE:
if opts.sampling is pgen.SamplingMethod.NONE:
tags_np = np.array(self.tags)
opts.n = min(tags_np.shape[0], opts.n)
if opts.n <= 0: return []
if opts.weighted:
probs_np = probs_cpu.detach().numpy()
probs_np /= np.sum(probs_np)
if np.count_nonzero(probs_np) <= opts.n:
results = tags_np
num_nonzero = np.count_nonzero(probs_np)
if num_nonzero <= opts.n:
if num_nonzero > 0:
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
else:
results = np.random.choice(tags_np, opts.n, replace=False)
else:
results = np.random.choice(a=tags_np, size=opts.n, replace=False, p=probs_np)
else:
# Just sample randomly
results = np.random.choice(a=tags_np, size=opts.n, replace=False)
elif opts.sampling == SamplingMethod.TOP_K:
elif opts.sampling is pgen.SamplingMethod.TOP_K:
probs, indices = probs_cpu.topk(opts.k)
indices = indices.detach().numpy().tolist()
if len(indices) <= 0: return []
@ -138,16 +150,20 @@ class WDLike(PromptGenerator):
if opts.weighted:
probs_np = probs.detach().numpy()
probs_np /= np.sum(probs_np)
if np.count_nonzero(probs_np) <= opts.n:
results = tags_np
probs_np = np.nan_to_num(probs_np)
num_nonzero = np.count_nonzero(probs_np)
if num_nonzero <= opts.n:
if num_nonzero > 0:
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
else:
results = np.random.choice(tags_np, opts.n, replace=False)
else:
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
else:
results = np.random.choice(tags_np, opts.n, replace=False)
# brought from https://nn.labml.ai/sampling/nucleus.html
elif opts.sampling == SamplingMethod.TOP_P:
elif opts.sampling is pgen.SamplingMethod.TOP_P:
sorted_probs, sorted_indices = probs_cpu.sort(descending=True)
cs_probs = torch.cumsum(sorted_probs, dim=0)
nucleus = cs_probs < opts.p
@ -163,9 +179,13 @@ class WDLike(PromptGenerator):
if opts.weighted:
probs_np = np.array([sorted_probs[i] for i in indices])
probs_np /= np.sum(probs_np)
if np.count_nonzero(probs_np) <= opts.n:
results = tags_np
probs_np = np.nan_to_num(probs_np)
num_nonzero = np.count_nonzero(probs_np)
if num_nonzero <= opts.n:
if num_nonzero > 0:
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
else:
results = np.random.choice(tags_np, opts.n, replace=False)
else:
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
else:

View File

@ -2,6 +2,9 @@ import os
import re
from modules import scripts
# to use intellisense on vscode
DEVELOP = False
def get_abspath(path: str):
return os.path.abspath(os.path.join(scripts.basedir(), path))