Compare commits
20 Commits
danbooru-d
...
main
| Author | SHA1 | Date |
|---|---|---|
|
|
752a660d2b | |
|
|
bcfec2f988 | |
|
|
c4d8cec126 | |
|
|
a6c052ef7a | |
|
|
f45016a4d1 | |
|
|
238318c605 | |
|
|
362da88287 | |
|
|
f5737b77ea | |
|
|
da5471671e | |
|
|
0c45abc73e | |
|
|
c28e403288 | |
|
|
0406db1963 | |
|
|
03d40a1456 | |
|
|
6435280d03 | |
|
|
7e0841d4af | |
|
|
b99ce707f4 | |
|
|
4972255f31 | |
|
|
1e174b6c83 | |
|
|
38fc93c5fa | |
|
|
5a13040437 |
|
|
@ -152,3 +152,5 @@ cython_debug/
|
|||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
.vscode
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2022 toshiaki1729
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
13
README.md
13
README.md
|
|
@ -20,6 +20,7 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt.git
|
|||
## Usage
|
||||
|
||||
1. Type some words into "Input Theme"
|
||||
1. Type some unwanted words into "Input Negative Theme"
|
||||
1. Push "Generate" button
|
||||
|
||||

|
||||
|
|
@ -41,14 +42,14 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt.git
|
|||
|
||||
It's doing nothing special;
|
||||
|
||||
1. Danbooru tags and it's descriptions are in the `data` folder
|
||||
- descriptions are generated from wiki and already tokenized
|
||||
- [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) models are used to tokenize the text
|
||||
1. Danbooru tags and it's descriptions are in the `data` folder
|
||||
- embeddigs of descriptions are generated from wiki
|
||||
- [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) models are used to make embeddings from the text
|
||||
1. Tokenize your input text and calculate cosine similarity with all tag descriptions
|
||||
1. Choose some tags depending on their similarities
|
||||
|
||||
|
||||
### Database (Optional)
|
||||
## Database (Optional)
|
||||
|
||||
You can choose the following dataset if needed.
|
||||
Download the following, unzip and put its contents into `text2prompt-root-dir/data/danbooru/`.
|
||||
|
|
@ -59,7 +60,7 @@ Download the following, unzip and put its contents into `text2prompt-root-dir/da
|
|||
|normal (same as previous one)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_normal_all-mpnet-base-v2.zip)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_normal_all-MiniLM-L6-v2.zip)|
|
||||
|full (noisy)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_full_all-mpnet-base-v2.zip)|[download](https://github.com/toshiaki1729/stable-diffusion-webui-text2prompt/releases/download/danbooru-database-v1.0.0/danbooru_full_all-MiniLM-L6-v2.zip)|
|
||||
|
||||
**well filtered:** Tags are removed if their description include the title of the work. These tags are heavily related to a specific work, meaning they are not "general" tags.
|
||||
**well filtered:** Tags are removed if their description include the title of some work. These tags are heavily related to a specific work, meaning they are not "general" tags.
|
||||
**normal:** Tags containing the title of a work, like tag_name(work_name), are removed.
|
||||
**full:** Including all tags.
|
||||
|
||||
|
|
@ -67,7 +68,7 @@ Download the following, unzip and put its contents into `text2prompt-root-dir/da
|
|||
|
||||
## More detailed description
|
||||
$i \in N = \\{1, 2, ..., n\\}$ for index number of the tag
|
||||
$s_i = S_C(d_i, t)$ for cosine similarity between tag description $d_i$ and your text $t$
|
||||
$s_i = S_C(d_i, t)$ for cosine similarity between tag description $d_i$ and your text $t$
|
||||
$P_i$ for probability for the tag to be chosen
|
||||
|
||||
### "Method to convert similarity into probability"
|
||||
|
|
|
|||
BIN
pic/pic1.png
BIN
pic/pic1.png
Binary file not shown.
|
Before Width: | Height: | Size: 132 KiB After Width: | Height: | Size: 163 KiB |
|
|
@ -4,9 +4,18 @@ import gradio as gr
|
|||
from modules import script_callbacks
|
||||
from modules import generation_parameters_copypaste as params_copypaste
|
||||
|
||||
import scripts.t2p.prompt_generator as pgen
|
||||
import scripts.t2p.settings as settings
|
||||
|
||||
wd_like = pgen.WDLike()
|
||||
if settings.DEVELOP:
|
||||
import scripts.t2p.prompt_generator as pgen
|
||||
from scripts.t2p.prompt_generator.wd_like import WDLike
|
||||
else:
|
||||
from scripts.t2p.dynamic_import import dynamic_import
|
||||
_wd_like = dynamic_import('scripts/t2p/prompt_generator/wd_like.py')
|
||||
WDLike = _wd_like.WDLike
|
||||
pgen = _wd_like.pgen
|
||||
|
||||
wd_like = WDLike()
|
||||
|
||||
# brought from modules/deepbooru.py
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
|
@ -28,7 +37,7 @@ def get_tag_range_txt(tag_range: int):
|
|||
maxval = len(wd_like.database.tag_idx) - 1
|
||||
i = max(0, min(tag_range, maxval))
|
||||
r = wd_like.database.tag_idx[i]
|
||||
return f'Tag range: <b> > {r[0]} tagged</b> ({r[1] + 1} tags total)'
|
||||
return f'Tag range: <b> ≥ {r[0]} tagged</b> ({r[1] + 1} tags total)'
|
||||
|
||||
|
||||
def dd_database_changed(database_name: str, tag_range: int):
|
||||
|
|
@ -43,9 +52,9 @@ def sl_tag_range_changed(tag_range: int):
|
|||
return get_tag_range_txt(tag_range)
|
||||
|
||||
|
||||
def generate_prompt(text: str, tag_range: int, conversion: int, power: float, sampling: int, n: int, k: int, p: float, weighted: bool, replace_underscore: bool, excape_brackets: bool):
|
||||
def generate_prompt(text: str, text_neg: str, neg_weight: float, tag_range: int, conversion: int, power: float, sampling: int, n: int, k: int, p: float, weighted: bool, replace_underscore: bool, excape_brackets: bool):
|
||||
wd_like.load_model() #skip loading if not needed
|
||||
tags = wd_like(text, pgen.GenerationSettings(tag_range, get_conversion(conversion), power, get_sampling(sampling), n, k, p, weighted))
|
||||
tags = wd_like(text, text_neg, neg_weight, pgen.GenerationSettings(tag_range, get_conversion(conversion), power, get_sampling(sampling), n, k, p, weighted))
|
||||
if replace_underscore: tags = [t.replace('_', ' ') for t in tags]
|
||||
if excape_brackets: tags = [re.sub(re_special, r'\\\1', t) for t in tags]
|
||||
return ', '.join(tags)
|
||||
|
|
@ -56,6 +65,8 @@ def on_ui_tabs():
|
|||
with gr.Row():
|
||||
with gr.Column():
|
||||
tb_input = gr.Textbox(label='Input Theme', interactive=True)
|
||||
tb_input_neg = gr.Textbox(label='Input Negative Theme', interactive=True)
|
||||
sl_negative_strength = gr.Slider(0, 3, value=1, step=0.01, label='Negative strength', interactive=True)
|
||||
cb_replace_underscore = gr.Checkbox(value=True, label='Replace underscore in tag with whitespace', interactive=True)
|
||||
cb_escape_brackets = gr.Checkbox(value=True, label='Escape brackets in tag', interactive=True)
|
||||
btn_generate = gr.Button(value='Generate', variant='primary')
|
||||
|
|
@ -68,7 +79,8 @@ def on_ui_tabs():
|
|||
gr.HTML(value='Generation Settings')
|
||||
choices = wd_like.get_model_names()
|
||||
with gr.Column():
|
||||
dd_database = gr.Dropdown(choices=choices, value=None, interactive=True, label='Database')
|
||||
if choices: wd_like.load_data(choices[-1])
|
||||
dd_database = gr.Dropdown(choices=choices, value=choices[-1] if choices else None, interactive=True, label='Database')
|
||||
sl_tag_range = gr.Slider(0, 8, 0, step=1, interactive=True, label='Tag count filter')
|
||||
txt_tag_range = gr.HTML(get_tag_range_txt(0))
|
||||
with gr.Column():
|
||||
|
|
@ -108,6 +120,8 @@ def on_ui_tabs():
|
|||
fn=generate_prompt,
|
||||
inputs=[
|
||||
tb_input,
|
||||
tb_input_neg,
|
||||
sl_negative_strength,
|
||||
sl_tag_range,
|
||||
rb_prob_conversion_method,
|
||||
sl_power,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
def dynamic_import(path: str):
|
||||
import os
|
||||
from modules import scripts, script_loading
|
||||
path = os.path.abspath(os.path.join(scripts.basedir(), path))
|
||||
return script_loading.load_module(path)
|
||||
|
|
@ -40,8 +40,5 @@ class PromptGenerator:
|
|||
raise NotImplementedError()
|
||||
def ready(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
def __call__(self, text: str, settings: GenerationSettings) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
from .wd_like import WDLike
|
||||
def __call__(self, text: str, text_neg: str, neg_weight: float, settings: GenerationSettings) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -4,15 +4,15 @@ import csv
|
|||
from typing import Dict
|
||||
import numpy as np
|
||||
|
||||
from .. import settings
|
||||
import scripts.t2p.settings as settings
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, database_path: str, re_filename: re.Pattern[str]):
|
||||
def __init__(self, database_path: str, re_filename: re.Pattern):
|
||||
self.read_files(database_path, re_filename)
|
||||
|
||||
|
||||
def read_files(self, database_path: str, re_filename: re.Pattern[str]):
|
||||
def read_files(self, database_path: str, re_filename: re.Pattern):
|
||||
self.clear()
|
||||
self.database_path = database_path
|
||||
fn, _ = os.path.splitext(os.path.basename(database_path))
|
||||
|
|
@ -20,13 +20,13 @@ class Database:
|
|||
self.size_name = m.group(1)
|
||||
self.model_name = m.group(2)
|
||||
if self.model_name not in settings.TOKENIZER_NAMES:
|
||||
print(f'Cannot use database in {database_path}; Incompatible model name "{self.model_name}"')
|
||||
print(f'[text2prompt] Cannot use database in {database_path}; Incompatible model name "{self.model_name}"')
|
||||
self.clear()
|
||||
return
|
||||
|
||||
tag_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tags.txt')
|
||||
if not os.path.isfile(tag_path):
|
||||
print(f'Cannot use database in {database_path}; No tag file exists')
|
||||
print(f'[text2prompt] Cannot use database in {database_path}; No tag file exists')
|
||||
self.clear()
|
||||
return
|
||||
|
||||
|
|
@ -35,14 +35,14 @@ class Database:
|
|||
|
||||
tag_idx_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tagidx.csv')
|
||||
if not os.path.isfile(tag_idx_path):
|
||||
print(f'Cannot read tag indices file. Tag count filter cannot be used.')
|
||||
print(f'[text2prompt] Cannot read tag indices file. Tag count filter cannot be used.')
|
||||
else:
|
||||
with open(tag_idx_path, mode='r', encoding='utf8', newline='') as f:
|
||||
cr = csv.reader(f)
|
||||
for row in cr:
|
||||
self.tag_idx.append((int(row[0]), int(row[1])))
|
||||
self.tag_idx.sort(key=lambda t : t[0])
|
||||
self.tag_idx = [(0, len(self.tags))] + self.tag_idx
|
||||
self.tag_idx = [(0, len(self.tags) - 1)] + self.tag_idx
|
||||
|
||||
|
||||
def clear(self):
|
||||
|
|
@ -77,12 +77,12 @@ class Database:
|
|||
|
||||
|
||||
class DatabaseLoader:
|
||||
def __init__(self, path: str, re_filename: re.Pattern[str]):
|
||||
def __init__(self, path: str, re_filename: re.Pattern):
|
||||
self.datas: Dict[str, Database] = dict()
|
||||
self.preload(path, re_filename)
|
||||
|
||||
|
||||
def preload(self, path: str, re_filename: re.Pattern[str]):
|
||||
def preload(self, path: str, re_filename: re.Pattern):
|
||||
dirs = os.listdir(path)
|
||||
for d in dirs:
|
||||
filepath = os.path.join(path, d)
|
||||
|
|
@ -91,8 +91,9 @@ class DatabaseLoader:
|
|||
if ext == '.npz':
|
||||
ds = Database(filepath, re_filename)
|
||||
self.datas[ds.name()] = ds
|
||||
print('[text2prompt] Loaded following databases')
|
||||
print(sorted(self.datas.keys()))
|
||||
print('[text2prompt] Following databases are available:')
|
||||
for name in sorted(self.datas.keys()):
|
||||
print(f' {name}')
|
||||
|
||||
|
||||
def load(self, database_name: str):
|
||||
|
|
|
|||
|
|
@ -5,10 +5,15 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbabilityConversion
|
||||
from .database_loader import DatabaseLoader
|
||||
from .. import settings
|
||||
import scripts.t2p.settings as settings
|
||||
|
||||
if settings.DEVELOP:
|
||||
import scripts.t2p.prompt_generator as pgen
|
||||
import scripts.t2p.prompt_generator.database_loader as dloader
|
||||
else:
|
||||
from scripts.t2p.dynamic_import import dynamic_import
|
||||
pgen = dynamic_import('scripts/t2p/prompt_generator/__init__.py')
|
||||
dloader = dynamic_import('scripts/t2p/prompt_generator/database_loader.py')
|
||||
|
||||
NUM_CHOICE = 10
|
||||
K_VALUE = 50
|
||||
|
|
@ -20,10 +25,10 @@ def mean_pooling(model_output, attention_mask):
|
|||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
|
||||
|
||||
class WDLike(PromptGenerator):
|
||||
class WDLike(pgen.PromptGenerator):
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
self.database_loader = DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
|
||||
self.database_loader = dloader.DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
|
||||
|
||||
def get_model_names(self):
|
||||
return sorted(self.database_loader.datas.keys())
|
||||
|
|
@ -38,13 +43,13 @@ class WDLike(PromptGenerator):
|
|||
self.loaded_model_name = None
|
||||
|
||||
def load_data(self, database_name: str):
|
||||
print('Loading database...')
|
||||
print(f'[text2prompt] Loading database with name "{database_name}"...')
|
||||
self.database = self.database_loader.load(database_name)
|
||||
print('Loaded')
|
||||
print('[text2prompt] Database loaded')
|
||||
|
||||
def load_model(self):
|
||||
if self.database is None:
|
||||
print('Cannot load model; Database is not loaded.')
|
||||
print('[text2prompt] Cannot load model; Database is not loaded.')
|
||||
return
|
||||
from modules.devices import device
|
||||
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
|
||||
|
|
@ -52,11 +57,11 @@ class WDLike(PromptGenerator):
|
|||
if self.loaded_model_name and self.loaded_model_name == self.database.model_name:
|
||||
return
|
||||
else:
|
||||
print('Loading model...')
|
||||
print(f'[text2prompt] Loading model with name "{self.database.model_name}"...')
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name])
|
||||
self.model = AutoModel.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]).to(device)
|
||||
self.loaded_model_name = self.database.model_name
|
||||
print('model loaded')
|
||||
print('[text2prompt] Model loaded')
|
||||
|
||||
def unload_model(self):
|
||||
if self.tokenizer is not None:
|
||||
|
|
@ -71,7 +76,7 @@ class WDLike(PromptGenerator):
|
|||
and self.tokenizer is not None \
|
||||
and self.loaded_model_name and self.loaded_model_name == self.database.model_name
|
||||
|
||||
def __call__(self, text: str, opts: GenerationSettings) -> List[str]:
|
||||
def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.GenerationSettings) -> List[str]:
|
||||
if not self.ready(): return ''
|
||||
|
||||
i = max(0, min(opts.tag_range, len(self.database.tag_idx) - 1))
|
||||
|
|
@ -85,50 +90,57 @@ class WDLike(PromptGenerator):
|
|||
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
|
||||
# Tokenize sentences
|
||||
encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device)
|
||||
if text_neg:
|
||||
encoded_input_neg = self.tokenizer(text_neg, padding=True, truncation=True, return_tensors='pt').to(device)
|
||||
# Compute token embeddings
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
if text_neg:
|
||||
model_output_neg = self.model(**encoded_input_neg)
|
||||
|
||||
# Perform pooling
|
||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||
if text_neg:
|
||||
sentence_embeddings -= neg_weight*mean_pooling(model_output_neg, encoded_input_neg['attention_mask'])
|
||||
|
||||
# Normalize embeddings
|
||||
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
||||
# --------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
# Get cosine similarity between given text and tag descriptions
|
||||
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
tag_tokens_dev = torch.from_numpy(self.tokens).to(device)
|
||||
tag_tokens_dev = torch.from_numpy(self.tokens).type(torch.float32).to(device)
|
||||
similarity: torch.Tensor = cos(sentence_embeddings[0], tag_tokens_dev)
|
||||
|
||||
# Convert similarity into probablity
|
||||
if opts.conversion == ProbabilityConversion.CUTOFF_AND_POWER:
|
||||
if opts.conversion is pgen.ProbabilityConversion.CUTOFF_AND_POWER:
|
||||
probs_cpu = torch.clamp(similarity.detach().cpu(), 0, 1) ** opts.prob_power
|
||||
elif opts.conversion == ProbabilityConversion.SOFTMAX:
|
||||
elif opts.conversion is pgen.ProbabilityConversion.SOFTMAX:
|
||||
probs_cpu = torch.softmax(similarity.detach().cpu(), dim=0)
|
||||
|
||||
probs_cpu = probs_cpu / probs_cpu.sum(dim=0)
|
||||
probs_cpu = probs_cpu.nan_to_num()
|
||||
|
||||
results = None
|
||||
|
||||
if opts.sampling == SamplingMethod.NONE:
|
||||
if opts.sampling is pgen.SamplingMethod.NONE:
|
||||
tags_np = np.array(self.tags)
|
||||
opts.n = min(tags_np.shape[0], opts.n)
|
||||
if opts.n <= 0: return []
|
||||
if opts.weighted:
|
||||
probs_np = probs_cpu.detach().numpy()
|
||||
probs_np /= np.sum(probs_np)
|
||||
|
||||
if np.count_nonzero(probs_np) <= opts.n:
|
||||
results = tags_np
|
||||
num_nonzero = np.count_nonzero(probs_np)
|
||||
if num_nonzero <= opts.n:
|
||||
if num_nonzero > 0:
|
||||
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
|
||||
else:
|
||||
results = np.random.choice(tags_np, opts.n, replace=False)
|
||||
else:
|
||||
results = np.random.choice(a=tags_np, size=opts.n, replace=False, p=probs_np)
|
||||
else:
|
||||
# Just sample randomly
|
||||
results = np.random.choice(a=tags_np, size=opts.n, replace=False)
|
||||
|
||||
elif opts.sampling == SamplingMethod.TOP_K:
|
||||
elif opts.sampling is pgen.SamplingMethod.TOP_K:
|
||||
probs, indices = probs_cpu.topk(opts.k)
|
||||
indices = indices.detach().numpy().tolist()
|
||||
if len(indices) <= 0: return []
|
||||
|
|
@ -138,16 +150,20 @@ class WDLike(PromptGenerator):
|
|||
if opts.weighted:
|
||||
probs_np = probs.detach().numpy()
|
||||
probs_np /= np.sum(probs_np)
|
||||
|
||||
if np.count_nonzero(probs_np) <= opts.n:
|
||||
results = tags_np
|
||||
probs_np = np.nan_to_num(probs_np)
|
||||
num_nonzero = np.count_nonzero(probs_np)
|
||||
if num_nonzero <= opts.n:
|
||||
if num_nonzero > 0:
|
||||
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
|
||||
else:
|
||||
results = np.random.choice(tags_np, opts.n, replace=False)
|
||||
else:
|
||||
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
|
||||
else:
|
||||
results = np.random.choice(tags_np, opts.n, replace=False)
|
||||
|
||||
# brought from https://nn.labml.ai/sampling/nucleus.html
|
||||
elif opts.sampling == SamplingMethod.TOP_P:
|
||||
elif opts.sampling is pgen.SamplingMethod.TOP_P:
|
||||
sorted_probs, sorted_indices = probs_cpu.sort(descending=True)
|
||||
cs_probs = torch.cumsum(sorted_probs, dim=0)
|
||||
nucleus = cs_probs < opts.p
|
||||
|
|
@ -163,9 +179,13 @@ class WDLike(PromptGenerator):
|
|||
if opts.weighted:
|
||||
probs_np = np.array([sorted_probs[i] for i in indices])
|
||||
probs_np /= np.sum(probs_np)
|
||||
|
||||
if np.count_nonzero(probs_np) <= opts.n:
|
||||
results = tags_np
|
||||
probs_np = np.nan_to_num(probs_np)
|
||||
num_nonzero = np.count_nonzero(probs_np)
|
||||
if num_nonzero <= opts.n:
|
||||
if num_nonzero > 0:
|
||||
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
|
||||
else:
|
||||
results = np.random.choice(tags_np, opts.n, replace=False)
|
||||
else:
|
||||
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ import os
|
|||
import re
|
||||
from modules import scripts
|
||||
|
||||
# to use intellisense on vscode
|
||||
DEVELOP = False
|
||||
|
||||
def get_abspath(path: str):
|
||||
return os.path.abspath(os.path.join(scripts.basedir(), path))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue