Squashed commit of the following:
commit 0571f2dab8bf3d8e1826f0fb5d53ca5b9bfcf6b0
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed Dec 28 01:05:27 2022 +0900
Update README.md
commit b4a5164081662db9085e2b5c17af6e98473180ff
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed Dec 28 00:44:25 2022 +0900
remove
commit 127946a3d200825eb15244ad09db5361db8bf89c
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed Dec 28 00:28:28 2022 +0900
Revert "remove"
This reverts commit f7666de150d1bb8acdcedcb0b82ddfb939206088.
commit f7666de150d1bb8acdcedcb0b82ddfb939206088
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed Dec 28 00:26:44 2022 +0900
remove
commit 41b99ba60bd59050384bdd55a498395e4148bfda
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed Dec 28 00:26:20 2022 +0900
Update .gitignore
commit 2fdce554bdd92bc87dce13181aa67c58d523d830
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed Dec 28 00:24:56 2022 +0900
Update .gitignore
commit 50927e064e5ffe052d785038f4c7de851ede6c91
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed Dec 28 00:23:49 2022 +0900
upload database
commit 39172118e8d8fac2ed7f3aab0c53e2a6e89c16b3
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed Dec 28 00:22:34 2022 +0900
Update .gitignore
commit 7c6225ade3a55fdf342eb2b3617302edab595192
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed Dec 28 00:21:25 2022 +0900
Create test
commit 8010ec05a5e7836fc004450bc82cc3ce93d975ad
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed Dec 28 00:20:22 2022 +0900
Update .gitignore
commit 6265874465281db81a3935f651537ac9c6cfc30a
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Wed Dec 28 00:17:17 2022 +0900
update README.md
commit 3f52e109aa0b7a41f4970b11b933c7f57c7ddd62
Merge: 180c4c5 867a48a
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Tue Dec 27 23:20:32 2022 +0900
Merge branch 'main' into feature/well-filtered-database
commit 180c4c5b414058939eb47dfe7783eb561c5da65e
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Tue Dec 27 23:19:51 2022 +0900
implement dynamic database loader
commit fdccc1a9cb9a80ce9b4042892afc9922a20e6524
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Sun Dec 25 19:39:32 2022 +0900
Update README.md
commit b539cbf4d3791777e424883b2a5bd64a73de9bf0
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Sun Dec 25 19:34:11 2022 +0900
remove thread lock to allow generating prompt while doing other tasks
commit 43b5bfd73a549f2bced6704cdcfacb1483648b17
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Sun Dec 25 19:26:26 2022 +0900
update README
commit 58df449f36678de2e9b35726315d53a281bdc380
Author: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com>
Date: Sun Dec 25 19:05:30 2022 +0900
add new language model and database
The size of data will be so large
main
parent
867a48ad52
commit
374325e3e2
|
|
@ -6,6 +6,8 @@ __pycache__/
|
|||
# C extensions
|
||||
*.so
|
||||
|
||||
archive/
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
|
|
|
|||
28
README.md
28
README.md
|
|
@ -18,9 +18,13 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt.git
|
|||
|
||||
|
||||
## Usage
|
||||
|
||||
1. Type some words into "Input Theme"
|
||||
1. Push "Generate" button
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
### Tips
|
||||
- For more creative result
|
||||
|
|
@ -34,20 +38,34 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt.git
|
|||
- You can enter very long sentences, but the more specific it is, the fewer results you will get.
|
||||
|
||||
## How it works
|
||||
|
||||
It's doing nothing special;
|
||||
|
||||
1. Danbooru tags and it's descriptions are in the `data` folder
|
||||
- descriptions are generated from wiki and already tokenized
|
||||
- [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model is used to tokenize the text
|
||||
- for now, some tags (such as <1k tagged or containing title of the work) are deleted to prevent from "noisy" result
|
||||
1. Tokenize your input text and calculate cosine similarity to each tag descriptions
|
||||
- [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) models are used to tokenize the text
|
||||
1. Tokenize your input text and calculate cosine similarity with all tag descriptions
|
||||
1. Choose some tags depending on their similarities
|
||||
|
||||

|
||||
|
||||
### Database (Optional)
|
||||
|
||||
You can choose the following dataset if needed.
|
||||
Download the following, unzip and put its contents into `text2prompt-root-dir/data/danbooru/`.
|
||||
|
||||
|Tag description|all-mpnet-base-v2|all-MiniLM-L6-v2|
|
||||
|:---|:---:|:---:|
|
||||
|**well filtered (recommended)**|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_strict_all-mpnet-base-v2.zip) (preinstalled)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_strict_all-MiniLM-L6-v2.zip)|
|
||||
|normal (same as previous one)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_normal_all-mpnet-base-v2.zip)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_normal_all-MiniLM-L6-v2.zip)|
|
||||
|full (noisy)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_full_all-mpnet-base-v2.zip)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_full_all-MiniLM-L6-v2.zip)|
|
||||
|
||||
**well filtered:** Tags are removed if their description include the title of the work. These tags are heavily related to a specific work, meaning they are not "general" tags.
|
||||
**normal:** Tags containing the title of a work, like tag_name(work_name), are removed.
|
||||
**full:** Including all tags.
|
||||
|
||||
---
|
||||
|
||||
### More detailed
|
||||
## More detailed description
|
||||
$i \in N = \\{1, 2, ..., n\\}$ for index number of the tag
|
||||
$s_i = S_C(d_i, t)$ for cosine similarity between tag description $d_i$ and your text $t$
|
||||
$P_i$ for probability for the tag to be chosen
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
50,13005
|
||||
100,11294
|
||||
200,9490
|
||||
500,7099
|
||||
1000,5425
|
||||
2000,3956
|
||||
5000,2458
|
||||
10000,1633
|
||||
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
BIN
pic/pic1.png
BIN
pic/pic1.png
Binary file not shown.
|
Before Width: | Height: | Size: 101 KiB After Width: | Height: | Size: 132 KiB |
|
|
@ -6,7 +6,7 @@ from modules import generation_parameters_copypaste as params_copypaste
|
|||
|
||||
import scripts.t2p.prompt_generator as pgen
|
||||
|
||||
wd_like = None
|
||||
wd_like = pgen.WDLike()
|
||||
|
||||
# brought from modules/deepbooru.py
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
|
@ -22,18 +22,35 @@ def get_sampling(choice: int):
|
|||
elif choice == 2: return pgen.SamplingMethod.TOP_P
|
||||
else: raise NotImplementedError()
|
||||
|
||||
def generate_prompt(text: str, conversion: int, power: float, sampling: int, n: int, k: int, p: float, weighted: bool, replace_underscore: bool, excape_brackets: bool):
|
||||
global wd_like
|
||||
if not wd_like:
|
||||
print('Loading tag data and model')
|
||||
wd_like = pgen.WDLike()
|
||||
wd_like.load_model()
|
||||
print('Loaded')
|
||||
tags = wd_like(text, pgen.GenerationSettings(get_conversion(conversion), power, get_sampling(sampling), n, k, p, weighted))
|
||||
def get_tag_range_txt(tag_range: int):
|
||||
if wd_like.database is None:
|
||||
return 'Tag range: NONE'
|
||||
maxval = len(wd_like.database.tag_idx) - 1
|
||||
i = max(0, min(tag_range, maxval))
|
||||
r = wd_like.database.tag_idx[i]
|
||||
return f'Tag range: <b> > {r[0]} tagged</b> ({r[1] + 1} tags total)'
|
||||
|
||||
|
||||
def dd_database_changed(database_name: str, tag_range: int):
|
||||
wd_like.load_data(database_name)
|
||||
return [
|
||||
gr.Slider.update(tag_range, 0, len(wd_like.database.tag_idx) - 1),
|
||||
get_tag_range_txt(tag_range)
|
||||
]
|
||||
|
||||
|
||||
def sl_tag_range_changed(tag_range: int):
|
||||
return get_tag_range_txt(tag_range)
|
||||
|
||||
|
||||
def generate_prompt(text: str, tag_range: int, conversion: int, power: float, sampling: int, n: int, k: int, p: float, weighted: bool, replace_underscore: bool, excape_brackets: bool):
|
||||
wd_like.load_model() #skip loading if not needed
|
||||
tags = wd_like(text, pgen.GenerationSettings(tag_range, get_conversion(conversion), power, get_sampling(sampling), n, k, p, weighted))
|
||||
if replace_underscore: tags = [t.replace('_', ' ') for t in tags]
|
||||
if excape_brackets: tags = [re.sub(re_special, r'\\\1', t) for t in tags]
|
||||
return ', '.join(tags)
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
with gr.Blocks(analytics_enabled=False) as text2prompt_interface:
|
||||
with gr.Row():
|
||||
|
|
@ -47,15 +64,33 @@ def on_ui_tabs():
|
|||
buttons = params_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
|
||||
params_copypaste.bind_buttons(buttons, None, tb_output)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Column(variant='panel'):
|
||||
gr.HTML(value='Generation Settings')
|
||||
rb_prob_conversion_method = gr.Radio(choices=['Cutoff and Power', 'Softmax'], value='Cutoff and Power', type='index', label='Method to convert similarity into probability')
|
||||
sl_power = gr.Slider(0, 5, value=2, step=0.1, label='Power', interactive=True)
|
||||
rb_sampling_method = gr.Radio(choices=['NONE', 'Top-k', 'Top-p (Nucleus)'], value='Top-k', type='index', label='Sampling method')
|
||||
nb_max_tag_num = gr.Number(value=20, label='Max number of tags', precision=0, interactive=True)
|
||||
nb_k_value = gr.Number(value=50, label='k value', precision=0, interactive=True)
|
||||
sl_p_value = gr.Slider(0, 1, label='p value', value=0.1, step=0.01, interactive=True)
|
||||
cb_weighted = gr.Checkbox(value=True, label='Use weighted choice', interactive=True)
|
||||
choices = wd_like.get_model_names()
|
||||
with gr.Column():
|
||||
dd_database = gr.Dropdown(choices=choices, value=None, interactive=True, label='Database')
|
||||
sl_tag_range = gr.Slider(0, 8, 0, step=1, interactive=True, label='Tag count filter')
|
||||
txt_tag_range = gr.HTML(get_tag_range_txt(0))
|
||||
with gr.Column():
|
||||
rb_prob_conversion_method = gr.Radio(choices=['Cutoff and Power', 'Softmax'], value='Cutoff and Power', type='index', label='Method to convert similarity into probability')
|
||||
sl_power = gr.Slider(0, 5, value=2, step=0.1, label='Power', interactive=True)
|
||||
rb_sampling_method = gr.Radio(choices=['NONE', 'Top-k', 'Top-p (Nucleus)'], value='Top-k', type='index', label='Sampling method')
|
||||
nb_max_tag_num = gr.Number(value=20, label='Max number of tags', precision=0, interactive=True)
|
||||
nb_k_value = gr.Number(value=50, label='k value', precision=0, interactive=True)
|
||||
sl_p_value = gr.Slider(0, 1, label='p value', value=0.1, step=0.01, interactive=True)
|
||||
cb_weighted = gr.Checkbox(value=True, label='Use weighted choice', interactive=True)
|
||||
|
||||
dd_database.change(
|
||||
fn=dd_database_changed,
|
||||
inputs=[dd_database, sl_tag_range],
|
||||
outputs=[sl_tag_range, txt_tag_range]
|
||||
)
|
||||
|
||||
sl_tag_range.change(
|
||||
fn=sl_tag_range_changed,
|
||||
inputs=sl_tag_range,
|
||||
outputs=txt_tag_range
|
||||
)
|
||||
|
||||
nb_max_tag_num.change(
|
||||
fn=lambda x: max(0, x),
|
||||
|
|
@ -71,7 +106,19 @@ def on_ui_tabs():
|
|||
|
||||
btn_generate.click(
|
||||
fn=generate_prompt,
|
||||
inputs=[tb_input, rb_prob_conversion_method, sl_power, rb_sampling_method, nb_max_tag_num, nb_k_value, sl_p_value, cb_weighted, cb_replace_underscore, cb_escape_brackets],
|
||||
inputs=[
|
||||
tb_input,
|
||||
sl_tag_range,
|
||||
rb_prob_conversion_method,
|
||||
sl_power,
|
||||
rb_sampling_method,
|
||||
nb_max_tag_num,
|
||||
nb_k_value,
|
||||
sl_p_value,
|
||||
cb_weighted,
|
||||
cb_replace_underscore,
|
||||
cb_escape_brackets
|
||||
],
|
||||
outputs=tb_output
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ class ProbabilityConversion(Enum):
|
|||
|
||||
class GenerationSettings:
|
||||
def __init__(
|
||||
self,
|
||||
conversion: ProbabilityConversion = ProbabilityConversion.CUTOFF_AND_POWER,
|
||||
self,
|
||||
tag_range: int = 0,
|
||||
conversion: ProbabilityConversion = ProbabilityConversion.CUTOFF_AND_POWER,
|
||||
prob_power: float = 2,
|
||||
sampling: SamplingMethod = SamplingMethod.TOP_K,
|
||||
n:int = 20,
|
||||
|
|
@ -21,6 +22,7 @@ class GenerationSettings:
|
|||
p: Optional[float] = 0.3,
|
||||
weighted: bool = True):
|
||||
|
||||
self.tag_range = tag_range
|
||||
self.sampling = sampling
|
||||
self.conversion = conversion
|
||||
self.n = n
|
||||
|
|
@ -31,6 +33,13 @@ class GenerationSettings:
|
|||
|
||||
|
||||
class PromptGenerator:
|
||||
def clear(self): pass
|
||||
def load_data(self, model_id: str, data_id: str):
|
||||
raise NotImplementedError()
|
||||
def load_model(self, model_id: str):
|
||||
raise NotImplementedError()
|
||||
def ready(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
def __call__(self, text: str, settings: GenerationSettings) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
import os
|
||||
import re
|
||||
import csv
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
|
||||
from .. import settings
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, database_path: str, re_filename: re.Pattern[str]):
|
||||
self.read_files(database_path, re_filename)
|
||||
|
||||
|
||||
def read_files(self, database_path: str, re_filename: re.Pattern[str]):
|
||||
self.clear()
|
||||
self.database_path = database_path
|
||||
fn, _ = os.path.splitext(os.path.basename(database_path))
|
||||
m = re_filename.match(fn)
|
||||
self.size_name = m.group(1)
|
||||
self.model_name = m.group(2)
|
||||
if self.model_name not in settings.TOKENIZER_NAMES:
|
||||
print(f'Cannot use database in {database_path}; Incompatible model name "{self.model_name}"')
|
||||
self.clear()
|
||||
return
|
||||
|
||||
tag_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tags.txt')
|
||||
if not os.path.isfile(tag_path):
|
||||
print(f'Cannot use database in {database_path}; No tag file exists')
|
||||
self.clear()
|
||||
return
|
||||
|
||||
with open(tag_path, mode='r', encoding='utf8', newline='\n') as f:
|
||||
self.tags = [l.strip() for l in f.readlines()]
|
||||
|
||||
tag_idx_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tagidx.csv')
|
||||
if not os.path.isfile(tag_idx_path):
|
||||
print(f'Cannot read tag indices file. Tag count filter cannot be used.')
|
||||
else:
|
||||
with open(tag_idx_path, mode='r', encoding='utf8', newline='') as f:
|
||||
cr = csv.reader(f)
|
||||
for row in cr:
|
||||
self.tag_idx.append((int(row[0]), int(row[1])))
|
||||
self.tag_idx.sort(key=lambda t : t[0])
|
||||
self.tag_idx = [(0, len(self.tags))] + self.tag_idx
|
||||
|
||||
|
||||
def clear(self):
|
||||
self.database_path = ''
|
||||
self.model_name = ''
|
||||
self.size_name = ''
|
||||
self.tag_idx = []
|
||||
self.tags = []
|
||||
self.data: np.ndarray = None
|
||||
|
||||
|
||||
def ready_to_load(self):
|
||||
return self.database_path \
|
||||
and self.model_name \
|
||||
and self.size_name \
|
||||
and self.tags \
|
||||
and self.tag_idx
|
||||
|
||||
def loaded(self):
|
||||
return self.data is not None
|
||||
|
||||
def load(self):
|
||||
if not self.ready_to_load(): return None
|
||||
if not self.loaded():
|
||||
self.data = np.load(self.database_path)['db']
|
||||
return self
|
||||
|
||||
|
||||
def name(self):
|
||||
return f'{self.model_name} : {self.size_name}'
|
||||
|
||||
|
||||
|
||||
class DatabaseLoader:
|
||||
def __init__(self, path: str, re_filename: re.Pattern[str]):
|
||||
self.datas: Dict[str, Database] = dict()
|
||||
self.preload(path, re_filename)
|
||||
|
||||
|
||||
def preload(self, path: str, re_filename: re.Pattern[str]):
|
||||
dirs = os.listdir(path)
|
||||
for d in dirs:
|
||||
filepath = os.path.join(path, d)
|
||||
if not os.path.isfile(filepath): continue
|
||||
_, ext = os.path.splitext(filepath)
|
||||
if ext == '.npz':
|
||||
ds = Database(filepath, re_filename)
|
||||
self.datas[ds.name()] = ds
|
||||
print('[text2prompt] Loaded following databases')
|
||||
print(sorted(self.datas.keys()))
|
||||
|
||||
|
||||
def load(self, database_name: str):
|
||||
database = self.datas.get(database_name)
|
||||
return database.load() if database else None
|
||||
|
|
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
|||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbabilityConversion
|
||||
from .database_loader import DatabaseLoader
|
||||
from .. import settings
|
||||
|
||||
|
||||
|
|
@ -21,28 +22,62 @@ def mean_pooling(model_output, attention_mask):
|
|||
|
||||
class WDLike(PromptGenerator):
|
||||
def __init__(self):
|
||||
self.tags = []
|
||||
self.clear()
|
||||
self.database_loader = DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
|
||||
|
||||
def get_model_names(self):
|
||||
return sorted(self.database_loader.datas.keys())
|
||||
|
||||
def clear(self):
|
||||
self.database_loader = None
|
||||
self.database = None
|
||||
self.tags = None
|
||||
self.tokens = None
|
||||
self.tokenizer = None
|
||||
self.model = None
|
||||
self.load_data()
|
||||
self.loaded_model_name = None
|
||||
|
||||
def load_data(self):
|
||||
with open(settings.WDLIKE_TAG_PATH, mode='r', encoding='utf8', newline='\n') as f:
|
||||
self.tags = [l.strip() for l in f.readlines()]
|
||||
with open(settings.WDLIKE_TOKEN_PATH, mode='rb') as f:
|
||||
self.tokens = np.load(f)
|
||||
def load_data(self, database_name: str):
|
||||
print('Loading database...')
|
||||
self.database = self.database_loader.load(database_name)
|
||||
print('Loaded')
|
||||
|
||||
def load_model(self):
|
||||
if self.database is None:
|
||||
print('Cannot load model; Database is not loaded.')
|
||||
return
|
||||
from modules.devices import device
|
||||
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
|
||||
# Load model from HuggingFace Hub
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(settings.WDLIKE_MODEL_NAME)
|
||||
self.model = AutoModel.from_pretrained(settings.WDLIKE_MODEL_NAME).to(device)
|
||||
if self.loaded_model_name and self.loaded_model_name == self.database.model_name:
|
||||
return
|
||||
else:
|
||||
print('Loading model...')
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name])
|
||||
self.model = AutoModel.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]).to(device)
|
||||
self.loaded_model_name = self.database.model_name
|
||||
print('model loaded')
|
||||
|
||||
def unload_model(self):
|
||||
if self.tokenizer is not None:
|
||||
del self.tokenizer
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
|
||||
def ready(self) -> bool:
|
||||
return self.database is not None \
|
||||
and self.database.loaded() \
|
||||
and self.model is not None \
|
||||
and self.tokenizer is not None \
|
||||
and self.loaded_model_name and self.loaded_model_name == self.database.model_name
|
||||
|
||||
def __call__(self, text: str, opts: GenerationSettings) -> List[str]:
|
||||
if not self.model or not self.tokenizer:
|
||||
return ''
|
||||
if not self.ready(): return ''
|
||||
|
||||
i = max(0, min(opts.tag_range, len(self.database.tag_idx) - 1))
|
||||
r = self.database.tag_idx[i][1]
|
||||
self.tokens = self.database.data[:r, :]
|
||||
self.tags = self.database.tags[:r]
|
||||
|
||||
from modules.devices import device
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,18 @@
|
|||
import os
|
||||
import re
|
||||
from modules import scripts
|
||||
|
||||
WDLIKE_TAG_PATH = os.path.abspath(os.path.join(scripts.basedir(), 'data/danbooru_wiki_tags.txt'))
|
||||
WDLIKE_TOKEN_PATH = os.path.abspath(os.path.join(scripts.basedir(), 'data/danbooru_wiki_token_all-mpnet-base-v2.npy'))
|
||||
WDLIKE_MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
|
||||
def get_abspath(path: str):
|
||||
return os.path.abspath(os.path.join(scripts.basedir(), path))
|
||||
|
||||
|
||||
TOKENIZER_NAMES = ['all-mpnet-base-v2', 'all-MiniLM-L6-v2']
|
||||
|
||||
TOKENIZER_MODELS = {
|
||||
TOKENIZER_NAMES[0]: f'sentence-transformers/{TOKENIZER_NAMES[0]}',
|
||||
TOKENIZER_NAMES[1]: f'sentence-transformers/{TOKENIZER_NAMES[1]}'
|
||||
}
|
||||
|
||||
DATABASE_PATH_DANBOORU = get_abspath('data/danbooru')
|
||||
|
||||
RE_TOKENFILE_DANBOORU = re.compile(r'(danbooru_[^_]+)_token_([^_]+)')
|
||||
Loading…
Reference in New Issue