parent
ef8f3cbcdc
commit
084e08c6e2
|
|
@ -54,7 +54,7 @@ class SplitMHA(nn.Module):
|
||||||
if mask is not None and mask.ndim == 2:
|
if mask is not None and mask.ndim == 2:
|
||||||
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
|
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
|
||||||
dtype = q.dtype # manual_cast may produce mixed dtypes
|
dtype = q.dtype # manual_cast may produce mixed dtypes
|
||||||
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask)
|
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False)
|
||||||
return self.out_proj(out)
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ class SAMAttention(nn.Module):
|
||||||
q = self.q_proj(q)
|
q = self.q_proj(q)
|
||||||
k = self.k_proj(k)
|
k = self.k_proj(k)
|
||||||
v = self.v_proj(v)
|
v = self.v_proj(v)
|
||||||
return self.out_proj(optimized_attention(q, k, v, self.num_heads))
|
return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
|
||||||
|
|
||||||
|
|
||||||
class TwoWayAttentionBlock(nn.Module):
|
class TwoWayAttentionBlock(nn.Module):
|
||||||
|
|
@ -179,7 +179,7 @@ class Attention(nn.Module):
|
||||||
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
|
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
|
||||||
if self.use_rope and freqs_cis is not None:
|
if self.use_rope and freqs_cis is not None:
|
||||||
q, k = apply_rope(q, k, freqs_cis)
|
q, k = apply_rope(q, k, freqs_cis)
|
||||||
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
|
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False))
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -364,7 +364,7 @@ class SplitAttn(nn.Module):
|
||||||
v = self.v_proj(v)
|
v = self.v_proj(v)
|
||||||
if rope is not None:
|
if rope is not None:
|
||||||
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
|
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
|
||||||
out = optimized_attention(q, k, v, self.num_heads)
|
out = optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)
|
||||||
return self.out_proj(out)
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -657,7 +657,7 @@ class DecoupledMemoryAttnLayer(nn.Module):
|
||||||
v = self.self_attn_v_proj(normed)
|
v = self.self_attn_v_proj(normed)
|
||||||
if rope is not None:
|
if rope is not None:
|
||||||
q, k = apply_rope_memory(q, k, rope, self.num_heads, 0)
|
q, k = apply_rope_memory(q, k, rope, self.num_heads, 0)
|
||||||
x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads))
|
x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
|
||||||
|
|
||||||
# Decoupled cross-attention: fuse image and memory projections
|
# Decoupled cross-attention: fuse image and memory projections
|
||||||
normed = self.norm2(x)
|
normed = self.norm2(x)
|
||||||
|
|
@ -668,7 +668,7 @@ class DecoupledMemoryAttnLayer(nn.Module):
|
||||||
v = self.cross_attn_v_proj(memory)
|
v = self.cross_attn_v_proj(memory)
|
||||||
if rope is not None:
|
if rope is not None:
|
||||||
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
|
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
|
||||||
x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads))
|
x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
|
||||||
|
|
||||||
# FFN
|
# FFN
|
||||||
x = x + self.linear2(F.gelu(self.linear1(self.norm3(x))))
|
x = x + self.linear2(F.gelu(self.linear1(self.norm3(x))))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue