import open_clip.tokenizer import torch from modules import sd_hijack_clip, devices tokenizer = open_clip.tokenizer._tokenizer # pylint: disable=protected-access class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] self.id_start = tokenizer.encoder[""] self.id_end = tokenizer.encoder[""] self.id_pad = 0 def tokenize(self, texts): tokenized = [tokenizer.encode(text) for text in texts] return tokenized def encode_with_transformers(self, tokens): z = self.wrapped.encode_with_transformer(tokens) return z def encode_embedding_init_text(self, init_text, nvpt): # pylint: disable=unused-argument ids = tokenizer.encode(init_text) ids = torch.asarray([ids], device=devices.device, dtype=torch.int) embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) return embedded