mirror of https://github.com/vladmandic/automatic
53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
import torch
|
|
from modules.flash_attn_triton_amd.fwd_prefill import attention_prefill_forward_triton_impl
|
|
from modules.flash_attn_triton_amd.utils import MetaData
|
|
|
|
|
|
def fwd(q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
out: torch.Tensor,
|
|
dropout_p: float,
|
|
softmax_scale: float,
|
|
causal: bool
|
|
):
|
|
# Setup metadata
|
|
metadata = MetaData(sm_scale=softmax_scale)
|
|
metadata.max_seqlens_q = q.shape[1]
|
|
metadata.max_seqlens_k = k.shape[1]
|
|
metadata.layout = "bshd"
|
|
|
|
if causal:
|
|
metadata.need_causal(True)
|
|
|
|
if dropout_p > 0.0:
|
|
metadata.need_dropout(dropout_p)
|
|
|
|
# check arguments
|
|
metadata.check_args(q, k, v, out)
|
|
|
|
# call implementation
|
|
attention_prefill_forward_triton_impl(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
metadata.sm_scale,
|
|
metadata.alibi_slopes,
|
|
metadata.causal,
|
|
None,
|
|
metadata.layout,
|
|
metadata.cu_seqlens_q,
|
|
metadata.cu_seqlens_k,
|
|
metadata.max_seqlens_q,
|
|
metadata.max_seqlens_k,
|
|
metadata.cache_seqlens,
|
|
metadata.cache_batch_idx,
|
|
metadata.dropout_p,
|
|
metadata.philox_seed,
|
|
metadata.philox_offset,
|
|
False,
|
|
metadata.use_exp2)
|
|
|
|
# varlen
|