automatic/pipelines/hdm/xut/modules/attention.py

333 lines
12 KiB
Python

import math
from functools import cache
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import xformers
XFORMERS_AVAILABLE = True
except ImportError:
XFORMERS_AVAILABLE = False
if XFORMERS_AVAILABLE:
from xformers.ops import memory_efficient_attention
else:
memory_efficient_attention = None
from .. import env
from ..utils import compile_wrapper
from .axial_rope import AxialRoPE
if not env.USE_XFORMERS:
memory_efficient_attention = None
if env.USE_VANILLA:
@compile_wrapper
def memory_efficient_attention(query, key, value, attn_bias=None, p=0.0):
scale = 1.0 / query.shape[-1] ** 0.5
query = query * scale
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
attn = attn @ value
return attn.transpose(1, 2).contiguous()
class SelfAttention(nn.Module):
def __init__(self, dim, n_heads=8, head_dim=-1, pos_dim=2):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.head_dim = head_dim if head_dim > 0 else dim // n_heads
self.n_heads = dim // self.head_dim
assert (
self.n_heads * self.head_dim == dim
), "dim must be divisible by n_heads or head_dim"
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.out = nn.Linear(dim, dim)
self.rope = AxialRoPE(self.head_dim, self.n_heads, pos_dim)
self.attn = memory_efficient_attention or F.scaled_dot_product_attention
self.xformers = memory_efficient_attention is not None
def forward(self, x, pos_map=None, mask=None):
b, n, _, h = *x.shape, self.n_heads
q, k, v = self.qkv(x).chunk(3, dim=-1)
if pos_map is not None:
q = self.rope(q.reshape(b, n, h, -1).transpose(1, 2), pos_map)
k = self.rope(k.reshape(b, n, h, -1).transpose(1, 2), pos_map)
v = v.reshape(b, n, h, -1)
if self.xformers:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
else:
v = v.transpose(1, 2)
else:
q, k, v = map(lambda t: t.reshape(b, n, h, -1), (q, k, v))
if not self.xformers:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if mask is not None:
if mask.ndim == 2:
mask = mask[None, None]
elif mask.ndim == 3:
mask = mask[:, None]
if n % 8 and self.xformers:
align_n = math.ceil(n / 8) * 8
mask_align = torch.empty(
*mask.shape[:3], align_n, device=mask.device, dtype=mask.dtype
)
mask_align[..., :n] = mask
mask = mask_align.to(q).expand(b, h, n, align_n)[..., :n]
else:
mask = mask.to(q).expand(b, h, n, n)
attn = self.attn(q, k, v, mask)
if not self.xformers:
attn = attn.transpose(1, 2)
attn = attn.reshape(b, n, h * self.head_dim)
attn = self.out(attn)
return attn
class CrossAttention(nn.Module):
def __init__(self, dim, ctx_dim, n_heads=8, head_dim=-1, pos_dim=2):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.head_dim = head_dim if head_dim > 0 else dim // n_heads
self.n_heads = dim // self.head_dim
assert (
self.n_heads * self.head_dim == dim
), "dim must be divisible by n_heads or head_dim"
self.q = nn.Linear(dim, dim, bias=False)
self.kv = nn.Linear(ctx_dim, dim * 2, bias=False)
self.out = nn.Linear(dim, dim)
self.rope = AxialRoPE(self.head_dim, self.n_heads, pos_dim)
self.attn = memory_efficient_attention or F.scaled_dot_product_attention
self.xformers = memory_efficient_attention is not None
def forward(self, x, ctx, pos_map=None, ctx_pos_map=None, mask=None):
b, n, _, h = *x.shape, self.n_heads
ctx_n = ctx.shape[1]
q = self.q(x)
k, v = self.kv(ctx).chunk(2, dim=-1)
if pos_map is not None:
q = self.rope(q.reshape(b, n, h, -1).transpose(1, 2), pos_map)
q = q if not self.xformers else q.transpose(1, 2)
else:
q = q.reshape(b, n, h, -1)
q = q if self.xformers else q.transpose(1, 2)
if ctx_pos_map is not None:
k = self.rope(k.reshape(b, ctx_n, h, -1).transpose(1, 2), ctx_pos_map)
k = k if not self.xformers else k.transpose(1, 2)
else:
k = k.reshape(b, ctx_n, h, -1)
k = k if self.xformers else k.transpose(1, 2)
v = v.reshape(b, ctx_n, h, -1)
v = v if self.xformers else v.transpose(1, 2)
if mask is not None:
if mask.ndim == 2:
mask = mask[None, None]
elif mask.ndim == 3:
mask = mask[:, None]
if ctx_n % 8 and self.xformers:
align_n = math.ceil(ctx_n / 8) * 8
mask_align = torch.empty(
*mask.shape[:3], align_n, device=mask.device, dtype=mask.dtype
)
mask_align[..., :ctx_n] = mask
mask = mask_align.to(q).expand(b, h, n, align_n)[..., :ctx_n]
else:
mask = mask.to(q).expand(b, h, n, ctx_n)
attn = self.attn(q, k, v, mask)
if not self.xformers:
attn = attn.transpose(1, 2)
attn = attn.reshape(b, n, h * self.head_dim)
attn = self.out(attn)
return attn
class AttentionPooling(CrossAttention):
def __init__(self, dim, n_heads=8, head_dim=-1, pos_dim=2):
super().__init__(dim, dim, n_heads, head_dim, pos_dim)
self.query_token = nn.Parameter(torch.randn(1, 1, dim) * 1 / dim**0.5)
def forward(self, x, pos_map=None, mask=None):
query = self.query_token.expand(x.shape[0], -1, -1)
return super().forward(query, x, None, pos_map, mask).squeeze(1)
class AttentiveProbe(CrossAttention):
def __init__(self, dim, out_dim, n_heads=8, head_dim=-1, pos_dim=2, n_probes=1):
super().__init__(dim, dim, n_heads, head_dim, pos_dim)
self.query_token = nn.Parameter(torch.randn(1, n_probes, dim) * 1 / dim**0.5)
self.token_proj = nn.Linear(dim * n_probes, out_dim)
def forward(self, x, pos_map=None, mask=None):
query = self.query_token.expand(x.shape[0], -1, -1)
output_embedding = super().forward(query, x, None, pos_map, mask)
output_embedding = output_embedding.flatten(-2, -1)
return self.token_proj(output_embedding)
@cache
def prefix_causal_attention_mask(
q_len, kv_len, prefix_len=0, is_self_attn=False, dtype=None, device=None
):
"""
**Made by claude 3.7 sonnet without thinking**
Generate attention masks and biases for transformer models.
Parameters:
-----------
q_len : int
Length of the query sequence
kv_len : int
Length of the key/value sequence
prefix_len : int, optional
Length of the prefix for which we allow full attention (no causal masking)
Default: 0 (standard causal mask)
is_self_attn : bool, optional
Whether this is for self-attention (q_len == kv_len and they represent the same sequence)
Enables faster mask generation
Default: False
dtype : torch.dtype, optional
Data type for the output tensors
Default: None (will use torch.bool for mask, torch.float for bias)
device : torch.device, optional
Device on which to create the tensors
Default: None (will use the default torch device)
Returns:
--------
tuple: (attention_mask, attention_bias)
- attention_mask: Boolean tensor of shape (q_len, kv_len) where True values indicate
positions that should be attended to
- attention_bias: Tensor of same shape with dtype specified (or float), containing
0.0 for positions to attend to and -float('inf') for positions to mask out
"""
# Fast path for self-attention with no prefix
if is_self_attn and prefix_len == 0:
# Simple lower triangular matrix for standard causal self-attention
attention_mask = torch.tril(
torch.ones(q_len, q_len, dtype=torch.bool, device=device)
)
# Fast path for self-attention with prefix
elif is_self_attn and prefix_len > 0:
attention_mask = torch.tril(
torch.ones(q_len, q_len, dtype=torch.bool, device=device)
)
# Add the prefix part (allow full attention to the prefix)
if prefix_len < q_len:
# Set the prefix columns to all True (we use indexing which is faster than cat)
attention_mask[:, :prefix_len] = True
# General case for cross-attention or when fast path is not used
else:
# Create base causal mask (lower triangular)
# Each query position i can attend to key positions j where j <= i
causal_mask = torch.tril(
torch.ones(q_len, kv_len, dtype=torch.bool, device=device)
)
# If there's a prefix, allow full attention within that prefix
if prefix_len > 0:
# Combine masks:
# - For the prefix part of kv, use all True
# - For the rest, use causal mask
if prefix_len < kv_len:
attention_mask = torch.cat(
[
torch.ones(q_len, prefix_len, dtype=torch.bool, device=device),
causal_mask[:, prefix_len:],
],
dim=1,
)
else:
# If prefix_len >= kv_len, the entire sequence gets full attention
attention_mask = torch.ones(
q_len, kv_len, dtype=torch.bool, device=device
)
else:
# Without prefix, just use the causal mask
attention_mask = causal_mask
# Convert boolean mask to attention bias
# True -> 0.0, False -> -inf
float_dtype = torch.float if dtype is None else dtype
attention_bias = torch.zeros_like(attention_mask, dtype=float_dtype, device=device)
attention_bias = attention_bias.masked_fill(~attention_mask, float("-inf"))
return attention_mask, attention_bias
# Example usage:
if __name__ == "__main__":
# Standard causal mask for sequence length 6
mask, bias = prefix_causal_attention_mask(q_len=6, kv_len=6)
print("Standard causal mask:")
print(mask)
print("\nStandard causal bias:")
print(bias)
# Same with self-attention flag
mask_self, bias_self = prefix_causal_attention_mask(
q_len=6, kv_len=6, is_self_attn=True
)
print("\nSelf-attention causal mask (should be identical):")
print(mask_self)
print("Masks are identical:", torch.all(mask == mask_self).item())
# Causal mask with prefix_len=3 (first 2 tokens get full attention)
mask, bias = prefix_causal_attention_mask(q_len=6, kv_len=6, prefix_len=3)
print("\nCausal mask with prefix_len=3:")
print(mask)
print("\nCausal bias with prefix_len=3:")
print(bias)
# Same with self-attention flag
mask_self, bias_self = prefix_causal_attention_mask(
q_len=6, kv_len=6, prefix_len=3, is_self_attn=True
)
print("\nSelf-attention mask with prefix_len=3 (should be identical):")
print(mask_self)
print("Masks are identical:", torch.all(mask == mask_self).item())
# Handling different q_len and kv_len (for cross-attention)
mask, bias = prefix_causal_attention_mask(q_len=4, kv_len=6, prefix_len=3)
print("\nCross-attention mask with q_len=4, kv_len=6, prefix_len=3:")
print(mask)
print("\nCross-attention bias:")
print(bias)
self_attn = SelfAttention(64, 8).cuda().half()
x = torch.randn(1, 16, 64).cuda().half()
mask, bias = prefix_causal_attention_mask(
16, 16, is_self_attn=True, device=x.device, dtype=x.dtype
)
test_out = self_attn(x, mask=bias)
torch.sum(test_out).backward()
print(x.shape, mask.shape, bias.shape)
print(test_out.shape)
print(torch.isnan(test_out).any())
print(torch.norm(next(self_attn.parameters()).grad))