255 lines
9.0 KiB
Python
255 lines
9.0 KiB
Python
import os
|
|
import time
|
|
|
|
from torch import nn, Tensor
|
|
from torch.utils.hooks import RemovableHandle
|
|
|
|
from modules.processing import Processed, StableDiffusionProcessing
|
|
from modules import shared
|
|
|
|
from scripts.dumpunet import layerinfo
|
|
from scripts.dumpunet.features.featureinfo import FeatureInfo, MultiImageFeatures
|
|
from scripts.dumpunet.features.process import feature_to_grid_images, save_features
|
|
from scripts.dumpunet.ui import retrieve_layers, retrieve_steps
|
|
from scripts.dumpunet.report import message as E
|
|
from scripts.dumpunet.utils import *
|
|
|
|
class FeatureExtractor:
|
|
|
|
# image_index -> step -> Features
|
|
extracted_features: MultiImageFeatures
|
|
|
|
# steps to process
|
|
steps: list[int]
|
|
|
|
# layers to process
|
|
layers: list[str]
|
|
|
|
def __init__(
|
|
self,
|
|
runner,
|
|
enabled: bool,
|
|
total_steps: int,
|
|
layer_input: str,
|
|
step_input: str,
|
|
path: str|None,
|
|
):
|
|
self._runner = runner
|
|
self._enabled = enabled
|
|
self._handles: list[RemovableHandle] = []
|
|
|
|
self.extracted_features = MultiImageFeatures()
|
|
self.steps = []
|
|
self.layers = []
|
|
self.path = None
|
|
|
|
if not self._enabled:
|
|
return
|
|
|
|
assert layer_input is not None and layer_input != "", E("<Layers> must not be empty.")
|
|
if path is not None:
|
|
assert path != "", E("<Output path> must not be empty.")
|
|
# mkdir -p path
|
|
if os.path.exists(path):
|
|
assert os.path.isdir(path), E("<Output path> already exists and is not a directory.")
|
|
else:
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
self.layers = retrieve_layers(layer_input)
|
|
self.steps = (
|
|
retrieve_steps(step_input)
|
|
or list(range(1, total_steps+1))
|
|
)
|
|
self.path = path
|
|
|
|
def __enter__(self):
|
|
pass
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
for handle in self._handles:
|
|
handle.remove()
|
|
self._handles.clear()
|
|
|
|
self._handles = []
|
|
self.extracted_features = MultiImageFeatures()
|
|
|
|
def setup(
|
|
self,
|
|
p: StableDiffusionProcessing,
|
|
):
|
|
if not self._enabled:
|
|
return
|
|
|
|
unet = p.sd_model.model.diffusion_model # type: ignore
|
|
|
|
#time_embed : nn.modules.container.Sequential
|
|
#input_blocks : nn.modules.container.ModuleList
|
|
#middle_block : ldm.modules.diffusionmodules.openaimodel.TimestepEmbedSequential
|
|
#output_blocks : nn.modules.container.ModuleList
|
|
#out_ : nn.modules.container.Sequential
|
|
#time_embed = unet.time_embed
|
|
#input_blocks = unet.input_blocks
|
|
#middle_block = unet.middle_block
|
|
#output_blocks = unet.output_blocks
|
|
#out_ = unet.out
|
|
#summary(unet, (4, 512, 512))
|
|
|
|
def start_step(module, inputs, outputs):
|
|
self._runner.steps_on_batch += 1
|
|
|
|
def create_hook(layername: str):
|
|
|
|
def forward_hook(module, inputs, outputs):
|
|
# print(f"{self._runner.steps_on_batch} {layername} {inputs[0].size()} {outputs.size()}")
|
|
|
|
if self._runner.steps_on_batch in self.steps:
|
|
images_per_batch = outputs.size()[0] // 2 # two same outputs per sample???
|
|
|
|
for image_index, output in enumerate(
|
|
outputs.detach().clone()[:images_per_batch],
|
|
(self._runner.batch_num-1) * images_per_batch
|
|
):
|
|
features = self.extracted_features[image_index][self._runner.steps_on_batch]
|
|
features.add(
|
|
layername,
|
|
FeatureInfo(
|
|
[ x.size() for x in inputs if type(x) == Tensor ],
|
|
output.size(),
|
|
output
|
|
)
|
|
)
|
|
return forward_hook
|
|
|
|
self._handles.append(unet.time_embed.register_forward_hook(start_step))
|
|
for layer in self.layers:
|
|
target = get_unet_layer(unet, layer)
|
|
self._handles.append(target.register_forward_hook(create_hook(layer)))
|
|
|
|
def add_images(
|
|
self,
|
|
p: StableDiffusionProcessing,
|
|
proc: Processed,
|
|
extracted_features: MultiImageFeatures,
|
|
color: bool
|
|
) -> Processed:
|
|
|
|
if not self._enabled:
|
|
return proc
|
|
|
|
if shared.state.interrupted:
|
|
return proc
|
|
|
|
index0 = proc.index_of_first_image
|
|
preview_images, rest_images = proc.images[:index0], proc.images[index0:]
|
|
|
|
assert rest_images is not None and len(rest_images) != 0, E("empty output?")
|
|
|
|
# Now `rest_images` is the list of the images we are interested in.
|
|
|
|
images = []
|
|
seeds = []
|
|
subseeds = []
|
|
prompts = []
|
|
neg_prompts = []
|
|
infotexts = []
|
|
|
|
def add_image(image, seed, subseed, prompt, neg_prompt, infotext, layername=None, feature_steps=None):
|
|
images.append(image)
|
|
seeds.append(seed)
|
|
subseeds.append(subseed)
|
|
prompts.append(prompt)
|
|
neg_prompts.append(neg_prompt)
|
|
info = infotext
|
|
if layername is not None or feature_steps is not None:
|
|
if info:
|
|
info += "\n"
|
|
if layername is not None:
|
|
info += f"Layer Name: {layername}"
|
|
if feature_steps is not None:
|
|
if layername is not None: info += ", "
|
|
info += f"Feature Steps: {feature_steps}"
|
|
|
|
infotexts.append(info)
|
|
|
|
for image in preview_images:
|
|
preview_info = proc.infotexts.pop(0)
|
|
add_image(image, proc.seed, proc.subseed, proc.prompt, proc.negative_prompt, preview_info)
|
|
|
|
# For Dynamic Prompt Extension
|
|
# which is not append subseeds...
|
|
while len(proc.all_subseeds) < len(proc.all_seeds):
|
|
proc.all_subseeds.append(proc.all_subseeds[0] if 0 < len(proc.all_subseeds) else 0)
|
|
|
|
assert all([
|
|
len(rest_images) == len(x) for x
|
|
in [
|
|
proc.all_seeds,
|
|
proc.all_subseeds,
|
|
proc.all_prompts,
|
|
proc.all_negative_prompts,
|
|
proc.infotexts
|
|
]
|
|
]), E(f"#images={len(rest_images)}, #seeds={len(proc.all_seeds)}, #subseeds={len(proc.all_subseeds)}, #pr={len(proc.all_prompts)}, #npr={len(proc.all_negative_prompts)}, #info={len(proc.infotexts)}")
|
|
|
|
sorted_step_features = list(sorted_values(extracted_features))
|
|
assert len(rest_images) == len(sorted_step_features), E(f"#images={len(rest_images)}, #features={len(sorted_step_features)}")
|
|
|
|
t0 = int(time.time()) # for binary files' name
|
|
shared.total_tqdm.clear()
|
|
shared.total_tqdm.updateTotal(len(sorted_step_features) * len(self.steps) * len(self.layers))
|
|
|
|
image_args = zip(
|
|
proc.all_seeds,
|
|
proc.all_subseeds,
|
|
proc.all_prompts,
|
|
proc.all_negative_prompts,
|
|
proc.infotexts
|
|
)
|
|
|
|
for idx, (image, step_features, args) in enumerate(zip(rest_images, sorted_step_features, image_args)):
|
|
add_image(image, *args)
|
|
|
|
for step, features in sorted_items(step_features):
|
|
for layer, feature in features:
|
|
|
|
if shared.state.interrupted:
|
|
break
|
|
|
|
canvases = feature_to_grid_images(feature, layer, p.width, p.height, color)
|
|
for canvas in canvases:
|
|
add_image(canvas, *args, layername=layer, feature_steps=step)
|
|
|
|
if self.path is not None:
|
|
basename = f"{idx:03}-{layer}-{step:03}-{{ch:04}}-{t0}"
|
|
save_features(feature, self.path, basename)
|
|
|
|
shared.total_tqdm.update()
|
|
|
|
return Processed(
|
|
p,
|
|
images,
|
|
seed=proc.seed,
|
|
info=proc.info,
|
|
subseed=proc.subseed,
|
|
all_seeds=seeds,
|
|
all_subseeds=subseeds,
|
|
all_prompts=prompts,
|
|
all_negative_prompts=neg_prompts,
|
|
infotexts=infotexts,
|
|
)
|
|
|
|
def get_unet_layer(unet, layername: str) -> nn.modules.Module:
|
|
idx = layerinfo.input_index(layername)
|
|
if idx is not None:
|
|
return unet.input_blocks[idx]
|
|
|
|
idx = layerinfo.middle_index(layername)
|
|
if idx is not None:
|
|
return unet.middle_block
|
|
|
|
idx = layerinfo.output_index(layername)
|
|
if idx is not None:
|
|
return unet.output_blocks[idx]
|
|
|
|
raise ValueError(E(f"Invalid layer name: {layername}"))
|