first commit
commit
65b97cc6d1
|
|
@ -0,0 +1,124 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
from modules import scripts, script_callbacks
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import modules.shared as shared
|
||||
from modules import devices
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange
|
||||
import math
|
||||
from ldm.modules.attention import CrossAttention
|
||||
|
||||
|
||||
hidden_layers = {}
|
||||
hidden_layer_names = []
|
||||
default_hidden_layer_name = "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2"
|
||||
hidden_layer_select = None
|
||||
|
||||
def update_layer_names(model):
|
||||
hidden_layers = {}
|
||||
for n, m in model.named_modules():
|
||||
if(isinstance(m, CrossAttention)):
|
||||
hidden_layers[n] = m
|
||||
hidden_layer_names = list(filter(lambda s : "attn2" in s, hidden_layers.keys()))
|
||||
if hidden_layer_select != None:
|
||||
hidden_layer_select.update(value=default_hidden_layer_name, choice=hidden_layer_names)
|
||||
|
||||
def get_attn(emb, ret):
|
||||
def hook(self, sin, sout):
|
||||
h = self.heads
|
||||
q = self.to_q(sin[0])
|
||||
context = emb
|
||||
k = self.to_k(context)
|
||||
q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k))
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
attn = sim.softmax(dim=-1)
|
||||
ret["out"] = attn
|
||||
return hook
|
||||
|
||||
def generate_vxa(image, prompt, idx, time, layer_name, output_mode):
|
||||
output = image.copy()
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = np.moveaxis(image, 2, 0)
|
||||
image = torch.from_numpy(image).unsqueeze(0)
|
||||
|
||||
model = shared.sd_model
|
||||
layer = hidden_layers[layer_name]
|
||||
cond_model = model.cond_stage_model
|
||||
with torch.no_grad(), devices.autocast():
|
||||
image = image.to(devices.device)
|
||||
latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||
try:
|
||||
t = torch.tensor([float(time)]).to(devices.device)
|
||||
except:
|
||||
return output
|
||||
emb = cond_model([prompt])
|
||||
|
||||
attn_out = {}
|
||||
handle = layer.register_forward_hook(get_attn(emb, attn_out))
|
||||
try:
|
||||
model.model.diffusion_model(latent, t, emb)
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
if (idx == ""):
|
||||
img = attn_out["out"][:,:,1:].sum(-1).sum(0)
|
||||
else:
|
||||
try:
|
||||
idxs = list(map(int, filter(lambda x : x != '', idx.strip().split(','))))
|
||||
img = attn_out["out"][:,:,idxs].sum(-1).sum(0)
|
||||
except:
|
||||
return output
|
||||
|
||||
scale = round(math.sqrt((image.shape[2] * image.shape[3]) / img.shape[0]))
|
||||
h = image.shape[2] // scale
|
||||
w = image.shape[3] // scale
|
||||
img = img.reshape(h, w) / img.max()
|
||||
img = img.to("cpu").numpy()
|
||||
output = output.astype(np.float64)
|
||||
if output_mode == "masked":
|
||||
for i in range(output.shape[0]):
|
||||
for j in range(output.shape[1]):
|
||||
output[i][j] *= img[i // scale][j // scale]
|
||||
elif output_mode == "grey":
|
||||
for i in range(output.shape[0]):
|
||||
for j in range(output.shape[1]):
|
||||
output[i][j] = [img[i // scale][j // scale] * 255.0] * 3
|
||||
output = output.astype(np.uint8)
|
||||
return output
|
||||
|
||||
def add_tab():
|
||||
with gr.Blocks(analytics_enabled=False) as visualize_cross_attention:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_image = gr.Image(elem_id="vxa_input_image")
|
||||
with gr.Column():
|
||||
vxa_output = gr.Image(elem_id = "vxa_output", interactive=False)
|
||||
vxa_generate = gr.Button(value="Visualize Cross-Attention", elem_id="vxa_gen_btn")
|
||||
with gr.Row():
|
||||
vxa_prompt = gr.Textbox(label="Prompt", placeholder="Prompt to be visualized")
|
||||
with gr.Row():
|
||||
vxa_token_indices = gr.Textbox(value="", label="Indices of tokens to be visualized", placeholder="Example: 1, 3 means the sum of the first and the third tokens. 1 is suggected for a single token. Leave blank to visualize all tokens.")
|
||||
with gr.Row():
|
||||
vxa_time_embedding = gr.Textbox(value="1.0", label="Time embedding")
|
||||
with gr.Row():
|
||||
for n, m in shared.sd_model.named_modules():
|
||||
if(isinstance(m, CrossAttention)):
|
||||
hidden_layers[n] = m
|
||||
hidden_layer_names = list(filter(lambda s : "attn2" in s, hidden_layers.keys()))
|
||||
hidden_layer_select = gr.Dropdown(value=default_hidden_layer_name, label="Cross-attention layer", choices=hidden_layer_names)
|
||||
with gr.Row():
|
||||
vxa_output_mode = gr.Dropdown(value="masked", label="Output mode", choices=["masked", "grey"])
|
||||
|
||||
vxa_generate.click(
|
||||
fn=generate_vxa,
|
||||
inputs=[input_image, vxa_prompt, vxa_token_indices, vxa_time_embedding, hidden_layer_select, vxa_output_mode],
|
||||
outputs=[vxa_output],
|
||||
)
|
||||
|
||||
return (visualize_cross_attention, "VXA", "visualize_cross_attention"),
|
||||
|
||||
script_callbacks.on_ui_tabs(add_tab)
|
||||
script_callbacks.on_model_loaded(update_layer_names)
|
||||
Loading…
Reference in New Issue