from modules import sd_hijack_clip, sd_hijack, shared from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations try: from modules.sd_hijack import fix_checkpoint def clear_any_hijacks(): StableDiffusionModelHijack.hijack = default_hijack except (ModuleNotFoundError, ImportError): from modules.sd_hijack_checkpoint import add, remove def fix_checkpoint(): add() def clear_any_hijacks(): remove() StableDiffusionModelHijack.hijack = default_hijack import ldm.modules.encoders.modules default_hijack = StableDiffusionModelHijack.hijack def trigger_sd_hijack(enabled, pretrained_key): clear_any_hijacks() if not enabled or pretrained_key == '': pretrained_key = 'openai/clip-vit-large-patch14' StableDiffusionModelHijack.hijack = create_lambda(pretrained_key) print("Hijacked clip text model!") sd_hijack.model_hijack.undo_hijack(shared.sd_model) sd_hijack.model_hijack.hijack(shared.sd_model) if not enabled: StableDiffusionModelHijack.hijack = default_hijack def create_lambda(model): def hijack_lambda(self, m): if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: from transformers import CLIPTextModel, CLIPTokenizer print(f"Changing CLIP model to {model}") try: m.cond_stage_model.transformer = CLIPTextModel.from_pretrained( model).to(m.cond_stage_model.transformer.device) m.cond_stage_model.transformer.requires_grad_(False) m.cond_stage_model.tokenizer = CLIPTokenizer.from_pretrained( model) except: print(f"Cannot initiate from given model key {model}!") model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.optimization_method = apply_optimizations() self.clip = m.cond_stage_model fix_checkpoint() def flatten(el): flattened = [flatten(children) for children in el.children()] res = [el] for c in flattened: res += c return res self.layers = flatten(m) else: print("CLIP change can be only applied to FrozenCLIPEmbedder class") return default_hijack(self, m) return hijack_lambda