stable-diffusion-webui-dump.../scripts/dumpunet/features/featureinfo.py

75 lines
2.0 KiB
Python

from dataclasses import dataclass
from collections import defaultdict
import torch
from torch import Tensor
from scripts.dumpunet import layerinfo
from scripts.dumpunet.report import message as E
from scripts.dumpunet.utils import *
@dataclass
class FeatureInfo:
input_dims: list[torch.Size]
output_dims: torch.Size
output: Tensor
class Features:
# layer -> FeatureInfo
features : dict[str, FeatureInfo]
def __init__(self):
self.features = dict()
def __getitem__(self, layer: int|str):
v = None
if isinstance(layer, int):
v = self.get_by_index(layer)
elif isinstance(layer, str):
v = self.get_by_name(layer)
if v is None:
raise KeyError(E(f"invalid key: {type(layer)} {layer}"))
return v
def __iter__(self):
return sorted_items(self.features)
def __contains__(self, key: int|str):
if isinstance(key, int):
key = layerinfo.name(key) or ""
return key in self.features
def layers(self):
return sorted_keys(self.features)
def get_by_name(self, layer: str) -> FeatureInfo|None:
if layer in self.features:
return self.features[layer]
return None
def get_by_index(self, layer: int) -> FeatureInfo|None:
name = layerinfo.name(layer)
if name is None:
return None
return self.get_by_name(name)
def add(self, layer: int|str, info: FeatureInfo):
if isinstance(layer, int):
name = layerinfo.name(layer)
if name is None:
raise ValueError(E(f"invalid layer name: {layer}"))
layer = name
self.features[layer] = info
class MultiStepFeatures(defaultdict[int,Features]):
def __init__(self):
super().__init__(lambda: Features())
class MultiImageFeatures(defaultdict[int,MultiStepFeatures]):
def __init__(self):
super().__init__(lambda: MultiStepFeatures())