add AttentionExtractor

master
hnmr293 2023-01-21 01:11:13 +09:00
parent b1bc99506b
commit 15b425c5a0
12 changed files with 351 additions and 49 deletions

View File

@ -19,6 +19,11 @@ onUiUpdate(() => {
'#dumpunet-{}-features-steps': 'Steps which U-Net features should be extracted. See tooltip for notations', '#dumpunet-{}-features-steps': 'Steps which U-Net features should be extracted. See tooltip for notations',
'#dumpunet-{}-features-dumppath': 'Raw binary files are dumped to here, one image per step per layer.', '#dumpunet-{}-features-dumppath': 'Raw binary files are dumped to here, one image per step per layer.',
'#dumpunet-{}-attention-checkbox': 'Extract attention layer\'s features and add their maps to output images.',
'#dumpunet-{}-attention-layer': 'U-Net layers <code>(IN00-IN11, M00, OUT00-OUT11)</code> which features should be extracted. See tooltip for notations.',
'#dumpunet-{}-attention-steps': 'Steps which features should be extracted. See tooltip for notations',
'#dumpunet-{}-attention-dumppath': 'Raw binary files are dumped to here, one image per step per layer.',
'#dumpunet-{}-layerprompt-checkbox': 'When checked, <code>(~: ... :~)</code> notation is enabled.', '#dumpunet-{}-layerprompt-checkbox': 'When checked, <code>(~: ... :~)</code> notation is enabled.',
'#dumpunet-{}-layerprompt-diff-layer': 'Layers <code>(IN00-IN11, M00, OUT00-OUT11)</code> which features should be extracted. See tooltip for notations.', '#dumpunet-{}-layerprompt-diff-layer': 'Layers <code>(IN00-IN11, M00, OUT00-OUT11)</code> which features should be extracted. See tooltip for notations.',
'#dumpunet-{}-layerprompt-diff-steps': 'Steps which features should be extracted. See tooltip for notations', '#dumpunet-{}-layerprompt-diff-steps': 'Steps which features should be extracted. See tooltip for notations',
@ -28,6 +33,8 @@ onUiUpdate(() => {
const hints = { const hints = {
'#dumpunet-{}-features-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n', '#dumpunet-{}-features-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n',
'#dumpunet-{}-features-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n', '#dumpunet-{}-features-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n',
'#dumpunet-{}-attention-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n',
'#dumpunet-{}-attention-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n',
'#dumpunet-{}-layerprompt-diff-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n', '#dumpunet-{}-layerprompt-diff-layer textarea': 'IN00: add one layer to output\nIN00,IN01: add layers to output\nIN00-IN02: add range to output\nIN00-OUT05(+2): add range to output with specified steps\n',
'#dumpunet-{}-layerprompt-diff-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n', '#dumpunet-{}-layerprompt-diff-steps textarea': '5: extracted at steps=5\n5,10: extracted at steps=5 and steps=10\n5-10: extracted when step is in 5..10 (inclusive)\n5-10(+2): extracts when step is 5,7,9\n',
}; };

View File

@ -81,7 +81,7 @@ onUiUpdate(() => {
const updates = []; const updates = [];
for (let mode of ['txt2img', 'img2img']) { for (let mode of ['txt2img', 'img2img']) {
for (let tab of ['features', 'layerprompt']) { for (let tab of ['features', 'attention', 'layerprompt']) {
const layer_input_ele = const layer_input_ele =
app.querySelector(`#dumpunet-${mode}-${tab}-layer textarea`) app.querySelector(`#dumpunet-${mode}-${tab}-layer textarea`)
|| app.querySelector(`#dumpunet-${mode}-${tab}-diff-layer textarea`); || app.querySelector(`#dumpunet-${mode}-${tab}-diff-layer textarea`);

View File

@ -13,6 +13,7 @@ from scripts.lib.features.extractor import FeatureExtractor
from scripts.lib.features.utils import feature_diff, feature_to_grid_images from scripts.lib.features.utils import feature_diff, feature_to_grid_images
from scripts.lib.tutils import save_tensor from scripts.lib.tutils import save_tensor
from scripts.lib.layer_prompt.prompt import LayerPrompt from scripts.lib.layer_prompt.prompt import LayerPrompt
from scripts.lib.attention.extractor import AttentionExtractor
from scripts.lib.report import message as E from scripts.lib.report import message as E
from scripts.lib import putils from scripts.lib import putils
@ -46,6 +47,13 @@ class Script(scripts.Script):
result.unet.dump.enabled, result.unet.dump.enabled,
result.unet.dump.path, result.unet.dump.path,
result.attn.enabled,
result.attn.settings.layers,
result.attn.settings.steps,
result.attn.settings.color,
result.attn.dump.enabled,
result.attn.dump.path,
result.lp.enabled, result.lp.enabled,
result.lp.diff_enabled, result.lp.diff_enabled,
result.lp.diff_settings.layers, result.lp.diff_settings.layers,
@ -95,6 +103,13 @@ class Script(scripts.Script):
path_on: bool, path_on: bool,
path: str, path: str,
attn_enabled: bool,
attn_layers: str,
attn_steps: str,
attn_color: bool,
attn_path_on: bool,
attn_path: str,
layerprompt_enabled: bool, layerprompt_enabled: bool,
layerprompt_diff_enabled: bool, layerprompt_diff_enabled: bool,
lp_diff_layers: str, lp_diff_layers: str,
@ -106,7 +121,7 @@ class Script(scripts.Script):
debug: bool, debug: bool,
): ):
if not unet_features_enabled and not layerprompt_enabled: if not unet_features_enabled and not attn_enabled and not layerprompt_enabled:
return process_images(p) return process_images(p)
self.debug = debug self.debug = debug
@ -134,6 +149,15 @@ class Script(scripts.Script):
layerprompt_enabled, layerprompt_enabled,
) )
at = AttentionExtractor(
self,
attn_enabled,
p.steps,
attn_layers,
attn_steps,
attn_path if attn_path_on else None
)
if layerprompt_diff_enabled: if layerprompt_diff_enabled:
fix_seed(p) fix_seed(p)
@ -142,15 +166,17 @@ class Script(scripts.Script):
# layer prompt disabled # layer prompt disabled
lp0 = LayerPrompt(self, layerprompt_enabled, remove_layer_prompts=True) lp0 = LayerPrompt(self, layerprompt_enabled, remove_layer_prompts=True)
proc1 = exec(p1, lp0, [ex, exlp]) proc1 = exec(p1, lp0, [ex, exlp, at])
features1 = ex.extracted_features features1 = ex.extracted_features
diff1 = exlp.extracted_features diff1 = exlp.extracted_features
proc1 = ex.add_images(p1, proc1, features1, color) proc1 = ex.add_images(p1, proc1, features1, color)
proc1 = at.add_images(p1, proc1, at.extracted_features, attn_color)
# layer prompt enabled # layer prompt enabled
proc2 = exec(p2, lp, [ex, exlp]) proc2 = exec(p2, lp, [ex, exlp, at])
features2 = ex.extracted_features features2 = ex.extracted_features
diff2 = exlp.extracted_features diff2 = exlp.extracted_features
proc2 = ex.add_images(p2, proc2, features2, color) proc2 = ex.add_images(p2, proc2, features2, color)
proc2 = at.add_images(p2, proc2, at.extracted_features, attn_color)
assert len(proc1.images) == len(proc2.images) assert len(proc1.images) == len(proc2.images)
@ -175,10 +201,9 @@ class Script(scripts.Script):
save_tensor(tensor, diff_path, basename) save_tensor(tensor, diff_path, basename)
else: else:
proc = exec(p, lp, [ex]) proc = exec(p, lp, [ex, at])
features = ex.extracted_features proc = ex.add_images(p, proc, ex.extracted_features, color)
if unet_features_enabled: proc = at.add_images(p, proc, at.extracted_features, attn_color)
proc = ex.add_images(p, proc, features, color)
return proc return proc

View File

@ -0,0 +1,180 @@
import math
from typing import TYPE_CHECKING
from torch import nn, Tensor, einsum
from einops import rearrange
from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock, CrossAttention, MemoryEfficientCrossAttention # type: ignore
from modules.processing import StableDiffusionProcessing
from modules.hypernetworks import hypernetwork
from modules import shared
from scripts.lib.feature_extractor import FeatureExtractorBase
from scripts.lib.features.featureinfo import MultiImageFeatures
from scripts.lib.features.extractor import get_unet_layer
from scripts.lib.attention.featureinfo import AttnFeatureInfo
from scripts.lib import layerinfo, tutils
from scripts.lib.utils import *
if TYPE_CHECKING:
from scripts.dumpunet import Script
class AttentionExtractor(FeatureExtractorBase):
# image_index -> step -> Features
extracted_features: MultiImageFeatures[AttnFeatureInfo]
def __init__(
self,
runner: "Script",
enabled: bool,
total_steps: int,
layer_input: str,
step_input: str,
path: str|None,
):
super().__init__(runner, enabled, total_steps, layer_input, step_input, path)
self.extracted_features = MultiImageFeatures()
def hook_unet(self, p: StableDiffusionProcessing, unet: nn.Module):
def create_hook(layername: str, block: BasicTransformerBlock, n: int, depth: int, c: int):
def forward(module, fn, x, context=None, *args, **kwargs):
result = fn(x, context=context, *args, **kwargs)
if self.steps_on_batch in self.steps:
if c == 2:
# process for only cross-attention
self.log(f"{self.steps_on_batch:>03} {layername}-{n}-{depth}-attn{c} ({'cross' if (block.disable_self_attn or 1 < c) else 'self'})")
self.log(f" | {shape(x),shape(context)} -> {shape(result)}")
qks, vqks = self.process_attention(module, x, context)
# qk := (batch, head, token, height*width)
# vqk := (batch, height*width, ch)
images_per_batch = qks.shape[0] // 2
assert qks.shape[0] == vqks.shape[0]
assert qks.shape[0] % 2 == 0
for image_index, (vk, vqk) in enumerate(
zip(qks[:images_per_batch], vqks[:images_per_batch]),
(self.batch_num-1) * images_per_batch
):
features = self.extracted_features[image_index][self.steps_on_batch]
features.add(
layername,
AttnFeatureInfo(vk, vqk)
)
return result
return forward
for layer in self.layers:
self.log(f"Attention: hooking {layer}...")
for n, d, block, attn1, attn2 in get_unet_attn_layers(unet, layer):
self.hook_forward(attn1, create_hook(layer, block, n, d, 1))
self.hook_forward(attn2, create_hook(layer, block, n, d, 2))
return super().hook_unet(p, unet)
def process_attention(self, module, x, context):
# q_in : unet features ([2, 4096, 320])
# k_in, v_in : embedding vector kv (cross-attention) ([2, 77, 320]) or unet features kv (self-attention) ([2, 4096, 320])
# q,k,v : head-separated q_in, k_in and v_in
ctx_k, ctx_v = hypernetwork.apply_hypernetwork(
shared.loaded_hypernetwork,
context if context is not None else x
)
q_in = module.to_q(x)
k_in = module.to_k(ctx_k)
v_in = module.to_v(ctx_v)
q: Tensor
k: Tensor
v: Tensor
q, k, v = map( # type: ignore
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=module.heads),
(q_in, k_in, v_in)
)
sim = einsum('b i d, b j d -> b i j', q, k) * module.scale
sim = sim.softmax(dim=-1)
# sim.shape == '(b h) i j'
o_in = einsum('b i j, b j d -> b i d', sim, v)
o = rearrange(o_in, '(b h) n d -> b n (h d)', h=module.heads)
qk: Tensor = rearrange(sim, '(b h) d t -> b h t d', h=module.heads).detach().clone()
vqk: Tensor = o.detach().clone()
self.log(f" | q: {shape(q_in)} # {shape(q)}")
self.log(f" | k: {shape(k_in)} # {shape(k)}")
self.log(f" | v: {shape(v_in)} # {shape(v)}")
#self.log(f" | qk: {shape(qk)} # {shape(sim)}")
#self.log(f" | vqk: {shape(vqk)}")
del q_in, k_in, v_in, q, k, v, sim, o_in, o
return qk, vqk
def feature_to_grid_images(self, feature: AttnFeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, color: bool):
#return feature_to_grid_images(feature, layer, width, height, color)
w, h, ch = get_shape(layer, width, height)
# qk
qk = feature.qk
heads_qk, ch_qk, n_qk = qk.shape
assert ch_qk == 77
assert w * h == n_qk, f"w={w}, h={h}, n_qk={n_qk}"
qk1 = rearrange(qk, 'a t (h w) -> (a t) h w', h=h).contiguous()
# vqk
vqk = feature.vqk
n_vqk, ch_vqk = vqk.shape
assert w * h == n_vqk, f"w={w}, h={h}, n_qk={n_vqk}"
assert ch == ch_vqk, f"ch={ch}, ch_vqk={ch_vqk}"
vqk1 = rearrange(vqk, '(h w) c -> c h w', h=h).contiguous()
#print(img_idx, step, layer, qk1.shape, vqk1.shape)
return tutils.tensor_to_image(qk1, ch_qk, heads_qk, color)
def save_features(self, feature: AttnFeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, path: str, basename: str):
w, h, ch = get_shape(layer, width, height)
qk = rearrange(feature.qk, 'a t (h w) -> (a t) h w', h=h).contiguous()
tutils.save_tensor(qk, path, basename)
def get_shape(layer: str, width: int, height: int):
assert layer in layerinfo.Settings
(ich, ih, iw), (och, oh, ow) = layerinfo.Settings[layer]
nw, nh = [max(1, math.ceil(x / 64)) for x in [width, height]]
return iw*nw, ih*nh, och
def get_unet_attn_layers(unet, layername: str):
unet_block = get_unet_layer(unet, layername)
def each_transformer(unet_block):
for block in unet_block.children():
if isinstance(block, SpatialTransformer):
yield block
def each_basic_block(trans):
for block in trans.children():
if isinstance(block, BasicTransformerBlock):
yield block
for n, trans in enumerate(each_transformer(unet_block)):
for depth, basic_block in enumerate(each_basic_block(trans.transformer_blocks)):
attn1: CrossAttention|MemoryEfficientCrossAttention
attn2: CrossAttention|MemoryEfficientCrossAttention
attn1, attn2 = basic_block.attn1, basic_block.attn2
assert isinstance(attn1, CrossAttention) or isinstance(attn1, MemoryEfficientCrossAttention)
assert isinstance(attn2, CrossAttention) or isinstance(attn2, MemoryEfficientCrossAttention)
yield n, depth, basic_block, attn1, attn2
def shape(t: Tensor|None) -> tuple|None:
return tuple(t.shape) if t is not None else None

View File

@ -0,0 +1,8 @@
from dataclasses import dataclass
from torch import Tensor
@dataclass
class AttnFeatureInfo:
qk: Tensor
vqk: Tensor

View File

@ -78,6 +78,14 @@ class UNet:
dump: DumpSetting dump: DumpSetting
info: Info info: Info
@dataclass
class Attn:
tab: Tab
enabled: Checkbox
settings: OutputSetting
dump: DumpSetting
info: Info
@dataclass @dataclass
class LayerPrompt: class LayerPrompt:
tab: Tab tab: Tab
@ -95,6 +103,7 @@ class Debug:
@dataclass @dataclass
class UI: class UI:
unet: UNet unet: UNet
attn: Attn
lp: LayerPrompt lp: LayerPrompt
debug: Debug debug: Debug
@ -109,6 +118,7 @@ class UI:
with Group(elem_id=id("ui")): with Group(elem_id=id("ui")):
result = UI( result = UI(
build_unet(id), build_unet(id),
build_attn(id),
build_layerprompt(id), build_layerprompt(id),
build_debug(runner, id), build_debug(runner, id),
) )
@ -140,6 +150,31 @@ def build_unet(id_: Callable[[str],str]):
info info
) )
def build_attn(id_: Callable[[str],str]):
id = lambda s: id_(f"attention-{s}")
with Tab("Attention", elem_id=id("tab")) as tab:
enabled = Checkbox(
label="Extract attention layers' features",
value=False,
elem_id=id("checkbox")
)
settings = OutputSetting.build(id)
with Accordion(label="Dump Setting", open=False):
dump = DumpSetting.build("Dump feature tensors to files", id)
info = build_info(id)
return Attn(
tab,
enabled,
settings,
dump,
info
)
def build_layerprompt(id_: Callable[[str],str]): def build_layerprompt(id_: Callable[[str],str]):
id = lambda s: id_(f"layerprompt-{s}") id = lambda s: id_(f"layerprompt-{s}")

View File

@ -12,12 +12,34 @@ from scripts.lib.report import message as E
if TYPE_CHECKING: if TYPE_CHECKING:
from scripts.dumpunet import Script from scripts.dumpunet import Script
class ForwardHook:
def __init__(self, module: nn.Module, fn: Callable[[nn.Module, Callable[..., Any]], Any]):
self.o = module.forward
self.fn = fn
self.module = module
self.module.forward = self.forward
def remove(self):
if self.module is not None and self.o is not None:
self.module.forward = self.o
self.module = None
self.o = None
self.fn = None
def forward(self, *args, **kwargs):
if self.module is not None and self.o is not None:
if self.fn is not None:
return self.fn(self.module, self.o, *args, **kwargs)
return None
class ExtractorBase: class ExtractorBase:
def __init__(self, runner: "Script", enabled: bool): def __init__(self, runner: "Script", enabled: bool):
self._runner = runner self._runner = runner
self._enabled = enabled self._enabled = enabled
self._handles: list[RemovableHandle] = [] self._handles: list[RemovableHandle|ForwardHook] = []
self._batch_num = 0 self._batch_num = 0
self._steps_on_batch = 0 self._steps_on_batch = 0
@ -152,6 +174,15 @@ class ExtractorBase:
assert isinstance(module, nn.Module) assert isinstance(module, nn.Module)
self._handles.append(module.register_forward_pre_hook(fn)) self._handles.append(module.register_forward_pre_hook(fn))
def hook_forward(
self,
module: nn.Module|Any,
fn: Callable[..., Any]
):
assert module is not None
assert isinstance(module, nn.Module)
self._handles.append(ForwardHook(module, fn))
def log(self, msg: str): def log(self, msg: str):
if self._runner.debug: if self._runner.debug:
print(E(msg), file=sys.stderr) print(E(msg), file=sys.stderr)

View File

@ -133,22 +133,22 @@ class FeatureExtractorBase(Generic[TInfo], ExtractorBase):
if shared.state.interrupted: if shared.state.interrupted:
break break
canvases = self.feature_to_grid_images(feature, layer, p.width, p.height, color) canvases = self.feature_to_grid_images(feature, layer, idx, step, p.width, p.height, color)
for canvas in canvases: for canvas in canvases:
builder.add(canvas, *args, {"Layer Name": layer, "Feature Steps": step}) builder.add(canvas, *args, {"Layer Name": layer, "Feature Steps": step})
if self.path is not None: if self.path is not None:
basename = f"{idx:03}-{layer}-{step:03}-{{ch:04}}-{t0}" basename = f"{idx:03}-{layer}-{step:03}-{{ch:04}}-{t0}"
self.save_features(feature, self.path, basename) self.save_features(feature, layer, idx, step, p.width, p.height, self.path, basename)
shared.total_tqdm.update() shared.total_tqdm.update()
return builder.to_proc(p, proc) return builder.to_proc(p, proc)
def feature_to_grid_images(self, feature: TInfo, layer: str, width: int, height: int, color: bool): def feature_to_grid_images(self, feature: TInfo, layer: str, img_idx: int, step: int, width: int, height: int, color: bool):
raise NotImplementedError(f"{self.__class__.__name__}.feature_to_grid_images") raise NotImplementedError(f"{self.__class__.__name__}.feature_to_grid_images")
def save_features(self, feature: TInfo, path: str, basename: str): def save_features(self, feature: TInfo, layer: str, img_idx: int, step: int, width: int, height: int, path: str, basename: str):
raise NotImplementedError(f"{self.__class__.__name__}.save_features") raise NotImplementedError(f"{self.__class__.__name__}.save_features")
def _fixup(self, proc: Processed): def _fixup(self, proc: Processed):

View File

@ -69,10 +69,10 @@ class FeatureExtractor(FeatureExtractorBase[FeatureInfo]):
target = get_unet_layer(unet, layer) target = get_unet_layer(unet, layer)
self.hook_layer(target, create_hook(layer)) self.hook_layer(target, create_hook(layer))
def feature_to_grid_images(self, feature: FeatureInfo, layer: str, width: int, height: int, color: bool): def feature_to_grid_images(self, feature: FeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, color: bool):
return feature_to_grid_images(feature, layer, width, height, color) return feature_to_grid_images(feature, layer, width, height, color)
def save_features(self, feature: FeatureInfo, path: str, basename: str): def save_features(self, feature: FeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, path: str, basename: str):
save_features(feature, path, basename) save_features(feature, path, basename)
def get_unet_layer(unet, layername: str) -> nn.modules.Module: def get_unet_layer(unet, layername: str) -> nn.modules.Module:

View File

@ -1,12 +1,9 @@
import math
from typing import Generator from typing import Generator
from torch import Tensor from torch import Tensor
from scripts.lib import tutils from scripts.lib import tutils
from scripts.lib import layerinfo from scripts.lib.features.featureinfo import FeatureInfo, MultiImageFeatures
from scripts.lib.features.featureinfo import FeatureInfo, Features, MultiImageFeatures
from scripts.lib.report import message as E
def feature_diff( def feature_diff(
features1: MultiImageFeatures[FeatureInfo], features1: MultiImageFeatures[FeatureInfo],
@ -55,10 +52,8 @@ def feature_to_grid_images(
if isinstance(feature, FeatureInfo): if isinstance(feature, FeatureInfo):
tensor = feature.output tensor = feature.output
assert isinstance(tensor, Tensor) assert isinstance(tensor, Tensor)
assert len(tensor.size()) == 3
grid_x, grid_y = _get_grid_num(layer, width, height) canvases = tutils.tensor_to_grid_images(tensor, layer, width, height, color)
canvases = tutils.tensor_to_image(tensor, grid_x, grid_y, color)
return canvases return canvases
def save_features( def save_features(
@ -67,30 +62,3 @@ def save_features(
basename: str basename: str
): ):
tutils.save_tensor(feature.output, save_dir, basename) tutils.save_tensor(feature.output, save_dir, basename)
def _get_grid_num(layer: str, width: int, height: int):
assert layer is not None and layer != "", E("<Layers> must not be empty.")
assert layer in layerinfo.Settings, E(f"Invalid <Layers> value: {layer}.")
_, (ch, mh, mw) = layerinfo.Settings[layer]
iw = math.ceil(width / 64)
ih = math.ceil(height / 64)
w = mw * iw
h = mh * ih
# w : width of a feature map
# h : height of a feature map
# ch: a number of a feature map
n = [w, h]
while ch % 2 == 0:
n[n[0]>n[1]] *= 2
ch //= 2
n[n[0]>n[1]] *= ch
if n[0] > n[1]:
while n[0] > n[1] * 2 and (n[0] // w) % 2 == 0:
n[0] //= 2
n[1] *= 2
else:
while n[0] * 2 < n[1] and (n[1] // h) % 2 == 0:
n[0] *= 2
n[1] //= 2
return n[0] // w, n[1] // h

View File

@ -1,4 +1,5 @@
import os import os
import math
from torch import Tensor from torch import Tensor
import numpy as np import numpy as np
@ -6,6 +7,20 @@ from PIL import Image
from modules import shared from modules import shared
from scripts.lib import layerinfo
from scripts.lib.report import message as E
def tensor_to_grid_images(
tensor: Tensor,
layer: str,
width: int,
height: int,
color: bool
):
grid_x, grid_y = get_grid_num(layer, width, height)
canvases = tensor_to_image(tensor, grid_x, grid_y, color)
return canvases
def tensor_to_image( def tensor_to_image(
tensor: Tensor, tensor: Tensor,
grid_x: int, grid_x: int,
@ -93,3 +108,30 @@ def _tensor_to_image(array: np.ndarray, color: bool):
else: else:
return np.clip(np.abs(array) * 256, 0, 255).astype(np.uint8) return np.clip(np.abs(array) * 256, 0, 255).astype(np.uint8)
def get_grid_num(layer: str, width: int, height: int):
assert layer is not None and layer != "", E("<Layers> must not be empty.")
assert layer in layerinfo.Settings, E(f"Invalid <Layers> value: {layer}.")
_, (ch, mh, mw) = layerinfo.Settings[layer]
iw = math.ceil(width / 64)
ih = math.ceil(height / 64)
w = mw * iw
h = mh * ih
# w : width of a feature map
# h : height of a feature map
# ch: a number of a feature map
n = [w, h]
while ch % 2 == 0:
n[n[0]>n[1]] *= 2
ch //= 2
n[n[0]>n[1]] *= ch
if n[0] > n[1]:
while n[0] > n[1] * 2 and (n[0] // w) % 2 == 0:
n[0] //= 2
n[1] *= 2
else:
while n[0] * 2 < n[1] and (n[1] // h) % 2 == 0:
n[0] *= 2
n[1] //= 2
return n[0] // w, n[1] // h

View File

@ -2,6 +2,8 @@
#dumpunet-img2img-features-layerinfo, #dumpunet-img2img-features-layerinfo,
#dumpunet-txt2img-layerprompt-layerinfo, #dumpunet-txt2img-layerprompt-layerinfo,
#dumpunet-img2img-layerprompt-layerinfo, #dumpunet-img2img-layerprompt-layerinfo,
#dumpunet-txt2img-attention-layerinfo,
#dumpunet-img2img-attention-layerinfo,
#dumpunet-txt2img-layerprompt-errors, #dumpunet-txt2img-layerprompt-errors,
#dumpunet-img2img-layerprompt-errors { #dumpunet-img2img-layerprompt-errors {
font-family: monospace; font-family: monospace;
@ -9,6 +11,8 @@
#dumpunet-txt2img-features-checkbox, #dumpunet-txt2img-features-checkbox,
#dumpunet-img2img-features-checkbox, #dumpunet-img2img-features-checkbox,
#dumpunet-txt2img-attention-checkbox,
#dumpunet-img2img-attention-checkbox,
#dumpunet-txt2img-layerprompt-checkbox, #dumpunet-txt2img-layerprompt-checkbox,
#dumpunet-img2img-layerprompt-checkbox { #dumpunet-img2img-layerprompt-checkbox {
/* background-color: #fff8f0; */ /* background-color: #fff8f0; */
@ -17,6 +21,8 @@
.dark #dumpunet-txt2img-features-checkbox, .dark #dumpunet-txt2img-features-checkbox,
.dark #dumpunet-img2img-features-checkbox, .dark #dumpunet-img2img-features-checkbox,
.dark #dumpunet-txt2img-attention-checkbox,
.dark #dumpunet-img2img-attention-checkbox,
.dark #dumpunet-txt2img-layerprompt-checkbox, .dark #dumpunet-txt2img-layerprompt-checkbox,
.dark #dumpunet-img2img-layerprompt-checkbox { .dark #dumpunet-img2img-layerprompt-checkbox {
/*background-color: inherit; /*background-color: inherit;