from typing import List import numpy as np import torch import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer from . import PromptGenerator, GenerationSettings, SamplingMethod, ProbabilityConversion from .database_loader import DatabaseLoader from .. import settings NUM_CHOICE = 10 K_VALUE = 50 # brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] #First element of model_output contains all token embeddings input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) class WDLike(PromptGenerator): def __init__(self): self.clear() self.database_loader = DatabaseLoader(settings.DATABASE_PATH_DANBOORU, settings.RE_TOKENFILE_DANBOORU) def get_model_names(self): return sorted(self.database_loader.datas.keys()) def clear(self): self.database_loader = None self.database = None self.tags = None self.tokens = None self.tokenizer = None self.model = None self.loaded_model_name = None def load_data(self, database_name: str): print('Loading database...') self.database = self.database_loader.load(database_name) print('Loaded') def load_model(self): if self.database is None: print('Cannot load model; Database is not loaded.') return from modules.devices import device # brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers # Load model from HuggingFace Hub if self.loaded_model_name and self.loaded_model_name == self.database.model_name: return else: print('Loading model...') self.tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]) self.model = AutoModel.from_pretrained(settings.TOKENIZER_MODELS[self.database.model_name]).to(device) self.loaded_model_name = self.database.model_name print('model loaded') def unload_model(self): if self.tokenizer is not None: del self.tokenizer if self.model is not None: del self.model def ready(self) -> bool: return self.database is not None \ and self.database.loaded() \ and self.model is not None \ and self.tokenizer is not None \ and self.loaded_model_name and self.loaded_model_name == self.database.model_name def __call__(self, text: str, opts: GenerationSettings) -> List[str]: if not self.ready(): return '' i = max(0, min(opts.tag_range, len(self.database.tag_idx) - 1)) r = self.database.tag_idx[i][1] self.tokens = self.database.data[:r, :] self.tags = self.database.tags[:r] from modules.devices import device # -------------------------------------------------------------------------------------------------------------------------------- # brought from https://huggingface.co/sentence-transformers/all-mpnet-base-v2#usage-huggingface-transformers # Tokenize sentences encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device) # Compute token embeddings with torch.no_grad(): model_output = self.model(**encoded_input) # Perform pooling sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) # Normalize embeddings sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) # -------------------------------------------------------------------------------------------------------------------------------- # Get cosine similarity between given text and tag descriptions cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) tag_tokens_dev = torch.from_numpy(self.tokens).to(device) similarity: torch.Tensor = cos(sentence_embeddings[0], tag_tokens_dev) # Convert similarity into probablity if opts.conversion == ProbabilityConversion.CUTOFF_AND_POWER: probs_cpu = torch.clamp(similarity.detach().cpu(), 0, 1) ** opts.prob_power elif opts.conversion == ProbabilityConversion.SOFTMAX: probs_cpu = torch.softmax(similarity.detach().cpu(), dim=0) probs_cpu = probs_cpu / probs_cpu.sum(dim=0) results = None if opts.sampling == SamplingMethod.NONE: tags_np = np.array(self.tags) opts.n = min(tags_np.shape[0], opts.n) if opts.n <= 0: return [] if opts.weighted: probs_np = probs_cpu.detach().numpy() probs_np /= np.sum(probs_np) if np.count_nonzero(probs_np) <= opts.n: results = tags_np else: results = np.random.choice(a=tags_np, size=opts.n, replace=False, p=probs_np) else: # Just sample randomly results = np.random.choice(a=tags_np, size=opts.n, replace=False) elif opts.sampling == SamplingMethod.TOP_K: probs, indices = probs_cpu.topk(opts.k) indices = indices.detach().numpy().tolist() if len(indices) <= 0: return [] tags_np = np.array([self.tags[i] for i in indices]) opts.n = min(tags_np.shape[0], opts.n) if opts.weighted: probs_np = probs.detach().numpy() probs_np /= np.sum(probs_np) if np.count_nonzero(probs_np) <= opts.n: results = tags_np else: results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np) else: results = np.random.choice(tags_np, opts.n, replace=False) # brought from https://nn.labml.ai/sampling/nucleus.html elif opts.sampling == SamplingMethod.TOP_P: sorted_probs, sorted_indices = probs_cpu.sort(descending=True) cs_probs = torch.cumsum(sorted_probs, dim=0) nucleus = cs_probs < opts.p nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]]) sorted_indices[~nucleus] = -1 indices_np = sorted_indices.detach().numpy() indices = [i for i in indices_np if i >= 0] if len(indices) <= 0: return [] tags_np = np.array([self.tags[i] for i in indices]) opts.n = min(tags_np.shape[0], opts.n) if opts.weighted: probs_np = np.array([sorted_probs[i] for i in indices]) probs_np /= np.sum(probs_np) if np.count_nonzero(probs_np) <= opts.n: results = tags_np else: results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np) else: results = np.random.choice(tags_np, opts.n, replace=False) return [] if results is None else results.tolist()