automatic/modules/flash_attn_triton_amd/interface_fa.py

53 lines
1.9 KiB
Python

import torch
from modules.flash_attn_triton_amd.fwd_prefill import attention_prefill_forward_triton_impl
from modules.flash_attn_triton_amd.utils import MetaData
def fwd(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
dropout_p: float,
softmax_scale: float,
causal: bool
):
# Setup metadata
metadata = MetaData(sm_scale=softmax_scale)
metadata.max_seqlens_q = q.shape[1]
metadata.max_seqlens_k = k.shape[1]
metadata.layout = "bshd"
if causal:
metadata.need_causal(True)
if dropout_p > 0.0:
metadata.need_dropout(dropout_p)
# check arguments
metadata.check_args(q, k, v, out)
# call implementation
attention_prefill_forward_triton_impl(
q,
k,
v,
out,
metadata.sm_scale,
metadata.alibi_slopes,
metadata.causal,
None,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.cache_seqlens,
metadata.cache_batch_idx,
metadata.dropout_p,
metadata.philox_seed,
metadata.philox_offset,
False,
metadata.use_exp2)
# varlen