add AttentionExtractor
parent
b1bc99506b
commit
15b425c5a0
|
|
@ -19,6 +19,11 @@ onUiUpdate(() => {
|
||||||
'#dumpunet-{}-features-steps': 'Steps which U-Net features should be extracted. See tooltip for notations',
|
'#dumpunet-{}-features-steps': 'Steps which U-Net features should be extracted. See tooltip for notations',
|
||||||
'#dumpunet-{}-features-dumppath': 'Raw binary files are dumped to here, one image per step per layer.',
|
'#dumpunet-{}-features-dumppath': 'Raw binary files are dumped to here, one image per step per layer.',
|
||||||
|
|
||||||
|
'#dumpunet-{}-attention-checkbox': 'Extract attention layer\'s features and add their maps to output images.',
|
||||||
|
'#dumpunet-{}-attention-layer': 'U-Net layers <code>(IN00-IN11, M00, OUT00-OUT11)</code> which features should be extracted. See tooltip for notations.',
|
||||||
|
'#dumpunet-{}-attention-steps': 'Steps which features should be extracted. See tooltip for notations',
|
||||||
|
'#dumpunet-{}-attention-dumppath': 'Raw binary files are dumped to here, one image per step per layer.',
|
||||||
|
|
||||||
'#dumpunet-{}-layerprompt-checkbox': 'When checked, <code>(~: ... :~)</code> notation is enabled.',
|
'#dumpunet-{}-layerprompt-checkbox': 'When checked, <code>(~: ... :~)</code> notation is enabled.',
|
||||||
'#dumpunet-{}-layerprompt-diff-layer': 'Layers <code>(IN00-IN11, M00, OUT00-OUT11)</code> which features should be extracted. See tooltip for notations.',
|
'#dumpunet-{}-layerprompt-diff-layer': 'Layers <code>(IN00-IN11, M00, OUT00-OUT11)</code> which features should be extracted. See tooltip for notations.',
|
||||||
'#dumpunet-{}-layerprompt-diff-steps': 'Steps which features should be extracted. See tooltip for notations',
|
'#dumpunet-{}-layerprompt-diff-steps': 'Steps which features should be extracted. See tooltip for notations',
|
||||||
|
|
@ -28,6 +33,8 @@ onUiUpdate(() => {
|
||||||
const hints = {
|
const hints = {
|
||||||
'#dumpunet-{}-features-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n',
|
'#dumpunet-{}-features-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n',
|
||||||
'#dumpunet-{}-features-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n',
|
'#dumpunet-{}-features-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n',
|
||||||
|
'#dumpunet-{}-attention-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n',
|
||||||
|
'#dumpunet-{}-attention-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n',
|
||||||
'#dumpunet-{}-layerprompt-diff-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n',
|
'#dumpunet-{}-layerprompt-diff-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n',
|
||||||
'#dumpunet-{}-layerprompt-diff-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n',
|
'#dumpunet-{}-layerprompt-diff-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n',
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ onUiUpdate(() => {
|
||||||
|
|
||||||
const updates = [];
|
const updates = [];
|
||||||
for (let mode of ['txt2img', 'img2img']) {
|
for (let mode of ['txt2img', 'img2img']) {
|
||||||
for (let tab of ['features', 'layerprompt']) {
|
for (let tab of ['features', 'attention', 'layerprompt']) {
|
||||||
const layer_input_ele =
|
const layer_input_ele =
|
||||||
app.querySelector(`#dumpunet-${mode}-${tab}-layer textarea`)
|
app.querySelector(`#dumpunet-${mode}-${tab}-layer textarea`)
|
||||||
|| app.querySelector(`#dumpunet-${mode}-${tab}-diff-layer textarea`);
|
|| app.querySelector(`#dumpunet-${mode}-${tab}-diff-layer textarea`);
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from scripts.lib.features.extractor import FeatureExtractor
|
||||||
from scripts.lib.features.utils import feature_diff, feature_to_grid_images
|
from scripts.lib.features.utils import feature_diff, feature_to_grid_images
|
||||||
from scripts.lib.tutils import save_tensor
|
from scripts.lib.tutils import save_tensor
|
||||||
from scripts.lib.layer_prompt.prompt import LayerPrompt
|
from scripts.lib.layer_prompt.prompt import LayerPrompt
|
||||||
|
from scripts.lib.attention.extractor import AttentionExtractor
|
||||||
from scripts.lib.report import message as E
|
from scripts.lib.report import message as E
|
||||||
from scripts.lib import putils
|
from scripts.lib import putils
|
||||||
|
|
||||||
|
|
@ -46,6 +47,13 @@ class Script(scripts.Script):
|
||||||
result.unet.dump.enabled,
|
result.unet.dump.enabled,
|
||||||
result.unet.dump.path,
|
result.unet.dump.path,
|
||||||
|
|
||||||
|
result.attn.enabled,
|
||||||
|
result.attn.settings.layers,
|
||||||
|
result.attn.settings.steps,
|
||||||
|
result.attn.settings.color,
|
||||||
|
result.attn.dump.enabled,
|
||||||
|
result.attn.dump.path,
|
||||||
|
|
||||||
result.lp.enabled,
|
result.lp.enabled,
|
||||||
result.lp.diff_enabled,
|
result.lp.diff_enabled,
|
||||||
result.lp.diff_settings.layers,
|
result.lp.diff_settings.layers,
|
||||||
|
|
@ -95,6 +103,13 @@ class Script(scripts.Script):
|
||||||
path_on: bool,
|
path_on: bool,
|
||||||
path: str,
|
path: str,
|
||||||
|
|
||||||
|
attn_enabled: bool,
|
||||||
|
attn_layers: str,
|
||||||
|
attn_steps: str,
|
||||||
|
attn_color: bool,
|
||||||
|
attn_path_on: bool,
|
||||||
|
attn_path: str,
|
||||||
|
|
||||||
layerprompt_enabled: bool,
|
layerprompt_enabled: bool,
|
||||||
layerprompt_diff_enabled: bool,
|
layerprompt_diff_enabled: bool,
|
||||||
lp_diff_layers: str,
|
lp_diff_layers: str,
|
||||||
|
|
@ -106,7 +121,7 @@ class Script(scripts.Script):
|
||||||
debug: bool,
|
debug: bool,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not unet_features_enabled and not layerprompt_enabled:
|
if not unet_features_enabled and not attn_enabled and not layerprompt_enabled:
|
||||||
return process_images(p)
|
return process_images(p)
|
||||||
|
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
|
@ -134,6 +149,15 @@ class Script(scripts.Script):
|
||||||
layerprompt_enabled,
|
layerprompt_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
at = AttentionExtractor(
|
||||||
|
self,
|
||||||
|
attn_enabled,
|
||||||
|
p.steps,
|
||||||
|
attn_layers,
|
||||||
|
attn_steps,
|
||||||
|
attn_path if attn_path_on else None
|
||||||
|
)
|
||||||
|
|
||||||
if layerprompt_diff_enabled:
|
if layerprompt_diff_enabled:
|
||||||
fix_seed(p)
|
fix_seed(p)
|
||||||
|
|
||||||
|
|
@ -142,15 +166,17 @@ class Script(scripts.Script):
|
||||||
|
|
||||||
# layer prompt disabled
|
# layer prompt disabled
|
||||||
lp0 = LayerPrompt(self, layerprompt_enabled, remove_layer_prompts=True)
|
lp0 = LayerPrompt(self, layerprompt_enabled, remove_layer_prompts=True)
|
||||||
proc1 = exec(p1, lp0, [ex, exlp])
|
proc1 = exec(p1, lp0, [ex, exlp, at])
|
||||||
features1 = ex.extracted_features
|
features1 = ex.extracted_features
|
||||||
diff1 = exlp.extracted_features
|
diff1 = exlp.extracted_features
|
||||||
proc1 = ex.add_images(p1, proc1, features1, color)
|
proc1 = ex.add_images(p1, proc1, features1, color)
|
||||||
|
proc1 = at.add_images(p1, proc1, at.extracted_features, attn_color)
|
||||||
# layer prompt enabled
|
# layer prompt enabled
|
||||||
proc2 = exec(p2, lp, [ex, exlp])
|
proc2 = exec(p2, lp, [ex, exlp, at])
|
||||||
features2 = ex.extracted_features
|
features2 = ex.extracted_features
|
||||||
diff2 = exlp.extracted_features
|
diff2 = exlp.extracted_features
|
||||||
proc2 = ex.add_images(p2, proc2, features2, color)
|
proc2 = ex.add_images(p2, proc2, features2, color)
|
||||||
|
proc2 = at.add_images(p2, proc2, at.extracted_features, attn_color)
|
||||||
|
|
||||||
assert len(proc1.images) == len(proc2.images)
|
assert len(proc1.images) == len(proc2.images)
|
||||||
|
|
||||||
|
|
@ -175,10 +201,9 @@ class Script(scripts.Script):
|
||||||
save_tensor(tensor, diff_path, basename)
|
save_tensor(tensor, diff_path, basename)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
proc = exec(p, lp, [ex])
|
proc = exec(p, lp, [ex, at])
|
||||||
features = ex.extracted_features
|
proc = ex.add_images(p, proc, ex.extracted_features, color)
|
||||||
if unet_features_enabled:
|
proc = at.add_images(p, proc, at.extracted_features, attn_color)
|
||||||
proc = ex.add_images(p, proc, features, color)
|
|
||||||
|
|
||||||
return proc
|
return proc
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,180 @@
|
||||||
|
import math
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from torch import nn, Tensor, einsum
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock, CrossAttention, MemoryEfficientCrossAttention # type: ignore
|
||||||
|
from modules.processing import StableDiffusionProcessing
|
||||||
|
from modules.hypernetworks import hypernetwork
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
from scripts.lib.feature_extractor import FeatureExtractorBase
|
||||||
|
from scripts.lib.features.featureinfo import MultiImageFeatures
|
||||||
|
from scripts.lib.features.extractor import get_unet_layer
|
||||||
|
from scripts.lib.attention.featureinfo import AttnFeatureInfo
|
||||||
|
from scripts.lib import layerinfo, tutils
|
||||||
|
from scripts.lib.utils import *
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from scripts.dumpunet import Script
|
||||||
|
|
||||||
|
class AttentionExtractor(FeatureExtractorBase):
|
||||||
|
|
||||||
|
# image_index -> step -> Features
|
||||||
|
extracted_features: MultiImageFeatures[AttnFeatureInfo]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
runner: "Script",
|
||||||
|
enabled: bool,
|
||||||
|
total_steps: int,
|
||||||
|
layer_input: str,
|
||||||
|
step_input: str,
|
||||||
|
path: str|None,
|
||||||
|
):
|
||||||
|
super().__init__(runner, enabled, total_steps, layer_input, step_input, path)
|
||||||
|
self.extracted_features = MultiImageFeatures()
|
||||||
|
|
||||||
|
def hook_unet(self, p: StableDiffusionProcessing, unet: nn.Module):
|
||||||
|
|
||||||
|
def create_hook(layername: str, block: BasicTransformerBlock, n: int, depth: int, c: int):
|
||||||
|
|
||||||
|
def forward(module, fn, x, context=None, *args, **kwargs):
|
||||||
|
result = fn(x, context=context, *args, **kwargs)
|
||||||
|
|
||||||
|
if self.steps_on_batch in self.steps:
|
||||||
|
if c == 2:
|
||||||
|
# process for only cross-attention
|
||||||
|
self.log(f"{self.steps_on_batch:>03} {layername}-{n}-{depth}-attn{c} ({'cross' if (block.disable_self_attn or 1 < c) else 'self'})")
|
||||||
|
self.log(f" | {shape(x),shape(context)} -> {shape(result)}")
|
||||||
|
|
||||||
|
qks, vqks = self.process_attention(module, x, context)
|
||||||
|
# qk := (batch, head, token, height*width)
|
||||||
|
# vqk := (batch, height*width, ch)
|
||||||
|
|
||||||
|
images_per_batch = qks.shape[0] // 2
|
||||||
|
assert qks.shape[0] == vqks.shape[0]
|
||||||
|
assert qks.shape[0] % 2 == 0
|
||||||
|
|
||||||
|
for image_index, (vk, vqk) in enumerate(
|
||||||
|
zip(qks[:images_per_batch], vqks[:images_per_batch]),
|
||||||
|
(self.batch_num-1) * images_per_batch
|
||||||
|
):
|
||||||
|
features = self.extracted_features[image_index][self.steps_on_batch]
|
||||||
|
features.add(
|
||||||
|
layername,
|
||||||
|
AttnFeatureInfo(vk, vqk)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
self.log(f"Attention: hooking {layer}...")
|
||||||
|
for n, d, block, attn1, attn2 in get_unet_attn_layers(unet, layer):
|
||||||
|
self.hook_forward(attn1, create_hook(layer, block, n, d, 1))
|
||||||
|
self.hook_forward(attn2, create_hook(layer, block, n, d, 2))
|
||||||
|
|
||||||
|
return super().hook_unet(p, unet)
|
||||||
|
|
||||||
|
def process_attention(self, module, x, context):
|
||||||
|
# q_in : unet features ([2, 4096, 320])
|
||||||
|
# k_in, v_in : embedding vector kv (cross-attention) ([2, 77, 320]) or unet features kv (self-attention) ([2, 4096, 320])
|
||||||
|
# q,k,v : head-separated q_in, k_in and v_in
|
||||||
|
|
||||||
|
ctx_k, ctx_v = hypernetwork.apply_hypernetwork(
|
||||||
|
shared.loaded_hypernetwork,
|
||||||
|
context if context is not None else x
|
||||||
|
)
|
||||||
|
|
||||||
|
q_in = module.to_q(x)
|
||||||
|
k_in = module.to_k(ctx_k)
|
||||||
|
v_in = module.to_v(ctx_v)
|
||||||
|
|
||||||
|
q: Tensor
|
||||||
|
k: Tensor
|
||||||
|
v: Tensor
|
||||||
|
|
||||||
|
q, k, v = map( # type: ignore
|
||||||
|
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=module.heads),
|
||||||
|
(q_in, k_in, v_in)
|
||||||
|
)
|
||||||
|
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * module.scale
|
||||||
|
sim = sim.softmax(dim=-1)
|
||||||
|
# sim.shape == '(b h) i j'
|
||||||
|
|
||||||
|
o_in = einsum('b i j, b j d -> b i d', sim, v)
|
||||||
|
o = rearrange(o_in, '(b h) n d -> b n (h d)', h=module.heads)
|
||||||
|
|
||||||
|
qk: Tensor = rearrange(sim, '(b h) d t -> b h t d', h=module.heads).detach().clone()
|
||||||
|
vqk: Tensor = o.detach().clone()
|
||||||
|
|
||||||
|
self.log(f" | q: {shape(q_in)} # {shape(q)}")
|
||||||
|
self.log(f" | k: {shape(k_in)} # {shape(k)}")
|
||||||
|
self.log(f" | v: {shape(v_in)} # {shape(v)}")
|
||||||
|
#self.log(f" | qk: {shape(qk)} # {shape(sim)}")
|
||||||
|
#self.log(f" | vqk: {shape(vqk)}")
|
||||||
|
|
||||||
|
del q_in, k_in, v_in, q, k, v, sim, o_in, o
|
||||||
|
|
||||||
|
return qk, vqk
|
||||||
|
|
||||||
|
def feature_to_grid_images(self, feature: AttnFeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, color: bool):
|
||||||
|
#return feature_to_grid_images(feature, layer, width, height, color)
|
||||||
|
w, h, ch = get_shape(layer, width, height)
|
||||||
|
# qk
|
||||||
|
qk = feature.qk
|
||||||
|
heads_qk, ch_qk, n_qk = qk.shape
|
||||||
|
assert ch_qk == 77
|
||||||
|
assert w * h == n_qk, f"w={w}, h={h}, n_qk={n_qk}"
|
||||||
|
qk1 = rearrange(qk, 'a t (h w) -> (a t) h w', h=h).contiguous()
|
||||||
|
# vqk
|
||||||
|
vqk = feature.vqk
|
||||||
|
n_vqk, ch_vqk = vqk.shape
|
||||||
|
assert w * h == n_vqk, f"w={w}, h={h}, n_qk={n_vqk}"
|
||||||
|
assert ch == ch_vqk, f"ch={ch}, ch_vqk={ch_vqk}"
|
||||||
|
vqk1 = rearrange(vqk, '(h w) c -> c h w', h=h).contiguous()
|
||||||
|
|
||||||
|
#print(img_idx, step, layer, qk1.shape, vqk1.shape)
|
||||||
|
return tutils.tensor_to_image(qk1, ch_qk, heads_qk, color)
|
||||||
|
|
||||||
|
def save_features(self, feature: AttnFeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, path: str, basename: str):
|
||||||
|
w, h, ch = get_shape(layer, width, height)
|
||||||
|
qk = rearrange(feature.qk, 'a t (h w) -> (a t) h w', h=h).contiguous()
|
||||||
|
tutils.save_tensor(qk, path, basename)
|
||||||
|
|
||||||
|
def get_shape(layer: str, width: int, height: int):
|
||||||
|
assert layer in layerinfo.Settings
|
||||||
|
(ich, ih, iw), (och, oh, ow) = layerinfo.Settings[layer]
|
||||||
|
nw, nh = [max(1, math.ceil(x / 64)) for x in [width, height]]
|
||||||
|
return iw*nw, ih*nh, och
|
||||||
|
|
||||||
|
def get_unet_attn_layers(unet, layername: str):
|
||||||
|
unet_block = get_unet_layer(unet, layername)
|
||||||
|
|
||||||
|
def each_transformer(unet_block):
|
||||||
|
for block in unet_block.children():
|
||||||
|
if isinstance(block, SpatialTransformer):
|
||||||
|
yield block
|
||||||
|
|
||||||
|
def each_basic_block(trans):
|
||||||
|
for block in trans.children():
|
||||||
|
if isinstance(block, BasicTransformerBlock):
|
||||||
|
yield block
|
||||||
|
|
||||||
|
for n, trans in enumerate(each_transformer(unet_block)):
|
||||||
|
for depth, basic_block in enumerate(each_basic_block(trans.transformer_blocks)):
|
||||||
|
attn1: CrossAttention|MemoryEfficientCrossAttention
|
||||||
|
attn2: CrossAttention|MemoryEfficientCrossAttention
|
||||||
|
|
||||||
|
attn1, attn2 = basic_block.attn1, basic_block.attn2
|
||||||
|
assert isinstance(attn1, CrossAttention) or isinstance(attn1, MemoryEfficientCrossAttention)
|
||||||
|
assert isinstance(attn2, CrossAttention) or isinstance(attn2, MemoryEfficientCrossAttention)
|
||||||
|
|
||||||
|
yield n, depth, basic_block, attn1, attn2
|
||||||
|
|
||||||
|
def shape(t: Tensor|None) -> tuple|None:
|
||||||
|
return tuple(t.shape) if t is not None else None
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AttnFeatureInfo:
|
||||||
|
qk: Tensor
|
||||||
|
vqk: Tensor
|
||||||
|
|
@ -78,6 +78,14 @@ class UNet:
|
||||||
dump: DumpSetting
|
dump: DumpSetting
|
||||||
info: Info
|
info: Info
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Attn:
|
||||||
|
tab: Tab
|
||||||
|
enabled: Checkbox
|
||||||
|
settings: OutputSetting
|
||||||
|
dump: DumpSetting
|
||||||
|
info: Info
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LayerPrompt:
|
class LayerPrompt:
|
||||||
tab: Tab
|
tab: Tab
|
||||||
|
|
@ -95,6 +103,7 @@ class Debug:
|
||||||
@dataclass
|
@dataclass
|
||||||
class UI:
|
class UI:
|
||||||
unet: UNet
|
unet: UNet
|
||||||
|
attn: Attn
|
||||||
lp: LayerPrompt
|
lp: LayerPrompt
|
||||||
debug: Debug
|
debug: Debug
|
||||||
|
|
||||||
|
|
@ -109,6 +118,7 @@ class UI:
|
||||||
with Group(elem_id=id("ui")):
|
with Group(elem_id=id("ui")):
|
||||||
result = UI(
|
result = UI(
|
||||||
build_unet(id),
|
build_unet(id),
|
||||||
|
build_attn(id),
|
||||||
build_layerprompt(id),
|
build_layerprompt(id),
|
||||||
build_debug(runner, id),
|
build_debug(runner, id),
|
||||||
)
|
)
|
||||||
|
|
@ -140,6 +150,31 @@ def build_unet(id_: Callable[[str],str]):
|
||||||
info
|
info
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def build_attn(id_: Callable[[str],str]):
|
||||||
|
id = lambda s: id_(f"attention-{s}")
|
||||||
|
|
||||||
|
with Tab("Attention", elem_id=id("tab")) as tab:
|
||||||
|
enabled = Checkbox(
|
||||||
|
label="Extract attention layers' features",
|
||||||
|
value=False,
|
||||||
|
elem_id=id("checkbox")
|
||||||
|
)
|
||||||
|
|
||||||
|
settings = OutputSetting.build(id)
|
||||||
|
|
||||||
|
with Accordion(label="Dump Setting", open=False):
|
||||||
|
dump = DumpSetting.build("Dump feature tensors to files", id)
|
||||||
|
|
||||||
|
info = build_info(id)
|
||||||
|
|
||||||
|
return Attn(
|
||||||
|
tab,
|
||||||
|
enabled,
|
||||||
|
settings,
|
||||||
|
dump,
|
||||||
|
info
|
||||||
|
)
|
||||||
|
|
||||||
def build_layerprompt(id_: Callable[[str],str]):
|
def build_layerprompt(id_: Callable[[str],str]):
|
||||||
id = lambda s: id_(f"layerprompt-{s}")
|
id = lambda s: id_(f"layerprompt-{s}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,12 +12,34 @@ from scripts.lib.report import message as E
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from scripts.dumpunet import Script
|
from scripts.dumpunet import Script
|
||||||
|
|
||||||
|
class ForwardHook:
|
||||||
|
|
||||||
|
def __init__(self, module: nn.Module, fn: Callable[[nn.Module, Callable[..., Any]], Any]):
|
||||||
|
self.o = module.forward
|
||||||
|
self.fn = fn
|
||||||
|
self.module = module
|
||||||
|
self.module.forward = self.forward
|
||||||
|
|
||||||
|
def remove(self):
|
||||||
|
if self.module is not None and self.o is not None:
|
||||||
|
self.module.forward = self.o
|
||||||
|
self.module = None
|
||||||
|
self.o = None
|
||||||
|
self.fn = None
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.module is not None and self.o is not None:
|
||||||
|
if self.fn is not None:
|
||||||
|
return self.fn(self.module, self.o, *args, **kwargs)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ExtractorBase:
|
class ExtractorBase:
|
||||||
|
|
||||||
def __init__(self, runner: "Script", enabled: bool):
|
def __init__(self, runner: "Script", enabled: bool):
|
||||||
self._runner = runner
|
self._runner = runner
|
||||||
self._enabled = enabled
|
self._enabled = enabled
|
||||||
self._handles: list[RemovableHandle] = []
|
self._handles: list[RemovableHandle|ForwardHook] = []
|
||||||
self._batch_num = 0
|
self._batch_num = 0
|
||||||
self._steps_on_batch = 0
|
self._steps_on_batch = 0
|
||||||
|
|
||||||
|
|
@ -152,6 +174,15 @@ class ExtractorBase:
|
||||||
assert isinstance(module, nn.Module)
|
assert isinstance(module, nn.Module)
|
||||||
self._handles.append(module.register_forward_pre_hook(fn))
|
self._handles.append(module.register_forward_pre_hook(fn))
|
||||||
|
|
||||||
|
def hook_forward(
|
||||||
|
self,
|
||||||
|
module: nn.Module|Any,
|
||||||
|
fn: Callable[..., Any]
|
||||||
|
):
|
||||||
|
assert module is not None
|
||||||
|
assert isinstance(module, nn.Module)
|
||||||
|
self._handles.append(ForwardHook(module, fn))
|
||||||
|
|
||||||
def log(self, msg: str):
|
def log(self, msg: str):
|
||||||
if self._runner.debug:
|
if self._runner.debug:
|
||||||
print(E(msg), file=sys.stderr)
|
print(E(msg), file=sys.stderr)
|
||||||
|
|
|
||||||
|
|
@ -133,22 +133,22 @@ class FeatureExtractorBase(Generic[TInfo], ExtractorBase):
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
canvases = self.feature_to_grid_images(feature, layer, p.width, p.height, color)
|
canvases = self.feature_to_grid_images(feature, layer, idx, step, p.width, p.height, color)
|
||||||
for canvas in canvases:
|
for canvas in canvases:
|
||||||
builder.add(canvas, *args, {"Layer Name": layer, "Feature Steps": step})
|
builder.add(canvas, *args, {"Layer Name": layer, "Feature Steps": step})
|
||||||
|
|
||||||
if self.path is not None:
|
if self.path is not None:
|
||||||
basename = f"{idx:03}-{layer}-{step:03}-{{ch:04}}-{t0}"
|
basename = f"{idx:03}-{layer}-{step:03}-{{ch:04}}-{t0}"
|
||||||
self.save_features(feature, self.path, basename)
|
self.save_features(feature, layer, idx, step, p.width, p.height, self.path, basename)
|
||||||
|
|
||||||
shared.total_tqdm.update()
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
return builder.to_proc(p, proc)
|
return builder.to_proc(p, proc)
|
||||||
|
|
||||||
def feature_to_grid_images(self, feature: TInfo, layer: str, width: int, height: int, color: bool):
|
def feature_to_grid_images(self, feature: TInfo, layer: str, img_idx: int, step: int, width: int, height: int, color: bool):
|
||||||
raise NotImplementedError(f"{self.__class__.__name__}.feature_to_grid_images")
|
raise NotImplementedError(f"{self.__class__.__name__}.feature_to_grid_images")
|
||||||
|
|
||||||
def save_features(self, feature: TInfo, path: str, basename: str):
|
def save_features(self, feature: TInfo, layer: str, img_idx: int, step: int, width: int, height: int, path: str, basename: str):
|
||||||
raise NotImplementedError(f"{self.__class__.__name__}.save_features")
|
raise NotImplementedError(f"{self.__class__.__name__}.save_features")
|
||||||
|
|
||||||
def _fixup(self, proc: Processed):
|
def _fixup(self, proc: Processed):
|
||||||
|
|
|
||||||
|
|
@ -69,10 +69,10 @@ class FeatureExtractor(FeatureExtractorBase[FeatureInfo]):
|
||||||
target = get_unet_layer(unet, layer)
|
target = get_unet_layer(unet, layer)
|
||||||
self.hook_layer(target, create_hook(layer))
|
self.hook_layer(target, create_hook(layer))
|
||||||
|
|
||||||
def feature_to_grid_images(self, feature: FeatureInfo, layer: str, width: int, height: int, color: bool):
|
def feature_to_grid_images(self, feature: FeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, color: bool):
|
||||||
return feature_to_grid_images(feature, layer, width, height, color)
|
return feature_to_grid_images(feature, layer, width, height, color)
|
||||||
|
|
||||||
def save_features(self, feature: FeatureInfo, path: str, basename: str):
|
def save_features(self, feature: FeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, path: str, basename: str):
|
||||||
save_features(feature, path, basename)
|
save_features(feature, path, basename)
|
||||||
|
|
||||||
def get_unet_layer(unet, layername: str) -> nn.modules.Module:
|
def get_unet_layer(unet, layername: str) -> nn.modules.Module:
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,9 @@
|
||||||
import math
|
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from scripts.lib import tutils
|
from scripts.lib import tutils
|
||||||
from scripts.lib import layerinfo
|
from scripts.lib.features.featureinfo import FeatureInfo, MultiImageFeatures
|
||||||
from scripts.lib.features.featureinfo import FeatureInfo, Features, MultiImageFeatures
|
|
||||||
from scripts.lib.report import message as E
|
|
||||||
|
|
||||||
def feature_diff(
|
def feature_diff(
|
||||||
features1: MultiImageFeatures[FeatureInfo],
|
features1: MultiImageFeatures[FeatureInfo],
|
||||||
|
|
@ -55,10 +52,8 @@ def feature_to_grid_images(
|
||||||
if isinstance(feature, FeatureInfo):
|
if isinstance(feature, FeatureInfo):
|
||||||
tensor = feature.output
|
tensor = feature.output
|
||||||
assert isinstance(tensor, Tensor)
|
assert isinstance(tensor, Tensor)
|
||||||
assert len(tensor.size()) == 3
|
|
||||||
|
|
||||||
grid_x, grid_y = _get_grid_num(layer, width, height)
|
canvases = tutils.tensor_to_grid_images(tensor, layer, width, height, color)
|
||||||
canvases = tutils.tensor_to_image(tensor, grid_x, grid_y, color)
|
|
||||||
return canvases
|
return canvases
|
||||||
|
|
||||||
def save_features(
|
def save_features(
|
||||||
|
|
@ -67,30 +62,3 @@ def save_features(
|
||||||
basename: str
|
basename: str
|
||||||
):
|
):
|
||||||
tutils.save_tensor(feature.output, save_dir, basename)
|
tutils.save_tensor(feature.output, save_dir, basename)
|
||||||
|
|
||||||
def _get_grid_num(layer: str, width: int, height: int):
|
|
||||||
assert layer is not None and layer != "", E("<Layers> must not be empty.")
|
|
||||||
assert layer in layerinfo.Settings, E(f"Invalid <Layers> value: {layer}.")
|
|
||||||
_, (ch, mh, mw) = layerinfo.Settings[layer]
|
|
||||||
iw = math.ceil(width / 64)
|
|
||||||
ih = math.ceil(height / 64)
|
|
||||||
w = mw * iw
|
|
||||||
h = mh * ih
|
|
||||||
# w : width of a feature map
|
|
||||||
# h : height of a feature map
|
|
||||||
# ch: a number of a feature map
|
|
||||||
n = [w, h]
|
|
||||||
while ch % 2 == 0:
|
|
||||||
n[n[0]>n[1]] *= 2
|
|
||||||
ch //= 2
|
|
||||||
n[n[0]>n[1]] *= ch
|
|
||||||
if n[0] > n[1]:
|
|
||||||
while n[0] > n[1] * 2 and (n[0] // w) % 2 == 0:
|
|
||||||
n[0] //= 2
|
|
||||||
n[1] *= 2
|
|
||||||
else:
|
|
||||||
while n[0] * 2 < n[1] and (n[1] // h) % 2 == 0:
|
|
||||||
n[0] *= 2
|
|
||||||
n[1] //= 2
|
|
||||||
|
|
||||||
return n[0] // w, n[1] // h
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -6,6 +7,20 @@ from PIL import Image
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
from scripts.lib import layerinfo
|
||||||
|
from scripts.lib.report import message as E
|
||||||
|
|
||||||
|
def tensor_to_grid_images(
|
||||||
|
tensor: Tensor,
|
||||||
|
layer: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
color: bool
|
||||||
|
):
|
||||||
|
grid_x, grid_y = get_grid_num(layer, width, height)
|
||||||
|
canvases = tensor_to_image(tensor, grid_x, grid_y, color)
|
||||||
|
return canvases
|
||||||
|
|
||||||
def tensor_to_image(
|
def tensor_to_image(
|
||||||
tensor: Tensor,
|
tensor: Tensor,
|
||||||
grid_x: int,
|
grid_x: int,
|
||||||
|
|
@ -93,3 +108,30 @@ def _tensor_to_image(array: np.ndarray, color: bool):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return np.clip(np.abs(array) * 256, 0, 255).astype(np.uint8)
|
return np.clip(np.abs(array) * 256, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
def get_grid_num(layer: str, width: int, height: int):
|
||||||
|
assert layer is not None and layer != "", E("<Layers> must not be empty.")
|
||||||
|
assert layer in layerinfo.Settings, E(f"Invalid <Layers> value: {layer}.")
|
||||||
|
_, (ch, mh, mw) = layerinfo.Settings[layer]
|
||||||
|
iw = math.ceil(width / 64)
|
||||||
|
ih = math.ceil(height / 64)
|
||||||
|
w = mw * iw
|
||||||
|
h = mh * ih
|
||||||
|
# w : width of a feature map
|
||||||
|
# h : height of a feature map
|
||||||
|
# ch: a number of a feature map
|
||||||
|
n = [w, h]
|
||||||
|
while ch % 2 == 0:
|
||||||
|
n[n[0]>n[1]] *= 2
|
||||||
|
ch //= 2
|
||||||
|
n[n[0]>n[1]] *= ch
|
||||||
|
if n[0] > n[1]:
|
||||||
|
while n[0] > n[1] * 2 and (n[0] // w) % 2 == 0:
|
||||||
|
n[0] //= 2
|
||||||
|
n[1] *= 2
|
||||||
|
else:
|
||||||
|
while n[0] * 2 < n[1] and (n[1] // h) % 2 == 0:
|
||||||
|
n[0] *= 2
|
||||||
|
n[1] //= 2
|
||||||
|
|
||||||
|
return n[0] // w, n[1] // h
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@
|
||||||
#dumpunet-img2img-features-layerinfo,
|
#dumpunet-img2img-features-layerinfo,
|
||||||
#dumpunet-txt2img-layerprompt-layerinfo,
|
#dumpunet-txt2img-layerprompt-layerinfo,
|
||||||
#dumpunet-img2img-layerprompt-layerinfo,
|
#dumpunet-img2img-layerprompt-layerinfo,
|
||||||
|
#dumpunet-txt2img-attention-layerinfo,
|
||||||
|
#dumpunet-img2img-attention-layerinfo,
|
||||||
#dumpunet-txt2img-layerprompt-errors,
|
#dumpunet-txt2img-layerprompt-errors,
|
||||||
#dumpunet-img2img-layerprompt-errors {
|
#dumpunet-img2img-layerprompt-errors {
|
||||||
font-family: monospace;
|
font-family: monospace;
|
||||||
|
|
@ -9,6 +11,8 @@
|
||||||
|
|
||||||
#dumpunet-txt2img-features-checkbox,
|
#dumpunet-txt2img-features-checkbox,
|
||||||
#dumpunet-img2img-features-checkbox,
|
#dumpunet-img2img-features-checkbox,
|
||||||
|
#dumpunet-txt2img-attention-checkbox,
|
||||||
|
#dumpunet-img2img-attention-checkbox,
|
||||||
#dumpunet-txt2img-layerprompt-checkbox,
|
#dumpunet-txt2img-layerprompt-checkbox,
|
||||||
#dumpunet-img2img-layerprompt-checkbox {
|
#dumpunet-img2img-layerprompt-checkbox {
|
||||||
/* background-color: #fff8f0; */
|
/* background-color: #fff8f0; */
|
||||||
|
|
@ -17,6 +21,8 @@
|
||||||
|
|
||||||
.dark #dumpunet-txt2img-features-checkbox,
|
.dark #dumpunet-txt2img-features-checkbox,
|
||||||
.dark #dumpunet-img2img-features-checkbox,
|
.dark #dumpunet-img2img-features-checkbox,
|
||||||
|
.dark #dumpunet-txt2img-attention-checkbox,
|
||||||
|
.dark #dumpunet-img2img-attention-checkbox,
|
||||||
.dark #dumpunet-txt2img-layerprompt-checkbox,
|
.dark #dumpunet-txt2img-layerprompt-checkbox,
|
||||||
.dark #dumpunet-img2img-layerprompt-checkbox {
|
.dark #dumpunet-img2img-layerprompt-checkbox {
|
||||||
/*background-color: inherit;
|
/*background-color: inherit;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue