mirror of https://github.com/vladmandic/automatic
80 lines
2.4 KiB
Python
80 lines
2.4 KiB
Python
import torch.nn as nn
|
|
|
|
from .layers import SwiGLU
|
|
from .attention import SelfAttention, CrossAttention
|
|
from .norm import RMSNorm
|
|
from .adaln import AdaLN
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
ctx_dim,
|
|
heads,
|
|
dim_head,
|
|
mlp_dim,
|
|
pos_dim,
|
|
use_adaln=False,
|
|
use_shared_adaln=False,
|
|
ctx_from_self=False,
|
|
norm_layer=RMSNorm,
|
|
):
|
|
super().__init__()
|
|
self.use_adaln = use_adaln
|
|
self.attn = SelfAttention(dim, heads, dim_head, pos_dim)
|
|
if ctx_dim is None:
|
|
self.xattn_pre_norm = None
|
|
self.xattn = None
|
|
else:
|
|
self.ctx_from_self = ctx_from_self
|
|
self.xattn = CrossAttention(dim, ctx_dim, heads, dim_head, pos_dim)
|
|
self.mlp = SwiGLU(dim, mlp_dim, dim)
|
|
|
|
if self.use_adaln:
|
|
self.attn_pre_norm = AdaLN(
|
|
dim, dim, norm_layer=norm_layer, shared=use_shared_adaln
|
|
)
|
|
self.mlp_pre_norm = AdaLN(
|
|
dim, dim, norm_layer=norm_layer, shared=use_shared_adaln
|
|
)
|
|
if self.xattn is not None:
|
|
self.xattn_pre_norm = AdaLN(
|
|
dim, dim, norm_layer=norm_layer, shared=use_shared_adaln
|
|
)
|
|
else:
|
|
self.attn_pre_norm = norm_layer(dim)
|
|
self.mlp_pre_norm = norm_layer(dim)
|
|
if self.xattn is not None:
|
|
self.xattn_pre_norm = norm_layer(dim)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
ctx,
|
|
pos_map=None,
|
|
ctx_pos_map=None,
|
|
y=None,
|
|
x_mask=None,
|
|
ctx_mask=None,
|
|
shared_adaln=None,
|
|
):
|
|
y = [y] if y is not None else []
|
|
y = y if shared_adaln is None else [y[0], shared_adaln[0]]
|
|
x, gate = self.attn_pre_norm(x, *y)
|
|
x = x + self.attn(x, pos_map, mask=x_mask) * gate
|
|
|
|
if self.xattn is not None:
|
|
if shared_adaln is not None:
|
|
y[1] = shared_adaln[1]
|
|
x, gate = self.xattn_pre_norm(x, *y)
|
|
if self.ctx_from_self:
|
|
ctx_mask = x_mask
|
|
x = x + self.xattn(x, ctx, pos_map, ctx_pos_map, mask=ctx_mask) * gate
|
|
|
|
if shared_adaln is not None:
|
|
y[1] = shared_adaln[-1]
|
|
x, gate = self.mlp_pre_norm(x, *y)
|
|
x = x + self.mlp(x) * gate
|
|
return x
|