automatic/pipelines/hdm/xut/modules/transformer.py

80 lines
2.4 KiB
Python

import torch.nn as nn
from .layers import SwiGLU
from .attention import SelfAttention, CrossAttention
from .norm import RMSNorm
from .adaln import AdaLN
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
ctx_dim,
heads,
dim_head,
mlp_dim,
pos_dim,
use_adaln=False,
use_shared_adaln=False,
ctx_from_self=False,
norm_layer=RMSNorm,
):
super().__init__()
self.use_adaln = use_adaln
self.attn = SelfAttention(dim, heads, dim_head, pos_dim)
if ctx_dim is None:
self.xattn_pre_norm = None
self.xattn = None
else:
self.ctx_from_self = ctx_from_self
self.xattn = CrossAttention(dim, ctx_dim, heads, dim_head, pos_dim)
self.mlp = SwiGLU(dim, mlp_dim, dim)
if self.use_adaln:
self.attn_pre_norm = AdaLN(
dim, dim, norm_layer=norm_layer, shared=use_shared_adaln
)
self.mlp_pre_norm = AdaLN(
dim, dim, norm_layer=norm_layer, shared=use_shared_adaln
)
if self.xattn is not None:
self.xattn_pre_norm = AdaLN(
dim, dim, norm_layer=norm_layer, shared=use_shared_adaln
)
else:
self.attn_pre_norm = norm_layer(dim)
self.mlp_pre_norm = norm_layer(dim)
if self.xattn is not None:
self.xattn_pre_norm = norm_layer(dim)
def forward(
self,
x,
ctx,
pos_map=None,
ctx_pos_map=None,
y=None,
x_mask=None,
ctx_mask=None,
shared_adaln=None,
):
y = [y] if y is not None else []
y = y if shared_adaln is None else [y[0], shared_adaln[0]]
x, gate = self.attn_pre_norm(x, *y)
x = x + self.attn(x, pos_map, mask=x_mask) * gate
if self.xattn is not None:
if shared_adaln is not None:
y[1] = shared_adaln[1]
x, gate = self.xattn_pre_norm(x, *y)
if self.ctx_from_self:
ctx_mask = x_mask
x = x + self.xattn(x, ctx, pos_map, ctx_pos_map, mask=ctx_mask) * gate
if shared_adaln is not None:
y[1] = shared_adaln[-1]
x, gate = self.mlp_pre_norm(x, *y)
x = x + self.mlp(x) * gate
return x