58 lines
2.3 KiB
Python
58 lines
2.3 KiB
Python
from modules import sd_hijack_clip, sd_hijack, shared
|
|
from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations, fix_checkpoint
|
|
import ldm.modules.encoders.modules
|
|
|
|
default_hijack = StableDiffusionModelHijack.hijack
|
|
|
|
def trigger_sd_hijack(enabled, pretrained_key):
|
|
clear_any_hijacks()
|
|
if not enabled or pretrained_key == '':
|
|
pretrained_key = 'openai/clip-vit-large-patch14'
|
|
StableDiffusionModelHijack.hijack = create_lambda(pretrained_key)
|
|
print("Hijacked clip text model!")
|
|
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
|
sd_hijack.model_hijack.hijack(shared.sd_model)
|
|
if not enabled:
|
|
StableDiffusionModelHijack.hijack = default_hijack
|
|
|
|
|
|
def clear_any_hijacks():
|
|
StableDiffusionModelHijack.hijack = default_hijack
|
|
|
|
def create_lambda(model):
|
|
def hijack_lambda(self, m):
|
|
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
|
from transformers import CLIPTextModel, CLIPTokenizer
|
|
print(f"Changing CLIP model to {model}")
|
|
try:
|
|
m.cond_stage_model.transformer = CLIPTextModel.from_pretrained(
|
|
model).to(m.cond_stage_model.transformer.device)
|
|
m.cond_stage_model.transformer.requires_grad_(False)
|
|
m.cond_stage_model.tokenizer = CLIPTokenizer.from_pretrained(
|
|
model)
|
|
except:
|
|
print(f"Cannot initiate from given model key {model}!")
|
|
|
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
|
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
|
|
|
self.optimization_method = apply_optimizations()
|
|
|
|
self.clip = m.cond_stage_model
|
|
|
|
fix_checkpoint()
|
|
|
|
def flatten(el):
|
|
flattened = [flatten(children) for children in el.children()]
|
|
res = [el]
|
|
for c in flattened:
|
|
res += c
|
|
return res
|
|
|
|
self.layers = flatten(m)
|
|
else:
|
|
print("CLIP change can be only applied to FrozenCLIPEmbedder class")
|
|
return default_hijack(self, m)
|
|
return hijack_lambda
|