Remove torch.backends.cuda.sdp_kernel context

main
Uminosachi 2024-07-31 18:30:54 +09:00
parent 8fec5cb42e
commit 99e476a88d
1 changed files with 16 additions and 14 deletions

View File

@ -246,13 +246,14 @@ class Attention(nn.Module):
dropout_p = self.dropout_p if self.training else 0.0
# Attention
with torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
# with torch.backends.cuda.sdp_kernel(
# enable_flash=USE_FLASH_ATTN,
# # if Flash attention kernel is off, then math kernel needs to be enabled
# enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
# enable_mem_efficient=OLD_GPU,
# ):
# out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
out = self.out_proj(out)
@ -313,13 +314,14 @@ class RoPEAttention(Attention):
dropout_p = self.dropout_p if self.training else 0.0
# Attention
with torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
# with torch.backends.cuda.sdp_kernel(
# enable_flash=USE_FLASH_ATTN,
# # if Flash attention kernel is off, then math kernel needs to be enabled
# enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
# enable_mem_efficient=OLD_GPU,
# ):
# out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
out = self.out_proj(out)