246 lines
10 KiB
Python
246 lines
10 KiB
Python
import sys
|
|
import math
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch import nn, Tensor, einsum
|
|
from einops import rearrange
|
|
|
|
from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock, CrossAttention, MemoryEfficientCrossAttention # type: ignore
|
|
from modules.processing import StableDiffusionProcessing
|
|
from modules.hypernetworks import hypernetwork
|
|
from modules import shared
|
|
|
|
from scripts.lib.feature_extractor import FeatureExtractorBase
|
|
from scripts.lib.features.featureinfo import MultiImageFeatures
|
|
from scripts.lib.features.extractor import get_unet_layer
|
|
from scripts.lib.attention.featureinfo import AttnFeatureInfo
|
|
from scripts.lib import layerinfo, tutils
|
|
from scripts.lib.colorizer import Colorizer
|
|
from scripts.lib.utils import *
|
|
from scripts.lib.report import message as E
|
|
|
|
if TYPE_CHECKING:
|
|
from scripts.dumpunet import Script
|
|
|
|
class AttentionExtractor(FeatureExtractorBase):
|
|
|
|
features_to_save: list[str]
|
|
|
|
# image_index -> step -> Features
|
|
extracted_features: MultiImageFeatures[AttnFeatureInfo]
|
|
|
|
def __init__(
|
|
self,
|
|
runner: "Script",
|
|
enabled: bool,
|
|
total_steps: int,
|
|
layer_input: str,
|
|
step_input: str,
|
|
features: list[str],
|
|
path: str|None,
|
|
):
|
|
if features is None or len(features) == 0:
|
|
if enabled:
|
|
enabled = False
|
|
print("\033[33m", file=sys.stderr, end="", flush=False)
|
|
print(E("Attention: Disabled because no features are selected. Select features in <Output features>."), file=sys.stderr, end="", flush=False)
|
|
print("\033[0m", file=sys.stderr)
|
|
|
|
super().__init__(runner, enabled, total_steps, layer_input, step_input, path)
|
|
self.features_to_save = features
|
|
self.extracted_features = MultiImageFeatures()
|
|
|
|
def hook_unet(self, p: StableDiffusionProcessing, unet: nn.Module):
|
|
|
|
def create_hook(layername: str, block: BasicTransformerBlock, n: int, depth: int, c: int):
|
|
|
|
def forward(module, fn, x, context=None, *args, **kwargs):
|
|
result = fn(x, context=context, *args, **kwargs)
|
|
|
|
if self.steps_on_batch in self.steps:
|
|
if c == 2:
|
|
# process for only cross-attention
|
|
self.log(f"{self.steps_on_batch:>03} {layername}-{n}-{depth}-attn{c} ({'cross' if (block.disable_self_attn or 1 < c) else 'self'})")
|
|
self.log(f" | {shape(x),shape(context)} -> {shape(result)}")
|
|
|
|
ks, qks, vqks = self.process_attention(module, x, context)
|
|
# qk := (batch, head, token, height*width)
|
|
# vqk := (batch, height*width, ch)
|
|
|
|
images_per_batch = p.batch_size
|
|
assert qks.shape[0] == vqks.shape[0], f"{qks.shape}, {vqks.shape}"
|
|
|
|
for image_index, (k, vk, vqk) in enumerate(
|
|
zip(ks[:images_per_batch], qks[:images_per_batch], vqks[:images_per_batch]),
|
|
(self.batch_num-1) * images_per_batch
|
|
):
|
|
features = self.extracted_features[image_index][self.steps_on_batch]
|
|
features.add(
|
|
layername,
|
|
AttnFeatureInfo(k, vk, vqk)
|
|
)
|
|
|
|
return result
|
|
|
|
return forward
|
|
|
|
active_steps: list[str] = []
|
|
for layer in self.layers:
|
|
for n, d, block, attn1, attn2 in get_unet_attn_layers(unet, layer):
|
|
self.log(f"Attention: hooking {layer}...")
|
|
active_steps.append(layer)
|
|
self.hook_forward(attn1, create_hook(layer, block, n, d, 1))
|
|
self.hook_forward(attn2, create_hook(layer, block, n, d, 2))
|
|
self.layers = sorted(set(active_steps), key=active_steps.index)
|
|
|
|
return super().hook_unet(p, unet)
|
|
|
|
def process_attention(self, module, x, context):
|
|
# q_in : unet features ([2, 4096, 320])
|
|
# k_in, v_in : embedding vector kv (cross-attention) ([2, 77, 320]) or unet features kv (self-attention) ([2, 4096, 320])
|
|
# q,k,v : head-separated q_in, k_in and v_in
|
|
|
|
if getattr(hypernetwork, "apply_hypernetworks", None) is not None:
|
|
ctx_k, ctx_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
|
elif getattr(hypernetwork, "apply_hypernetwork", None) is not None:
|
|
ctx_k, ctx_v = hypernetwork.apply_hypernetwork( # type: ignore
|
|
shared.loaded_hypernetwork, # type: ignore
|
|
context if context is not None else x
|
|
)
|
|
else:
|
|
assert False, "not supported version"
|
|
|
|
q_in = module.to_q(x)
|
|
k_in = module.to_k(ctx_k)
|
|
v_in = module.to_v(ctx_v)
|
|
|
|
q: Tensor
|
|
k: Tensor
|
|
v: Tensor
|
|
|
|
q, k, v = map( # type: ignore
|
|
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=module.heads),
|
|
(q_in, k_in, v_in)
|
|
)
|
|
|
|
sim = einsum('b i d, b j d -> b i j', q, k) * module.scale
|
|
sim = sim.softmax(dim=-1)
|
|
# sim.shape == '(b h) i j'
|
|
|
|
o_in = einsum('b i j, b j d -> b i d', sim, v)
|
|
o = rearrange(o_in, '(b h) n d -> b n (h d)', h=module.heads)
|
|
|
|
kk: Tensor = rearrange(k, '(b h) t d -> b h t d', h=module.heads).detach().clone()
|
|
qk: Tensor = rearrange(sim, '(b h) d t -> b h t d', h=module.heads).detach().clone()
|
|
vqk: Tensor = o.detach().clone()
|
|
|
|
self.log(f" | q: {shape(q_in)} # {shape(q)}")
|
|
self.log(f" | k: {shape(k_in)} # {shape(k)}")
|
|
self.log(f" | v: {shape(v_in)} # {shape(v)}")
|
|
#self.log(f" | kk: {shape(kk)} # {shape(k)}")
|
|
#self.log(f" | qk: {shape(qk)} # {shape(sim)}")
|
|
#self.log(f" | vqk: {shape(vqk)}")
|
|
|
|
del q_in, k_in, v_in, q, k, v, sim, o_in, o
|
|
|
|
return kk, qk, vqk
|
|
|
|
def add_images(
|
|
self,
|
|
p: StableDiffusionProcessing,
|
|
builder,
|
|
extracted_features: MultiImageFeatures[AttnFeatureInfo],
|
|
add_average: bool,
|
|
color: Colorizer
|
|
):
|
|
if not self.enabled: return
|
|
if shared.state.interrupted: return
|
|
if len(extracted_features) == 0:
|
|
print("\033[33m", file=sys.stderr, end="", flush=False)
|
|
print(E("Attention: no images are extracted"), file=sys.stderr, end="", flush=False)
|
|
print("\033[0m", file=sys.stderr)
|
|
return
|
|
return super().add_images(p, builder, extracted_features, add_average, color)
|
|
|
|
def feature_to_grid_images(self, feature: AttnFeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, add_average: bool, color: Colorizer):
|
|
w, h, ch = get_shape(layer, width, height)
|
|
images = []
|
|
|
|
if "K" in self.features_to_save:
|
|
k = feature.k
|
|
heads_k, ch_k, n_k = k.shape
|
|
assert ch_k % 77 == 0, f"ch_k={ch_k}"
|
|
k1 = rearrange(k, 'a t n -> a n t').contiguous()
|
|
k_images = tutils.tensor_to_image(k1, 1, heads_k, color, add_average)
|
|
images.extend(k_images)
|
|
del k1
|
|
|
|
if "Q*K" in self.features_to_save:
|
|
qk = feature.qk
|
|
heads_qk, ch_qk, n_qk = qk.shape
|
|
assert ch_qk % 77 == 0, f"ch_qk={ch_qk}"
|
|
assert w * h == n_qk, f"w={w}, h={h}, n_qk={n_qk}"
|
|
qk1 = rearrange(qk, 'a t (h w) -> (a t) h w', h=h).contiguous()
|
|
qk_images = tutils.tensor_to_image(qk1, ch_qk, heads_qk, color, False)
|
|
if add_average:
|
|
# shape = (ch, h, w)
|
|
qk_avg = torch.mean(
|
|
rearrange(qk, 'a t (h w) -> a t h w', h=h).contiguous(),
|
|
0
|
|
)
|
|
qk_avg_image = tutils.tensor_to_image(qk_avg, ch_qk, 1, color, False)
|
|
qk_images = qk_avg_image + qk_images
|
|
images.extend(qk_images)
|
|
del qk1
|
|
|
|
if "V*Q*K" in self.features_to_save:
|
|
vqk = feature.vqk
|
|
n_vqk, ch_vqk = vqk.shape
|
|
assert w * h == n_vqk, f"w={w}, h={h}, n_vqk={n_vqk}"
|
|
assert ch == ch_vqk, f"ch={ch}, ch_vqk={ch_vqk}"
|
|
vqk1 = rearrange(vqk, '(h w) c -> c h w', h=h).contiguous()
|
|
vqk_images = tutils.tensor_to_grid_images(vqk1, layer, width, height, color, add_average)
|
|
images.extend(vqk_images)
|
|
del vqk1
|
|
|
|
return images
|
|
|
|
def save_features(self, feature: AttnFeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, path: str, basename: str):
|
|
w, h, ch = get_shape(layer, width, height)
|
|
qk = rearrange(feature.qk, 'a t (h w) -> (a t) h w', h=h).contiguous()
|
|
tutils.save_tensor(qk, path, basename)
|
|
|
|
def get_shape(layer: str, width: int, height: int):
|
|
assert layer in layerinfo.Settings
|
|
(ich, ih, iw), (och, oh, ow) = layerinfo.Settings[layer]
|
|
nw, nh = [max(1, math.ceil(x / 64)) for x in [width, height]]
|
|
return iw*nw, ih*nh, och
|
|
|
|
def get_unet_attn_layers(unet, layername: str):
|
|
unet_block = get_unet_layer(unet, layername)
|
|
|
|
def each_transformer(unet_block):
|
|
for block in unet_block.children():
|
|
if isinstance(block, SpatialTransformer):
|
|
yield block
|
|
|
|
def each_basic_block(trans):
|
|
for block in trans.children():
|
|
if isinstance(block, BasicTransformerBlock):
|
|
yield block
|
|
|
|
for n, trans in enumerate(each_transformer(unet_block)):
|
|
for depth, basic_block in enumerate(each_basic_block(trans.transformer_blocks)):
|
|
attn1: CrossAttention|MemoryEfficientCrossAttention
|
|
attn2: CrossAttention|MemoryEfficientCrossAttention
|
|
|
|
attn1, attn2 = basic_block.attn1, basic_block.attn2
|
|
assert isinstance(attn1, CrossAttention) or isinstance(attn1, MemoryEfficientCrossAttention)
|
|
assert isinstance(attn2, CrossAttention) or isinstance(attn2, MemoryEfficientCrossAttention)
|
|
|
|
yield n, depth, basic_block, attn1, attn2
|
|
|
|
def shape(t: Tensor|None) -> tuple|None:
|
|
return tuple(t.shape) if t is not None else None
|