stable-diffusion-webui-text.../scripts/t2p/prompt_generator/wd_like.py

179 lines
7.6 KiB
Python

from typing import List
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
import scripts.t2p.settings as settings
if settings.DEVELOP:
import scripts.t2p.prompt_generator as pgen
import scripts.t2p.prompt_generator.database_loader as dloader
else:
from scripts.t2p.dynamic_import import dynamic_import
pgen = dynamic_import('scripts/t2p/prompt_generator/__init__.py')
dloader = dynamic_import('scripts/t2p/prompt_generator/database_loader.py')
NUM_CHOICE = 10
K_VALUE = 50
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class WDLike(pgen.PromptGenerator):
def __init__(self):
self.clear()
self.database_loader = dloader.DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU)
def get_model_names(self):
return sorted(self.database_loader.datas.keys())
def clear(self):
self.database_loader = None
self.database = None
self.tags = None
self.tokens = None
self.tokenizer = None
self.model = None
self.loaded_model_name = None
def load_data(self, database_name: str):
print(f'[text2prompt] Loading database with name "{database_name}"...')
self.database = self.database_loader.load(database_name)
print('[text2prompt] Database loaded')
def load_model(self):
if self.database is None:
print('[text2prompt] Cannot load model; Database is not loaded.')
return
from modules.devices import device
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
# Load model from HuggingFace Hub
if self.loaded_model_name and self.loaded_model_name == self.database.model_name:
return
else:
print(f'[text2prompt] Loading model with name "{self.database.model_name}"...')
self.tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name])
self.model = AutoModel.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]).to(device)
self.loaded_model_name = self.database.model_name
print('[text2prompt] Model loaded')
def unload_model(self):
if self.tokenizer is not None:
del self.tokenizer
if self.model is not None:
del self.model
def ready(self) -> bool:
return self.database is not None \
and self.database.loaded() \
and self.model is not None \
and self.tokenizer is not None \
and self.loaded_model_name and self.loaded_model_name == self.database.model_name
def __call__(self, text: str, opts: pgen.GenerationSettings) -> List[str]:
if not self.ready(): return ''
i = max(0, min(opts.tag_range, len(self.database.tag_idx) - 1))
r = self.database.tag_idx[i][1]
self.tokens = self.database.data[:r, :]
self.tags = self.database.tags[:r]
from modules.devices import device
# --------------------------------------------------------------------------------------------------------------------------------
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
# Tokenize sentences
encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device)
# Compute token embeddings
with torch.no_grad():
model_output = self.model(**encoded_input)
# Perform pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
# --------------------------------------------------------------------------------------------------------------------------------
# Get cosine similarity between given text and tag descriptions
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
tag_tokens_dev = torch.from_numpy(self.tokens).to(device)
similarity: torch.Tensor = cos(sentence_embeddings[0], tag_tokens_dev)
# Convert similarity into probablity
if opts.conversion is pgen.ProbabilityConversion.CUTOFF_AND_POWER:
probs_cpu = torch.clamp(similarity.detach().cpu(), 0, 1) ** opts.prob_power
elif opts.conversion is pgen.ProbabilityConversion.SOFTMAX:
probs_cpu = torch.softmax(similarity.detach().cpu(), dim=0)
probs_cpu = probs_cpu / probs_cpu.sum(dim=0)
results = None
if opts.sampling is pgen.SamplingMethod.NONE:
tags_np = np.array(self.tags)
opts.n = min(tags_np.shape[0], opts.n)
if opts.n <= 0: return []
if opts.weighted:
probs_np = probs_cpu.detach().numpy()
probs_np /= np.sum(probs_np)
if np.count_nonzero(probs_np) <= opts.n:
results = tags_np
else:
results = np.random.choice(a=tags_np, size=opts.n, replace=False, p=probs_np)
else:
# Just sample randomly
results = np.random.choice(a=tags_np, size=opts.n, replace=False)
elif opts.sampling is pgen.SamplingMethod.TOP_K:
probs, indices = probs_cpu.topk(opts.k)
indices = indices.detach().numpy().tolist()
if len(indices) <= 0: return []
tags_np = np.array([self.tags[i] for i in indices])
opts.n = min(tags_np.shape[0], opts.n)
if opts.weighted:
probs_np = probs.detach().numpy()
probs_np /= np.sum(probs_np)
if np.count_nonzero(probs_np) <= opts.n:
results = tags_np
else:
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
else:
results = np.random.choice(tags_np, opts.n, replace=False)
# brought from https://nn.labml.ai/sampling/nucleus.html
elif opts.sampling is pgen.SamplingMethod.TOP_P:
sorted_probs, sorted_indices = probs_cpu.sort(descending=True)
cs_probs = torch.cumsum(sorted_probs, dim=0)
nucleus = cs_probs < opts.p
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]])
sorted_indices[~nucleus] = -1
indices_np = sorted_indices.detach().numpy()
indices = [i for i in indices_np if i >= 0]
if len(indices) <= 0: return []
tags_np = np.array([self.tags[i] for i in indices])
opts.n = min(tags_np.shape[0], opts.n)
if opts.weighted:
probs_np = np.array([sorted_probs[i] for i in indices])
probs_np /= np.sum(probs_np)
if np.count_nonzero(probs_np) <= opts.n:
results = tags_np
else:
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
else:
results = np.random.choice(tags_np, opts.n, replace=False)
return [] if results is None else results.tolist()