# This code is borrowed from Automatic1111's webui with modifications # AGPL v3.0 (c) 2023 AUTOMATIC1111, read the full license here # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt # Modified by kabachuha and incorporated into the AGPL v3.0 license of the project # Copyright (C) 2023 by Artem Khrapov (kabachuha) # Read LICENSE for usage terms. from collections import namedtuple import math import torch import open_clip from typing import Optional from modules import prompt_parser, devices, sd_hijack from modules.shared import opts import os from ldm.util import instantiate_from_config tokenizer = open_clip.tokenizer._tokenizer from modules import textual_inversion class PromptChunk: """ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt. If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary. Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token, so just 75 tokens from prompt. """ def __init__(self): self.tokens = [] self.multipliers = [] self.fixes = [] class HijackDummy: fixes = None comments = [] layers = None circular_enabled = False clip = None optimization_method = None embedding_db = textual_inversion.textual_inversion.EmbeddingDatabase() class Invoke(object): KEY = 'invoked_by' PRETRAINED = 'from_pretrained' PIPELINE = 'pipeline' TRAINER = 'trainer' LOCAL_TRAINER = 'local_trainer' PREPROCESSOR = 'preprocessor' PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) class FrozenOpenCLIPEmbedder(torch.nn.Module): """ Uses the OpenCLIP transformer encoder for text """ LAYERS = ['last', 'penultimate'] def __init__(self, arch='ViT-H-14', version='open_clip_pytorch_model.bin', device='cuda', max_length=77, freeze=True, layer='last'): super().__init__() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device('cpu'), pretrained=version) del model.visual self.model = model self.device = device self.max_length = max_length if freeze: self.freeze() self.layer = layer if self.layer == 'last': self.layer_idx = 0 elif self.layer == 'penultimate': self.layer_idx = 1 else: raise NotImplementedError() # ^ vanilla self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] self.id_start = tokenizer.encoder[""] self.id_end = tokenizer.encoder[""] self.id_pad = 0 # ^ with custom words self.hijack = HijackDummy() self.chunk_length = 75 def tokenize(self, texts): if not (hasattr(opts, 'use_old_emphasis_implementation') and opts.use_old_emphasis_implementation): tokenized = [tokenizer.encode(text) for text in texts] else: assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' return tokenized def encode_with_transformer(self, text): x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.model.ln_final(x) return x def encode_with_transformers(self, tokens): # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers z = self.encode_with_transformer(tokens) return z def encode_embedding_init_text(self, init_text, nvpt): ids = tokenizer.encode(init_text) ids = torch.asarray([ids], device=devices.device, dtype=torch.int) embedded = self.model.token_embedding.wrapped(ids).squeeze(0) return embedded def empty_chunk(self): """creates an empty PromptChunk and returns it""" chunk = PromptChunk() chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) chunk.multipliers = [1.0] * (self.chunk_length + 2) return chunk def get_target_prompt_token_count(self, token_count): """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented""" return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length def tokenize_line(self, line): """ this transforms a single prompt into a list of PromptChunk objects - as many as needed to represent the prompt. Returns the list and the total number of tokens in the prompt. """ if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) else: parsed = [[line, 1.0]] tokenized = self.tokenize([text for text, _ in parsed]) chunks = [] chunk = PromptChunk() token_count = 0 last_comma = -1 def next_chunk(is_last=False): """puts current chunk into the list of results and produces the next one - empty; if is_last is true, tokens tokens at the end won't add to token_count""" nonlocal token_count nonlocal last_comma nonlocal chunk if is_last: token_count += len(chunk.tokens) else: token_count += self.chunk_length to_add = self.chunk_length - len(chunk.tokens) if to_add > 0: chunk.tokens += [self.id_end] * to_add chunk.multipliers += [1.0] * to_add chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end] chunk.multipliers = [1.0] + chunk.multipliers + [1.0] last_comma = -1 chunks.append(chunk) chunk = PromptChunk() for tokens, (text, weight) in zip(tokenized, parsed): if text == 'BREAK' and weight == -1: next_chunk() continue position = 0 while position < len(tokens): token = tokens[position] if token == self.comma_token: last_comma = len(chunk.tokens) # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: break_location = last_comma + 1 reloc_tokens = chunk.tokens[break_location:] reloc_mults = chunk.multipliers[break_location:] chunk.tokens = chunk.tokens[:break_location] chunk.multipliers = chunk.multipliers[:break_location] next_chunk() chunk.tokens = reloc_tokens chunk.multipliers = reloc_mults if len(chunk.tokens) == self.chunk_length: next_chunk() embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position) if embedding is None: chunk.tokens.append(token) chunk.multipliers.append(weight) position += 1 continue emb_len = int(embedding.vec.shape[0]) if len(chunk.tokens) + emb_len > self.chunk_length: next_chunk() chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding)) chunk.tokens += [0] * emb_len chunk.multipliers += [weight] * emb_len position += embedding_length_in_tokens if len(chunk.tokens) > 0 or len(chunks) == 0: next_chunk(is_last=True) return chunks, token_count def process_texts(self, texts): """ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum length, in tokens, of all texts. """ token_count = 0 cache = {} batch_chunks = [] for line in texts: if line in cache: chunks = cache[line] else: chunks, current_token_count = self.tokenize_line(line) token_count = max(current_token_count, token_count) cache[line] = chunks batch_chunks.append(chunks) return batch_chunks, token_count def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break x = r(x, attn_mask=attn_mask) return x def encode(self, text): return self(text) def get_learned_conditioning(self, text): return self.encode(text) def from_pretrained(cls, model_name_or_path: str, revision: Optional[str] = None, cfg_dict=None, device: str = None, **kwargs): """Instantiate a model from local directory or remote model repo. Note that when loading from remote, the model revision can be specified. Args: model_name_or_path(str): A model dir or a model id to be loaded revision(str, `optional`): The revision used when the model_name_or_path is a model id of the remote hub. default `master`. cfg_dict(Config, `optional`): An optional model config. If provided, it will replace the config read out of the `model_name_or_path` device(str, `optional`): The device to load the model. **kwargs: task(str, `optional`): The `Tasks` enumeration value to replace the task value read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not equal to the model saved. For example, load a `backbone` into a `text-classification` model. Other kwargs will be directly fed into the `model` key, to replace the default configs. Returns: A model instance. """ prefetched = kwargs.get('model_prefetched') if prefetched is not None: kwargs.pop('model_prefetched') invoked_by = kwargs.get(Invoke.KEY) if invoked_by is not None: kwargs.pop(Invoke.KEY) else: invoked_by = Invoke.PRETRAINED if os.path.exists(model_name_or_path): local_model_dir = model_name_or_path if cfg_dict is not None: cfg = cfg_dict """else: cfg = Config.from_file( osp.join(local_model_dir, ModelFile.CONFIGURATION))""" task_name = cfg.task if 'task' in kwargs: task_name = kwargs.pop('task') model_cfg = cfg.model if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): model_cfg.type = model_cfg.model_type model_cfg.model_dir = local_model_dir print("plugins", cfg.safe_get('plugins')) # install and import remote repos before build # register_plugins_repo(cfg.safe_get('plugins')) # register_modelhub_repo(local_model_dir, cfg.get('allow_remote', False)) for k, v in kwargs.items(): model_cfg[k] = v if device is not None: model_cfg.device = device """if task_name is Tasks.backbone: model_cfg.init_backbone = True model = build_backbone(model_cfg) else:""" model = instantiate_from_config(model_cfg) # model = build_model(model_cfg, task_name=task_name) # dynamically add pipeline info to model for pipeline inference if hasattr(cfg, 'pipeline'): model.pipeline = cfg.pipeline if not hasattr(model, 'cfg'): model.cfg = cfg model_cfg.pop('model_dir', None) model.name = model_name_or_path model.model_dir = local_model_dir return model def forward(self, texts): """ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. An example shape returned by this function can be: (2, 77, 768). Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" """ batch_chunks, token_count = self.process_texts(texts) used_embeddings = {} chunk_count = max([len(x) for x in batch_chunks]) zs = [] for i in range(chunk_count): batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] tokens = [x.tokens for x in batch_chunk] multipliers = [x.multipliers for x in batch_chunk] self.hijack.fixes = [x.fixes for x in batch_chunk] for fixes in self.hijack.fixes: for position, embedding in fixes: used_embeddings[embedding.name] = embedding z = self.process_tokens(tokens, multipliers) zs.append(z) if len(used_embeddings) > 0: embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) self.hijack.comments.append(f"Used embeddings: {embeddings_list}") return torch.hstack(zs) def process_tokens(self, remade_batch_tokens, batch_multipliers): """ sends one single prompt chunk to be encoded by transformers neural network. remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens. Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier corresponds to one token. """ tokens = torch.asarray(remade_batch_tokens).to(devices.device) # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones. if self.id_end != self.id_pad: for batch_pos in range(len(remade_batch_tokens)): index = remade_batch_tokens[batch_pos].index(self.id_end) tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad z = self.encode_with_transformers(tokens) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() z = z * (original_mean / new_mean) return z