first commit

master
benkyoujouzu 2022-11-25 12:09:13 +08:00
commit 65b97cc6d1
1 changed files with 124 additions and 0 deletions

124
scripts/vxa.py Normal file
View File

@ -0,0 +1,124 @@
import os
import gradio as gr
from modules import scripts, script_callbacks
from PIL import Image
import numpy as np
import torch
import modules.shared as shared
from modules import devices
from torch import nn, einsum
from einops import rearrange
import math
from ldm.modules.attention import CrossAttention
hidden_layers = {}
hidden_layer_names = []
default_hidden_layer_name = "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2"
hidden_layer_select = None
def update_layer_names(model):
hidden_layers = {}
for n, m in model.named_modules():
if(isinstance(m, CrossAttention)):
hidden_layers[n] = m
hidden_layer_names = list(filter(lambda s : "attn2" in s, hidden_layers.keys()))
if hidden_layer_select != None:
hidden_layer_select.update(value=default_hidden_layer_name, choice=hidden_layer_names)
def get_attn(emb, ret):
def hook(self, sin, sout):
h = self.heads
q = self.to_q(sin[0])
context = emb
k = self.to_k(context)
q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = sim.softmax(dim=-1)
ret["out"] = attn
return hook
def generate_vxa(image, prompt, idx, time, layer_name, output_mode):
output = image.copy()
image = image.astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
image = torch.from_numpy(image).unsqueeze(0)
model = shared.sd_model
layer = hidden_layers[layer_name]
cond_model = model.cond_stage_model
with torch.no_grad(), devices.autocast():
image = image.to(devices.device)
latent = model.get_first_stage_encoding(model.encode_first_stage(image))
try:
t = torch.tensor([float(time)]).to(devices.device)
except:
return output
emb = cond_model([prompt])
attn_out = {}
handle = layer.register_forward_hook(get_attn(emb, attn_out))
try:
model.model.diffusion_model(latent, t, emb)
finally:
handle.remove()
if (idx == ""):
img = attn_out["out"][:,:,1:].sum(-1).sum(0)
else:
try:
idxs = list(map(int, filter(lambda x : x != '', idx.strip().split(','))))
img = attn_out["out"][:,:,idxs].sum(-1).sum(0)
except:
return output
scale = round(math.sqrt((image.shape[2] * image.shape[3]) / img.shape[0]))
h = image.shape[2] // scale
w = image.shape[3] // scale
img = img.reshape(h, w) / img.max()
img = img.to("cpu").numpy()
output = output.astype(np.float64)
if output_mode == "masked":
for i in range(output.shape[0]):
for j in range(output.shape[1]):
output[i][j] *= img[i // scale][j // scale]
elif output_mode == "grey":
for i in range(output.shape[0]):
for j in range(output.shape[1]):
output[i][j] = [img[i // scale][j // scale] * 255.0] * 3
output = output.astype(np.uint8)
return output
def add_tab():
with gr.Blocks(analytics_enabled=False) as visualize_cross_attention:
with gr.Row():
with gr.Column():
input_image = gr.Image(elem_id="vxa_input_image")
with gr.Column():
vxa_output = gr.Image(elem_id = "vxa_output", interactive=False)
vxa_generate = gr.Button(value="Visualize Cross-Attention", elem_id="vxa_gen_btn")
with gr.Row():
vxa_prompt = gr.Textbox(label="Prompt", placeholder="Prompt to be visualized")
with gr.Row():
vxa_token_indices = gr.Textbox(value="", label="Indices of tokens to be visualized", placeholder="Example: 1, 3 means the sum of the first and the third tokens. 1 is suggected for a single token. Leave blank to visualize all tokens.")
with gr.Row():
vxa_time_embedding = gr.Textbox(value="1.0", label="Time embedding")
with gr.Row():
for n, m in shared.sd_model.named_modules():
if(isinstance(m, CrossAttention)):
hidden_layers[n] = m
hidden_layer_names = list(filter(lambda s : "attn2" in s, hidden_layers.keys()))
hidden_layer_select = gr.Dropdown(value=default_hidden_layer_name, label="Cross-attention layer", choices=hidden_layer_names)
with gr.Row():
vxa_output_mode = gr.Dropdown(value="masked", label="Output mode", choices=["masked", "grey"])
vxa_generate.click(
fn=generate_vxa,
inputs=[input_image, vxa_prompt, vxa_token_indices, vxa_time_embedding, hidden_layer_select, vxa_output_mode],
outputs=[vxa_output],
)
return (visualize_cross_attention, "VXA", "visualize_cross_attention"),
script_callbacks.on_ui_tabs(add_tab)
script_callbacks.on_model_loaded(update_layer_names)