139 lines
5.9 KiB
Python
139 lines
5.9 KiB
Python
from typing import List
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
|
from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbablityConversion
|
|
from .. import settings
|
|
|
|
|
|
NUM_CHOICE = 10
|
|
K_VALUE = 50
|
|
|
|
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
|
|
def mean_pooling(model_output, attention_mask):
|
|
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
|
|
|
|
class WDLike(PromptGenerator):
|
|
def __init__(self):
|
|
self.tags = []
|
|
self.tokens = None
|
|
self.tokenizer = None
|
|
self.model = None
|
|
self.load_data()
|
|
|
|
def load_data(self):
|
|
with open(settings.WDLIKE_TAG_PATH, mode='r', encoding='utf8', newline='\n') as f:
|
|
self.tags = [l.strip() for l in f.readlines()]
|
|
with open(settings.WDLIKE_TOKEN_PATH, mode='rb') as f:
|
|
self.tokens = np.load(f)
|
|
|
|
def load_model(self):
|
|
from modules.devices import device
|
|
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
|
|
# Load model from HuggingFace Hub
|
|
self.tokenizer = AutoTokenizer.from_pretrained(settings.WDLIKE_MODEL_NAME)
|
|
self.model = AutoModel.from_pretrained(settings.WDLIKE_MODEL_NAME).to(device)
|
|
|
|
def __call__(self, text: str, opts: GenerationSettings) -> List[str]:
|
|
if not self.model or not self.tokenizer:
|
|
return ''
|
|
|
|
from modules.devices import device
|
|
|
|
# --------------------------------------------------------------------------------------------------------------------------------
|
|
# brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers
|
|
# Tokenize sentences
|
|
encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device)
|
|
# Compute token embeddings
|
|
with torch.no_grad():
|
|
model_output = self.model(**encoded_input)
|
|
|
|
# Perform pooling
|
|
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
|
|
|
# Normalize embeddings
|
|
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
|
# --------------------------------------------------------------------------------------------------------------------------------
|
|
|
|
# Get cosine similarity between given text and tag descriptions
|
|
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
|
|
|
tag_tokens_dev = torch.from_numpy(self.tokens).to(device)
|
|
similarity: torch.Tensor = cos(sentence_embeddings[0], tag_tokens_dev)
|
|
|
|
# Convert similarity into probablity
|
|
if opts.conversion == ProbablityConversion.CUTOFF_AND_POWER:
|
|
probs_cpu = torch.clamp(similarity.detach().cpu(), 0, 1) ** opts.prob_power
|
|
elif opts.conversion == ProbablityConversion.SOFTMAX:
|
|
probs_cpu = torch.softmax(similarity.detach().cpu(), dim=0)
|
|
|
|
probs_cpu = probs_cpu / probs_cpu.sum(dim=0)
|
|
|
|
results = None
|
|
|
|
if opts.sampling == SamplingMethod.NONE:
|
|
tags_np = np.array(self.tags)
|
|
opts.n = min(tags_np.shape[0], opts.n)
|
|
if opts.n <= 0: return []
|
|
if opts.weighted:
|
|
probs_np = probs_cpu.detach().numpy()
|
|
probs_np /= np.sum(probs_np)
|
|
|
|
if np.count_nonzero(probs_np) <= opts.n:
|
|
results = tags_np
|
|
else:
|
|
results = np.random.choice(a=tags_np, size=opts.n, replace=False, p=probs_np)
|
|
else:
|
|
# Just sample randomly
|
|
results = np.random.choice(a=tags_np, size=opts.n, replace=False)
|
|
|
|
elif opts.sampling == SamplingMethod.TOP_K:
|
|
probs, indices = probs_cpu.topk(opts.k)
|
|
indices = indices.detach().numpy().tolist()
|
|
if len(indices) <= 0: return []
|
|
|
|
tags_np = np.array([self.tags[i] for i in indices])
|
|
opts.n = min(tags_np.shape[0], opts.n)
|
|
if opts.weighted:
|
|
probs_np = probs.detach().numpy()
|
|
probs_np /= np.sum(probs_np)
|
|
|
|
if np.count_nonzero(probs_np) <= opts.n:
|
|
results = tags_np
|
|
else:
|
|
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
|
|
else:
|
|
results = np.random.choice(tags_np, opts.n, replace=False)
|
|
|
|
# brought from https://nn.labml.ai/sampling/nucleus.html
|
|
elif opts.sampling == SamplingMethod.TOP_P:
|
|
sorted_probs, sorted_indices = probs_cpu.sort(descending=True)
|
|
cs_probs = torch.cumsum(sorted_probs, dim=0)
|
|
nucleus = cs_probs < opts.p
|
|
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]])
|
|
sorted_indices[~nucleus] = -1
|
|
|
|
indices_np = sorted_indices.detach().numpy()
|
|
indices = [i for i in indices_np if i >= 0]
|
|
if len(indices) <= 0: return []
|
|
|
|
tags_np = np.array([self.tags[i] for i in indices])
|
|
opts.n = min(tags_np.shape[0], opts.n)
|
|
if opts.weighted:
|
|
probs_np = np.array([sorted_probs[i] for i in indices])
|
|
probs_np /= np.sum(probs_np)
|
|
|
|
if np.count_nonzero(probs_np) <= opts.n:
|
|
results = tags_np
|
|
else:
|
|
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
|
|
else:
|
|
results = np.random.choice(tags_np, opts.n, replace=False)
|
|
|
|
return [] if results is None else results.tolist() |