93 lines
3.6 KiB
Python
93 lines
3.6 KiB
Python
from typing import TYPE_CHECKING
|
|
|
|
from torch import nn, Tensor
|
|
|
|
from modules.processing import StableDiffusionProcessing
|
|
|
|
from scripts.lib import layerinfo
|
|
from scripts.lib.feature_extractor import FeatureExtractorBase
|
|
from scripts.lib.features.featureinfo import FeatureInfo
|
|
from scripts.lib.features.utils import feature_to_grid_images, save_features
|
|
from scripts.lib.report import message as E
|
|
from scripts.lib.utils import *
|
|
from scripts.lib.colorizer import Colorizer
|
|
|
|
if TYPE_CHECKING:
|
|
from scripts.dumpunet import Script
|
|
|
|
class FeatureExtractor(FeatureExtractorBase[FeatureInfo]):
|
|
|
|
def __init__(
|
|
self,
|
|
runner: "Script",
|
|
enabled: bool,
|
|
total_steps: int,
|
|
layer_input: str,
|
|
step_input: str,
|
|
path: str|None,
|
|
):
|
|
super().__init__(runner, enabled, total_steps, layer_input, step_input, path)
|
|
|
|
def hook_unet(self, p: StableDiffusionProcessing, unet: nn.Module):
|
|
|
|
#time_embed : nn.modules.container.Sequential
|
|
#input_blocks : nn.modules.container.ModuleList
|
|
#middle_block : ldm.modules.diffusionmodules.openaimodel.TimestepEmbedSequential
|
|
#output_blocks : nn.modules.container.ModuleList
|
|
#out_ : nn.modules.container.Sequential
|
|
#time_embed = unet.time_embed
|
|
#input_blocks = unet.input_blocks
|
|
#middle_block = unet.middle_block
|
|
#output_blocks = unet.output_blocks
|
|
#out_ = unet.out
|
|
#summary(unet, (4, 512, 512))
|
|
|
|
def create_hook(layername: str):
|
|
|
|
def forward_hook(module, inputs, outputs):
|
|
if self.steps_on_batch in self.steps:
|
|
self.log(f"{self.steps_on_batch} {layername} {inputs[0].size()} {outputs.size()}")
|
|
|
|
images_per_batch = p.batch_size
|
|
|
|
for image_index, output in enumerate(
|
|
outputs.detach().clone()[:images_per_batch],
|
|
(self.batch_num-1) * images_per_batch
|
|
):
|
|
features = self.extracted_features[image_index][self.steps_on_batch]
|
|
features.add(
|
|
layername,
|
|
FeatureInfo(
|
|
[ x.size() for x in inputs if type(x) == Tensor ],
|
|
output.size(),
|
|
output
|
|
)
|
|
)
|
|
return forward_hook
|
|
|
|
for layer in self.layers:
|
|
self.log(f"U-Net: hooking {layer}...")
|
|
target = get_unet_layer(unet, layer)
|
|
self.hook_layer(target, create_hook(layer))
|
|
|
|
def feature_to_grid_images(self, feature: FeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, add_average: bool, color: Colorizer):
|
|
return feature_to_grid_images(feature, layer, width, height, add_average, color)
|
|
|
|
def save_features(self, feature: FeatureInfo, layer: str, img_idx: int, step: int, width: int, height: int, path: str, basename: str):
|
|
save_features(feature, path, basename)
|
|
|
|
def get_unet_layer(unet, layername: str) -> nn.modules.Module:
|
|
idx = layerinfo.input_index(layername)
|
|
if idx is not None:
|
|
return unet.input_blocks[idx]
|
|
|
|
idx = layerinfo.middle_index(layername)
|
|
if idx is not None:
|
|
return unet.middle_block
|
|
|
|
idx = layerinfo.output_index(layername)
|
|
if idx is not None:
|
|
return unet.output_blocks[idx]
|
|
|
|
raise ValueError(E(f"Invalid layer name: {layername}"))
|