stable-diffusion-webui-dump.../scripts/lib/features/featureinfo.py

84 lines
2.2 KiB
Python

from typing import Generic, TypeVar
from dataclasses import dataclass
from collections import defaultdict
import torch
from torch import Tensor
import numpy as np
from scripts.lib import layerinfo
from scripts.lib.report import message as E
from scripts.lib.utils import *
@dataclass
class FeatureInfo:
input_dims: list[torch.Size]
output_dims: torch.Size
output: Tensor
#@dataclass
#class AttnFeatureInfo:
# qk: np.ndarray
# vqk: np.ndarray
TInfo = TypeVar("TInfo")
class Features(Generic[TInfo]):
# layer -> UNetFeatureInfo
features : dict[str, TInfo]
def __init__(self):
self.features = dict()
def __getitem__(self, layer: int|str):
v = None
if isinstance(layer, int):
v = self.get_by_index(layer)
elif isinstance(layer, str):
v = self.get_by_name(layer)
if v is None:
raise KeyError(E(f"invalid key: {type(layer)} {layer}"))
return v
def __iter__(self):
return sorted_items(self.features)
def __contains__(self, key: int|str):
if isinstance(key, int):
key = layerinfo.name(key) or ""
return key in self.features
def layers(self):
return sorted_keys(self.features)
def get_by_name(self, layer: str) -> TInfo|None:
if layer in self.features:
return self.features[layer]
return None
def get_by_index(self, layer: int) -> TInfo|None:
name = layerinfo.name(layer)
if name is None:
return None
return self.get_by_name(name)
def add(self, layer: int|str, info: TInfo):
if isinstance(layer, int):
name = layerinfo.name(layer)
if name is None:
raise ValueError(E(f"invalid layer name: {layer}"))
layer = name
self.features[layer] = info
class MultiStepFeatures(Generic[TInfo], defaultdict[int,Features[TInfo]]):
def __init__(self):
super().__init__(lambda: Features())
class MultiImageFeatures(Generic[TInfo], defaultdict[int,MultiStepFeatures[TInfo]]):
def __init__(self):
super().__init__(lambda: MultiStepFeatures())