Disable sageattention for SAM3 (#13529)

Causes Nans
pull/13169/head^2
Jukka Seppänen 2026-04-23 21:14:42 +03:00 committed by GitHub
parent ef8f3cbcdc
commit 084e08c6e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 6 deletions

View File

@ -54,7 +54,7 @@ class SplitMHA(nn.Module):
if mask is not None and mask.ndim == 2:
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
dtype = q.dtype # manual_cast may produce mixed dtypes
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask)
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False)
return self.out_proj(out)

View File

@ -40,7 +40,7 @@ class SAMAttention(nn.Module):
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
return self.out_proj(optimized_attention(q, k, v, self.num_heads))
return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
class TwoWayAttentionBlock(nn.Module):
@ -179,7 +179,7 @@ class Attention(nn.Module):
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
if self.use_rope and freqs_cis is not None:
q, k = apply_rope(q, k, freqs_cis)
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False))
class Block(nn.Module):

View File

@ -364,7 +364,7 @@ class SplitAttn(nn.Module):
v = self.v_proj(v)
if rope is not None:
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
out = optimized_attention(q, k, v, self.num_heads)
out = optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)
return self.out_proj(out)
@ -657,7 +657,7 @@ class DecoupledMemoryAttnLayer(nn.Module):
v = self.self_attn_v_proj(normed)
if rope is not None:
q, k = apply_rope_memory(q, k, rope, self.num_heads, 0)
x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads))
x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
# Decoupled cross-attention: fuse image and memory projections
normed = self.norm2(x)
@ -668,7 +668,7 @@ class DecoupledMemoryAttnLayer(nn.Module):
v = self.cross_attn_v_proj(memory)
if rope is not None:
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads))
x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
# FFN
x = x + self.linear2(F.gelu(self.linear1(self.norm3(x))))