mirror of https://github.com/vladmandic/automatic
28 lines
917 B
Python
28 lines
917 B
Python
from typing import Type
|
|
import torch
|
|
from modules.dml.hijack.utils import catch_nan
|
|
|
|
|
|
def make_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
|
|
class ToMeBlock(block_class):
|
|
# Save for unpatching later
|
|
_parent = block_class
|
|
|
|
def _forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:
|
|
m_a, m_c, m_m, u_a, u_c, u_m = tomesd.patch.compute_merge(x, self._tome_info)
|
|
|
|
# This is where the meat of the computation happens
|
|
x = u_a(self.attn1(m_a(self.norm1(x)), context=context if self.disable_self_attn else None)) + x
|
|
x = catch_nan(lambda: (u_c(self.attn2(m_c(self.norm2(x)), context=context)) + x))
|
|
x = u_m(self.ff(m_m(self.norm3(x)))) + x
|
|
|
|
return x
|
|
|
|
return ToMeBlock
|
|
|
|
try:
|
|
import tomesd
|
|
tomesd.patch.make_tome_block = make_tome_block
|
|
except Exception:
|
|
pass
|