Remove torch.backends.cuda.sdp_kernel context
parent
8fec5cb42e
commit
99e476a88d
|
|
@ -246,13 +246,14 @@ class Attention(nn.Module):
|
|||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
# with torch.backends.cuda.sdp_kernel(
|
||||
# enable_flash=USE_FLASH_ATTN,
|
||||
# # if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
# enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
# enable_mem_efficient=OLD_GPU,
|
||||
# ):
|
||||
# out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
|
@ -313,13 +314,14 @@ class RoPEAttention(Attention):
|
|||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
# with torch.backends.cuda.sdp_kernel(
|
||||
# enable_flash=USE_FLASH_ATTN,
|
||||
# # if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
# enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
# enable_mem_efficient=OLD_GPU,
|
||||
# ):
|
||||
# out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
|
|
|||
Loading…
Reference in New Issue