add AttentionExtractor

master
hnmr293 2023-01-21 01:11:13 +09:00
parent b1bc99506b
commit 15b425c5a0
12 changed files with 351 additions and 49 deletions

View File

@ -19,6 +19,11 @@ onUiUpdate(() => {
'#dumpunet-{}-features-steps': 'Steps which U-Net features should be extracted. See tooltip for notations',
'#dumpunet-{}-features-dumppath': 'Raw binary files are dumped to here, one image per step per layer.',
'#dumpunet-{}-attention-checkbox': 'Extract attention layer\'s features and add their maps to output images.',
'#dumpunet-{}-attention-layer': 'U-Net layers <code>(IN00-IN11, M00, OUT00-OUT11)</code> which features should be extracted. See tooltip for notations.',
'#dumpunet-{}-attention-steps': 'Steps which features should be extracted. See tooltip for notations',
'#dumpunet-{}-attention-dumppath': 'Raw binary files are dumped to here, one image per step per layer.',
'#dumpunet-{}-layerprompt-checkbox': 'When checked, <code>(~: ... :~)</code> notation is enabled.',
'#dumpunet-{}-layerprompt-diff-layer': 'Layers <code>(IN00-IN11, M00, OUT00-OUT11)</code> which features should be extracted. See tooltip for notations.',
'#dumpunet-{}-layerprompt-diff-steps': 'Steps which features should be extracted. See tooltip for notations',
@ -28,6 +33,8 @@ onUiUpdate(() => {
const hints = {
'#dumpunet-{}-features-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n',
'#dumpunet-{}-features-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n',
'#dumpunet-{}-attention-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n',
'#dumpunet-{}-attention-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n',
'#dumpunet-{}-layerprompt-diff-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n',
'#dumpunet-{}-layerprompt-diff-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n',
};

View File

@ -81,7 +81,7 @@ onUiUpdate(() => {
const updates = [];
for (let mode of ['txt2img', 'img2img']) {
for (let tab of ['features', 'layerprompt']) {
for (let tab of ['features', 'attention', 'layerprompt']) {
const layer_input_ele =
app.querySelector(`#dumpunet-${mode}-${tab}-layer textarea`)
|| app.querySelector(`#dumpunet-${mode}-${tab}-diff-layer textarea`);

View File

@ -13,6 +13,7 @@ from scripts.lib.features.extractor import FeatureExtractor
from scripts.lib.features.utils import feature_diff, feature_to_grid_images
from scripts.lib.tutils import save_tensor
from scripts.lib.layer_prompt.prompt import LayerPrompt
from scripts.lib.attention.extractor import AttentionExtractor
from scripts.lib.report import message as E
from scripts.lib import putils
@ -46,6 +47,13 @@ class Script(scripts.Script):
result.unet.dump.enabled,
result.unet.dump.path,
result.attn.enabled,
result.attn.settings.layers,
result.attn.settings.steps,
result.attn.settings.color,
result.attn.dump.enabled,
result.attn.dump.path,
result.lp.enabled,
result.lp.diff_enabled,
result.lp.diff_settings.layers,
@ -95,6 +103,13 @@ class Script(scripts.Script):
path_on: bool,
path: str,
attn_enabled: bool,
attn_layers: str,
attn_steps: str,
attn_color: bool,
attn_path_on: bool,
attn_path: str,
layerprompt_enabled: bool,
layerprompt_diff_enabled: bool,
lp_diff_layers: str,
@ -106,7 +121,7 @@ class Script(scripts.Script):
debug: bool,
):
if not unet_features_enabled and not layerprompt_enabled:
if not unet_features_enabled and not attn_enabled and not layerprompt_enabled:
return process_images(p)
self.debug = debug
@ -134,6 +149,15 @@ class Script(scripts.Script):
layerprompt_enabled,
)
at = AttentionExtractor(
self,
attn_enabled,
p.steps,
attn_layers,
attn_steps,
attn_path if attn_path_on else None
)
if layerprompt_diff_enabled:
fix_seed(p)
@ -142,15 +166,17 @@ class Script(scripts.Script):
# layer prompt disabled
lp0 = LayerPrompt(self, layerprompt_enabled, remove_layer_prompts=True)
proc1 = exec(p1, lp0, [ex, exlp])
proc1 = exec(p1, lp0, [ex, exlp, at])
features1 = ex.extracted_features
diff1 = exlp.extracted_features
proc1 = ex.add_images(p1, proc1, features1, color)
proc1 = at.add_images(p1, proc1, at.extracted_features, attn_color)
# layer prompt enabled
proc2 = exec(p2, lp, [ex, exlp])
proc2 = exec(p2, lp, [ex, exlp, at])
features2 = ex.extracted_features
diff2 = exlp.extracted_features
proc2 = ex.add_images(p2, proc2, features2, color)
proc2 = at.add_images(p2, proc2, at.extracted_features, attn_color)
assert len(proc1.images) == len(proc2.images)
@ -175,10 +201,9 @@ class Script(scripts.Script):
save_tensor(tensor, diff_path, basename)
else:
proc = exec(p, lp, [ex])
features = ex.extracted_features
if unet_features_enabled:
proc = ex.add_images(p, proc, features, color)
proc = exec(p, lp, [ex, at])
proc = ex.add_images(p, proc, ex.extracted_features, color)
proc = at.add_images(p, proc, at.extracted_features, attn_color)
return proc

View File

@ -0,0 +1,180 @@
import math
from typing import TYPE_CHECKING
from torch import nn, Tensor, einsum
from einops import rearrange
from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock, CrossAttention, MemoryEfficientCrossAttention # type: ignore
from modules.processing import StableDiffusionProcessing
from modules.hypernetworks import hypernetwork
from modules import shared
from scripts.lib.feature_extractor import FeatureExtractorBase
from scripts.lib.features.featureinfo import MultiImageFeatures
from scripts.lib.features.extractor import get_unet_layer
from scripts.lib.attention.featureinfo import AttnFeatureInfo
from scripts.lib import layerinfo, tutils
from scripts.lib.utils import *
if TYPE_CHECKING:
from scripts.dumpunet import Script
class AttentionExtractor(FeatureExtractorBase):
# image_index -> step -> Features
extracted_features: MultiImageFeatures[AttnFeatureInfo]
def __init__(
self,
runner: "Script",
enabled: bool,
total_steps: int,
layer_input: str,
step_input: str,
path: str|None,
):
super().__init__(runner, enabled, total_steps, layer_input, step_input, path)
self.extracted_features = MultiImageFeatures()
def hook_unet(self, p: StableDiffusionProcessing, unet: nn.Module):
def create_hook(layername: str, block: BasicTransformerBlock, n: int, depth: int, c: int):
def forward(module, fn, x, context=None, *args, **kwargs):
result = fn(x, context=context, *args, **kwargs)
if self.steps_on_batch in self.steps:
if c == 2:
# process for only cross-attention
self.log(f"{self.steps_on_batch:>03} {layername}-{n}-{depth}-attn{c} ({'cross' if (block.disable_self_attn or 1 < c) else 'self'})")
self.log(f" | {shape(x),shape(context)} -> {shape(result)}")
qks, vqks = self.process_attention(module, x, context)
# qk := (batch, head, token, height*width)
# vqk := (batch, height*width, ch)
images_per_batch = qks.shape[0] // 2
assert qks.shape[0] == vqks.shape[0]
assert qks.shape[0] % 2 == 0
for image_index, (vk, vqk) in enumerate(
zip(qks[:images_per_batch], vqks[:images_per_batch]),
(self.batch_num-1) * images_per_batch
):
features = self.extracted_features[image_index][self.steps_on_batch]
features.add(
layername,
AttnFeatureInfo(vk, vqk)
)
return result
return forward
for layer in self.layers:
self.log(f"Attention: hooking {layer}...")
for n, d, block, attn1, attn2 in get_unet_attn_layers(unet, layer):
self.hook_forward(attn1, create_hook(layer, block, n, d, 1))
self.hook_forward(attn2, create_hook(layer, block, n, d, 2))
return super().hook_unet(p, unet)
def process_attention(self, module, x, context):
# q_in : unet features ([2, 4096, 320])
# k_in, v_in : embedding vector kv (cross-attention) ([2, 77, 320]) or unet features kv (self-attention) ([2, 4096, 320])
# q,k,v : head-separated q_in, k_in and v_in
ctx_k, ctx_v = hypernetwork.apply_hypernetwork(
shared.loaded_hypernetwork,
context if context is not None else x
)
q_in = module.to_q(x)
k_in = module.to_k(ctx_k)
v_in = module.to_v(ctx_v)
q: Tensor
k: Tensor
v: Tensor
q, k, v = map( # type: ignore
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=module.heads),
(q_in, k_in, v_in)
)
sim = einsum('b i d, b j d -> b i j', q, k) * module.scale
sim = sim.softmax(dim=-1)
# sim.shape == '(b h) i j'
o_in = einsum('b i j, b j d -> b i d', sim, v)
o = rearrange(o_in, '(b h) n d -> b n (h d)', h=module.heads)
qk: Tensor = rearrange(sim, '(b h) d t -> b h t d', h=module.heads).detach().clone()
vqk: Tensor = o.detach().clone()
self.log(f" | q: {shape(q_in)} # {shape(q)}")
self.log(f" | k: {shape(k_in)} # {shape(k)}")
self.log(f" | v: {shape(v_in)} # {shape(v)}")
#self.log(f" | qk: {shape(qk)} # {shape(sim)}")
#self.log(f" | vqk: {shape(vqk)}")
del q_in, k_in, v_in, q, k, v, sim, o_in, o
return qk, vqk
def feature_to_grid_images(self, feature: AttnFeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, color: bool):
#return feature_to_grid_images(feature, layer, width, height, color)
w, h, ch = get_shape(layer, width, height)
# qk
qk = feature.qk
heads_qk, ch_qk, n_qk = qk.shape
assert ch_qk == 77
assert w * h == n_qk, f"w={w}, h={h}, n_qk={n_qk}"
qk1 = rearrange(qk, 'a t (h w) -> (a t) h w', h=h).contiguous()
# vqk
vqk = feature.vqk
n_vqk, ch_vqk = vqk.shape
assert w * h == n_vqk, f"w={w}, h={h}, n_qk={n_vqk}"
assert ch == ch_vqk, f"ch={ch}, ch_vqk={ch_vqk}"
vqk1 = rearrange(vqk, '(h w) c -> c h w', h=h).contiguous()
#print(img_idx, step, layer, qk1.shape, vqk1.shape)
return tutils.tensor_to_image(qk1, ch_qk, heads_qk, color)
def save_features(self, feature: AttnFeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, path: str, basename: str):
w, h, ch = get_shape(layer, width, height)
qk = rearrange(feature.qk, 'a t (h w) -> (a t) h w', h=h).contiguous()
tutils.save_tensor(qk, path, basename)
def get_shape(layer: str, width: int, height: int):
assert layer in layerinfo.Settings
(ich, ih, iw), (och, oh, ow) = layerinfo.Settings[layer]
nw, nh = [max(1, math.ceil(x / 64)) for x in [width, height]]
return iw*nw, ih*nh, och
def get_unet_attn_layers(unet, layername: str):
unet_block = get_unet_layer(unet, layername)
def each_transformer(unet_block):
for block in unet_block.children():
if isinstance(block, SpatialTransformer):
yield block
def each_basic_block(trans):
for block in trans.children():
if isinstance(block, BasicTransformerBlock):
yield block
for n, trans in enumerate(each_transformer(unet_block)):
for depth, basic_block in enumerate(each_basic_block(trans.transformer_blocks)):
attn1: CrossAttention|MemoryEfficientCrossAttention
attn2: CrossAttention|MemoryEfficientCrossAttention
attn1, attn2 = basic_block.attn1, basic_block.attn2
assert isinstance(attn1, CrossAttention) or isinstance(attn1, MemoryEfficientCrossAttention)
assert isinstance(attn2, CrossAttention) or isinstance(attn2, MemoryEfficientCrossAttention)
yield n, depth, basic_block, attn1, attn2
def shape(t: Tensor|None) -> tuple|None:
return tuple(t.shape) if t is not None else None

View File

@ -0,0 +1,8 @@
from dataclasses import dataclass
from torch import Tensor
@dataclass
class AttnFeatureInfo:
qk: Tensor
vqk: Tensor

View File

@ -78,6 +78,14 @@ class UNet:
dump: DumpSetting
info: Info
@dataclass
class Attn:
tab: Tab
enabled: Checkbox
settings: OutputSetting
dump: DumpSetting
info: Info
@dataclass
class LayerPrompt:
tab: Tab
@ -95,6 +103,7 @@ class Debug:
@dataclass
class UI:
unet: UNet
attn: Attn
lp: LayerPrompt
debug: Debug
@ -109,6 +118,7 @@ class UI:
with Group(elem_id=id("ui")):
result = UI(
build_unet(id),
build_attn(id),
build_layerprompt(id),
build_debug(runner, id),
)
@ -140,6 +150,31 @@ def build_unet(id_: Callable[[str],str]):
info
)
def build_attn(id_: Callable[[str],str]):
id = lambda s: id_(f"attention-{s}")
with Tab("Attention", elem_id=id("tab")) as tab:
enabled = Checkbox(
label="Extract attention layers' features",
value=False,
elem_id=id("checkbox")
)
settings = OutputSetting.build(id)
with Accordion(label="Dump Setting", open=False):
dump = DumpSetting.build("Dump feature tensors to files", id)
info = build_info(id)
return Attn(
tab,
enabled,
settings,
dump,
info
)
def build_layerprompt(id_: Callable[[str],str]):
id = lambda s: id_(f"layerprompt-{s}")

View File

@ -12,12 +12,34 @@ from scripts.lib.report import message as E
if TYPE_CHECKING:
from scripts.dumpunet import Script
class ForwardHook:
def __init__(self, module: nn.Module, fn: Callable[[nn.Module, Callable[..., Any]], Any]):
self.o = module.forward
self.fn = fn
self.module = module
self.module.forward = self.forward
def remove(self):
if self.module is not None and self.o is not None:
self.module.forward = self.o
self.module = None
self.o = None
self.fn = None
def forward(self, *args, **kwargs):
if self.module is not None and self.o is not None:
if self.fn is not None:
return self.fn(self.module, self.o, *args, **kwargs)
return None
class ExtractorBase:
def __init__(self, runner: "Script", enabled: bool):
self._runner = runner
self._enabled = enabled
self._handles: list[RemovableHandle] = []
self._handles: list[RemovableHandle|ForwardHook] = []
self._batch_num = 0
self._steps_on_batch = 0
@ -152,6 +174,15 @@ class ExtractorBase:
assert isinstance(module, nn.Module)
self._handles.append(module.register_forward_pre_hook(fn))
def hook_forward(
self,
module: nn.Module|Any,
fn: Callable[..., Any]
):
assert module is not None
assert isinstance(module, nn.Module)
self._handles.append(ForwardHook(module, fn))
def log(self, msg: str):
if self._runner.debug:
print(E(msg), file=sys.stderr)

View File

@ -133,22 +133,22 @@ class FeatureExtractorBase(Generic[TInfo], ExtractorBase):
if shared.state.interrupted:
break
canvases = self.feature_to_grid_images(feature, layer, p.width, p.height, color)
canvases = self.feature_to_grid_images(feature, layer, idx, step, p.width, p.height, color)
for canvas in canvases:
builder.add(canvas, *args, {"Layer Name": layer, "Feature Steps": step})
if self.path is not None:
basename = f"{idx:03}-{layer}-{step:03}-{{ch:04}}-{t0}"
self.save_features(feature, self.path, basename)
self.save_features(feature, layer, idx, step, p.width, p.height, self.path, basename)
shared.total_tqdm.update()
return builder.to_proc(p, proc)
def feature_to_grid_images(self, feature: TInfo, layer: str, width: int, height: int, color: bool):
def feature_to_grid_images(self, feature: TInfo, layer: str, img_idx: int, step: int, width: int, height: int, color: bool):
raise NotImplementedError(f"{self.__class__.__name__}.feature_to_grid_images")
def save_features(self, feature: TInfo, path: str, basename: str):
def save_features(self, feature: TInfo, layer: str, img_idx: int, step: int, width: int, height: int, path: str, basename: str):
raise NotImplementedError(f"{self.__class__.__name__}.save_features")
def _fixup(self, proc: Processed):

View File

@ -69,10 +69,10 @@ class FeatureExtractor(FeatureExtractorBase[FeatureInfo]):
target = get_unet_layer(unet, layer)
self.hook_layer(target, create_hook(layer))
def feature_to_grid_images(self, feature: FeatureInfo, layer: str, width: int, height: int, color: bool):
def feature_to_grid_images(self, feature: FeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, color: bool):
return feature_to_grid_images(feature, layer, width, height, color)
def save_features(self, feature: FeatureInfo, path: str, basename: str):
def save_features(self, feature: FeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, path: str, basename: str):
save_features(feature, path, basename)
def get_unet_layer(unet, layername: str) -> nn.modules.Module:

View File

@ -1,12 +1,9 @@
import math
from typing import Generator
from torch import Tensor
from scripts.lib import tutils
from scripts.lib import layerinfo
from scripts.lib.features.featureinfo import FeatureInfo, Features, MultiImageFeatures
from scripts.lib.report import message as E
from scripts.lib.features.featureinfo import FeatureInfo, MultiImageFeatures
def feature_diff(
features1: MultiImageFeatures[FeatureInfo],
@ -55,10 +52,8 @@ def feature_to_grid_images(
if isinstance(feature, FeatureInfo):
tensor = feature.output
assert isinstance(tensor, Tensor)
assert len(tensor.size()) == 3
grid_x, grid_y = _get_grid_num(layer, width, height)
canvases = tutils.tensor_to_image(tensor, grid_x, grid_y, color)
canvases = tutils.tensor_to_grid_images(tensor, layer, width, height, color)
return canvases
def save_features(
@ -67,30 +62,3 @@ def save_features(
basename: str
):
tutils.save_tensor(feature.output, save_dir, basename)
def _get_grid_num(layer: str, width: int, height: int):
assert layer is not None and layer != "", E("<Layers> must not be empty.")
assert layer in layerinfo.Settings, E(f"Invalid <Layers> value: {layer}.")
_, (ch, mh, mw) = layerinfo.Settings[layer]
iw = math.ceil(width / 64)
ih = math.ceil(height / 64)
w = mw * iw
h = mh * ih
# w : width of a feature map
# h : height of a feature map
# ch: a number of a feature map
n = [w, h]
while ch % 2 == 0:
n[n[0]>n[1]] *= 2
ch //= 2
n[n[0]>n[1]] *= ch
if n[0] > n[1]:
while n[0] > n[1] * 2 and (n[0] // w) % 2 == 0:
n[0] //= 2
n[1] *= 2
else:
while n[0] * 2 < n[1] and (n[1] // h) % 2 == 0:
n[0] *= 2
n[1] //= 2
return n[0] // w, n[1] // h

View File

@ -1,4 +1,5 @@
import os
import math
from torch import Tensor
import numpy as np
@ -6,6 +7,20 @@ from PIL import Image
from modules import shared
from scripts.lib import layerinfo
from scripts.lib.report import message as E
def tensor_to_grid_images(
tensor: Tensor,
layer: str,
width: int,
height: int,
color: bool
):
grid_x, grid_y = get_grid_num(layer, width, height)
canvases = tensor_to_image(tensor, grid_x, grid_y, color)
return canvases
def tensor_to_image(
tensor: Tensor,
grid_x: int,
@ -93,3 +108,30 @@ def _tensor_to_image(array: np.ndarray, color: bool):
else:
return np.clip(np.abs(array) * 256, 0, 255).astype(np.uint8)
def get_grid_num(layer: str, width: int, height: int):
assert layer is not None and layer != "", E("<Layers> must not be empty.")
assert layer in layerinfo.Settings, E(f"Invalid <Layers> value: {layer}.")
_, (ch, mh, mw) = layerinfo.Settings[layer]
iw = math.ceil(width / 64)
ih = math.ceil(height / 64)
w = mw * iw
h = mh * ih
# w : width of a feature map
# h : height of a feature map
# ch: a number of a feature map
n = [w, h]
while ch % 2 == 0:
n[n[0]>n[1]] *= 2
ch //= 2
n[n[0]>n[1]] *= ch
if n[0] > n[1]:
while n[0] > n[1] * 2 and (n[0] // w) % 2 == 0:
n[0] //= 2
n[1] *= 2
else:
while n[0] * 2 < n[1] and (n[1] // h) % 2 == 0:
n[0] *= 2
n[1] //= 2
return n[0] // w, n[1] // h

View File

@ -2,6 +2,8 @@
#dumpunet-img2img-features-layerinfo,
#dumpunet-txt2img-layerprompt-layerinfo,
#dumpunet-img2img-layerprompt-layerinfo,
#dumpunet-txt2img-attention-layerinfo,
#dumpunet-img2img-attention-layerinfo,
#dumpunet-txt2img-layerprompt-errors,
#dumpunet-img2img-layerprompt-errors {
font-family: monospace;
@ -9,6 +11,8 @@
#dumpunet-txt2img-features-checkbox,
#dumpunet-img2img-features-checkbox,
#dumpunet-txt2img-attention-checkbox,
#dumpunet-img2img-attention-checkbox,
#dumpunet-txt2img-layerprompt-checkbox,
#dumpunet-img2img-layerprompt-checkbox {
/* background-color: #fff8f0; */
@ -17,6 +21,8 @@
.dark #dumpunet-txt2img-features-checkbox,
.dark #dumpunet-img2img-features-checkbox,
.dark #dumpunet-txt2img-attention-checkbox,
.dark #dumpunet-img2img-attention-checkbox,
.dark #dumpunet-txt2img-layerprompt-checkbox,
.dark #dumpunet-img2img-layerprompt-checkbox {
/*background-color: inherit;