zluda flash attention 2 via triton

pull/3845/head
Seunghoon Lee 2025-03-22 23:10:59 +09:00
parent 85bda171f4
commit ab9e87d848
No known key found for this signature in database
GPG Key ID: 436E38F4E70BD152
10 changed files with 3233 additions and 8 deletions

View File

@ -637,6 +637,40 @@ SOFTWARE.
limitations under the License.
</pre>
<h2><a href="https://github.com/Dao-AILab/flash-attention/blob/main/LICENSE">Flash Attention</a></h2>
<small>Fast and memory-efficient exact attention</small>
<pre>
BSD 3-Clause License
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
</pre>
<h2><a href="https://github.com/explosion/curated-transformers/blob/main/LICENSE">Curated transformers</a></h2>
<small>The MPS workaround for nn.Linear on macOS 13.2.X is based on the MPS workaround for nn.Linear created by danieldk for Curated transformers</small>
<pre>

View File

@ -0,0 +1,606 @@
import torch
import triton
import triton.language as tl
from modules.flash_attn_triton_amd.utils import get_shape_from_layout, get_strides_from_layout
@triton.jit
def _bwd_preprocess_use_o(
Out,
DO,
Delta,
stride_oz, stride_oh, stride_om, stride_ok,
stride_doz, stride_doh, stride_dom, stride_dok, # pylint: disable=unused-argument
stride_deltaz, stride_deltah, stride_deltam,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
N_CTX_Q: tl.constexpr,
Z: tl.constexpr, # pylint: disable=unused-argument
H: tl.constexpr,
IS_VARLEN: tl.constexpr
):
pid_m = tl.program_id(0)
pid_bh = tl.program_id(1)
# Compute batch and head indices
off_z = pid_bh // H
off_h = pid_bh % H
if IS_VARLEN:
# Compute sequence lengths for the current batch
q_start = tl.load(cu_seqlens_q + off_z)
q_end = tl.load(cu_seqlens_q + off_z + 1)
k_start = tl.load(cu_seqlens_k + off_z)
k_end = tl.load(cu_seqlens_k + off_z + 1)
# Compute actual sequence lengths
N_CTX_Q = q_end - q_start
N_CTX_K = k_end - k_start # pylint: disable=unused-variable
else:
q_start = 0
k_start = 0
N_CTX_Q = max_seqlen_q
N_CTX_K = max_seqlen_k # pylint: disable=unused-variable
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_d = tl.arange(0, BLOCK_DMODEL)
# create masks
mask_m = off_m < N_CTX_Q
mask_d = off_d < ACTUAL_BLOCK_DMODEL
# compute offsets
o_offset = Out + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
do_offset = DO + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
# compute pointers
out_ptrs = o_offset + off_m[:, None] * stride_om + off_d[None, :] * stride_ok
do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok
# load
o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
# compute delta
delta = tl.sum(o * do, axis=1)
# write-back delta
delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
delta_ptrs = delta_offset + off_m * stride_deltam
tl.store(delta_ptrs, delta, mask=mask_m)
@triton.jit
def _bwd_kernel_one_col_block(
Q,
K,
V,
sm_scale,
Out, DO, DQ, DK, DV, L, D, # pylint: disable=unused-argument
q_offset,
k_offset,
v_offset,
do_offset,
dq_offset,
dk_offset,
dv_offset,
d_offset,
l_offset,
stride_dq_all, stride_qz, stride_qh, # pylint: disable=unused-argument
stride_qm,
stride_qk,
stride_kz, stride_kh, # pylint: disable=unused-argument
stride_kn,
stride_kk,
stride_vz, stride_vh, # pylint: disable=unused-argument
stride_vn,
stride_vk,
stride_deltaz, stride_deltah, # pylint: disable=unused-argument
stride_deltam,
Z, H, # pylint: disable=unused-argument
N_CTX_Q,
N_CTX_K,
off_h, off_z, off_hz, # pylint: disable=unused-argument
start_n,
num_block_m,
num_block_n, # pylint: disable=unused-argument
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
USE_EXP2: tl.constexpr,
):
if CAUSAL:
# TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
lo = 0
else:
lo = 0
# initialize col and head offsets
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
# masks
mask_n = offs_n < N_CTX_K
mask_d = offs_d < ACTUAL_BLOCK_DMODEL
kv_mask = mask_n[:, None] & mask_d[None, :]
# initialize grad accumulators
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
# load k and v once per column block
k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
# loop over rows
for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M):
offs_m = start_m + tl.arange(0, BLOCK_M)
q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
# update mask as row block changes
mask_m = offs_m < N_CTX_Q
q_mask = mask_m[:, None] & mask_d[None, :]
# load q, k, v, do on-chip
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
do = tl.load(do_ptrs, mask=q_mask, other=0.0)
# recompute p = softmax(qk, dim=-1).T
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
if CAUSAL:
col_offset = N_CTX_Q - N_CTX_K
causal_mask = offs_m[:, None] >= (col_offset + offs_n[None, :])
qk = tl.where(causal_mask, qk, float("-inf"))
l_ptrs = l_offset + offs_m * stride_deltam
l_i = tl.load(l_ptrs, mask=mask_m)
# compute p
if USE_EXP2:
RCP_LN2: tl.constexpr = 1.4426950408889634
qk *= sm_scale * RCP_LN2
l_i *= RCP_LN2
p = tl.math.exp2(qk - l_i[:, None])
else:
qk *= sm_scale
p = tl.math.exp(qk - l_i[:, None])
# mask block in the cases where the data is smaller the block size
p_mask = mask_m[:, None] & mask_n[None, :]
p = tl.where(p_mask, p, 0.0)
# compute dv
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp
dp = tl.dot(do, tl.trans(v))
# compute ds , ds = p * (dp - delta[:, None])
d_ptrs = d_offset + offs_m * stride_deltam
Di = tl.load(d_ptrs, mask=mask_m)
ds = (p * (dp - Di[:, None])) * sm_scale
ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty)
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds), q)
# compute dq
if SEQUENCE_PARALLEL:
dq = tl.dot(ds, k)
else:
dq = tl.load(dq_ptrs, mask=q_mask, other=0.0)
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask)
# write-back dv and dk
dk_ptrs = dk_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
# write-back
tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)
@triton.jit
def _bwd_kernel(
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
D,
stride_dq_all,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vn,
stride_vk,
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
num_block_m,
num_block_n,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
USE_EXP2: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
# program ids
off_hz = tl.program_id(0)
if SEQUENCE_PARALLEL:
start_n = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
if IS_VARLEN:
# Compute sequence lengths for the current batch
q_start = tl.load(cu_seqlens_q + off_z)
q_end = tl.load(cu_seqlens_q + off_z + 1)
k_start = tl.load(cu_seqlens_k + off_z)
k_end = tl.load(cu_seqlens_k + off_z + 1)
# Compute actual sequence lengths
N_CTX_Q = q_end - q_start
N_CTX_K = k_end - k_start
else:
q_start = 0
k_start = 0
N_CTX_Q = max_seqlen_q
N_CTX_K = max_seqlen_k
# input tensor offsets
q_offset = Q + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
k_offset = K + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
v_offset = V + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
do_offset = DO + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
l_offset = L + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
d_offset = D + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
# output tensor offsets
dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
if SEQUENCE_PARALLEL:
dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
else:
dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
# inner loop
if SEQUENCE_PARALLEL:
_bwd_kernel_one_col_block(
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
D,
q_offset,
k_offset,
v_offset,
do_offset,
dq_offset,
dk_offset,
dv_offset,
d_offset,
l_offset,
stride_dq_all,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vn,
stride_vk,
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
N_CTX_Q,
N_CTX_K,
off_h,
off_z,
off_hz,
start_n,
num_block_m,
num_block_n,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
USE_EXP2=USE_EXP2,
)
else:
for start_n in range(0, num_block_n):
_bwd_kernel_one_col_block(
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
D,
q_offset,
k_offset,
v_offset,
do_offset,
dq_offset,
dk_offset,
dv_offset,
d_offset,
l_offset,
stride_dq_all,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vn,
stride_vk,
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
N_CTX_Q,
N_CTX_K,
off_h,
off_z,
off_hz,
start_n,
num_block_m,
num_block_n,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
USE_EXP2=USE_EXP2,
)
# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom.
def attention_prefill_backward_triton_impl(
do,
q,
k,
v,
o,
softmax_lse,
dq,
dk,
dv,
sm_scale: float,
alibi_slopes, # pylint: disable=unused-argument
causal,
layout: str,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q: int,
max_seqlen_k: int,
use_exp2: bool,
sequence_parallel = True,
):
# make contigious
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
softmax_lse = softmax_lse.contiguous()
# get strides and shape
batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
stride_qz, stride_qh, stride_qm, stride_qk = q_strides
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
stride_oz, stride_oh, stride_om, stride_ok = o_strides
batch_headsize = batch * nheads_q
is_varlen = layout == "thd"
# FIXME: some configs lead to oom for some reason when using 64 x 64 blocks
if max_seqlen_q <= 32 or max_seqlen_k <= 32:
BLOCK_M = 32
BLOCK_N = 32
else:
BLOCK_M = 64
BLOCK_N = 64
num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful
num_stages = 1
waves_per_eu = 1
# divide up the problem
num_blocks_m = triton.cdiv(max_seqlen_q, BLOCK_M)
num_blocks_n = triton.cdiv(max_seqlen_k, BLOCK_N)
# get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
padded_d_model = max(padded_d_model, 16)
BLOCK_DMODEL = padded_d_model
ACTUAL_BLOCK_DMODEL = head_size
do = do.contiguous()
# NOTE: we might need to copy the output tensor if they are not continuous or have other issues
copy_back = {"dq": False, "dk": False, "dv": False}
dq_og = None
# deal with dq
if dq is None:
if sequence_parallel:
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
else:
dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype)
else:
dq_og = dq
if not dq.is_contiguous():
dq = dq.contiguous()
copy_back["dq"] = True
if sequence_parallel:
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
copy_back["dq"] = True
else:
# NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
dq.zero_()
stride_dq_all = dq.stride()[0]
dk_og = None
dv_og = None
# deal with dk, dv
if (dk is None) or (dv is None):
dk = torch.empty_like(k)
dv = torch.empty_like(v)
else:
if not dk.is_contiguous():
dk_og = dk
dk = dk.contiguous()
copy_back["dk"] = True
if not dv.is_contiguous():
dv_og = dv
dv = dv.contiguous()
copy_back["dv"] = True
# assert contigious
assert do.is_contiguous()
assert q.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert o.is_contiguous()
assert softmax_lse.is_contiguous()
# init delta
delta = torch.empty_like(softmax_lse)
if is_varlen:
stride_deltam, stride_deltah = delta.stride()
stride_deltaz = 0
else:
stride_deltaz, stride_deltah, stride_deltam = delta.stride()
_bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
o,
do,
delta,
stride_oz, stride_oh, stride_om, stride_ok,
stride_oz, stride_oh, stride_om, stride_ok,
stride_deltaz, stride_deltah, stride_deltam,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
N_CTX_Q=max_seqlen_q,
Z=batch,
H=nheads_q,
IS_VARLEN=is_varlen
)
_bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)](
q,
k,
v,
sm_scale,
o,
do,
dq,
dk,
dv,
softmax_lse,
delta,
stride_dq_all,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_deltaz, stride_deltah, stride_deltam,
batch,
nheads_q,
num_blocks_m,
num_blocks_n,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
CAUSAL=causal,
USE_EXP2=use_exp2,
num_warps=num_warps,
num_stages=num_stages,
waves_per_eu = waves_per_eu,
IS_VARLEN=is_varlen
)
if sequence_parallel:
dq = dq.sum(dim=0)
if copy_back["dq"]:
dq_og.copy_(dq)
dq = dq_og
if copy_back["dk"]:
dk_og.copy_(dk)
dk = dk_og
if copy_back["dv"]:
dv_og.copy_(dv)
dv = dv_og
return dq, dk, dv, delta, None, None

View File

@ -0,0 +1,271 @@
import math
import torch
def attention_backward_core_ref_impl(
do, q, k, v, o, softmax_lse, sm_scale, causal, use_exp2
):
# cast to float32
do = do.to(torch.float32)
q = q.to(torch.float32)
k = k.to(torch.float32)
v = v.to(torch.float32)
o = o.to(torch.float32)
softmax_lse = softmax_lse.to(torch.float32)
# recompute attention_scores. Make sure it matches the forward impl. i.e. It use float32
attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
# scale scores
attention_scaled_scores = sm_scale * attention_scores
# Apply causal mask if necessary
if causal:
L_q, L_k = q.shape[1], k.shape[1]
row_idx = torch.arange(L_q, device=q.device).unsqueeze(1)
col_idx = torch.arange(L_k, device=q.device).unsqueeze(0)
col_offset = L_q-L_k
causal_mask = row_idx >= (col_offset + col_idx)
# set -inf to places the causal mask is false
attention_scaled_scores = attention_scaled_scores.masked_fill(
torch.logical_not(causal_mask.unsqueeze(0)), float('-inf')
)
# compute probabilities using softmax_lse
if use_exp2:
RCP_LN = 1 / math.log(2)
attention_scaled_scores_base2 = attention_scaled_scores * RCP_LN
softmax_lse_base2 = softmax_lse * RCP_LN
softmax_lse_3d = softmax_lse_base2.unsqueeze(-1)
p = torch.exp2(attention_scaled_scores_base2 - softmax_lse_3d)
else:
softmax_lse_3d = softmax_lse.unsqueeze(-1)
p = torch.exp(attention_scaled_scores - softmax_lse_3d)
# compute gradient wrt v
dv = torch.matmul(p.transpose(-2, -1), do.to(torch.float32))
# compute dp
dp = torch.matmul(do, v.transpose(-2, -1))
# calculate ds using dp
delta = torch.sum(o * do, axis=-1).to(torch.float32) # what OAI kernel uses
delta_3d = delta.unsqueeze(-1)
ds = (p * (dp - delta_3d)) * sm_scale
# compute gradient wrt k
dk = torch.matmul(ds.transpose(-2, -1), q.to(torch.float32))
# compute gradient wrt q
dq = torch.matmul(ds, k.to(torch.float32))
# cast back to original dtype
dq = dq.to(torch.float16)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
# remove d dim with size 1
delta = delta_3d.squeeze(-1)
return dq, dk, dv, delta
def attention_varlen_backward_pytorch_ref_impl(
do,
q,
k,
v,
o,
softmax_lse,
sm_scale,
causal,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q, max_seqlen_k, # pylint: disable=unused-argument
use_exp2,
):
# Ensure the layout is 'thd'
if layout != 'thd':
raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.")
batch_size = cu_seqlens_q.shape[0] - 1
num_heads = q.shape[1]
head_dim = q.shape[2] # pylint: disable=unused-variable
# Pre-allocate outputs
total_L_q = q.shape[0]
total_L_k = k.shape[0] # pylint: disable=unused-variable
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
# delta has the same shape as softmax_lse: [total_L_q, num_heads]
delta = torch.zeros((total_L_q, num_heads), dtype=torch.float32, device=o.device)
for i in range(batch_size):
# Get the start and end indices for the current sequence
start_q = cu_seqlens_q[i].item()
end_q = cu_seqlens_q[i + 1].item()
start_k = cu_seqlens_k[i].item()
end_k = cu_seqlens_k[i + 1].item()
# Extract q_i, k_i, v_i, do_i, o_i, softmax_lse_i
q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
do_i = do[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
o_i = o[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
# softmax_lse has shape [total_L_q, num_heads]
softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, num_heads]
softmax_lse_i = softmax_lse_i.transpose(0, 1) # [num_heads, L_q_i]
# Permute to [num_heads, L_q_i, head_dim]
q_i = q_i.permute(1, 0, 2)
k_i = k_i.permute(1, 0, 2)
v_i = v_i.permute(1, 0, 2)
do_i = do_i.permute(1, 0, 2)
o_i = o_i.permute(1, 0, 2)
# softmax_lse_i is already in [num_heads, L_q_i]
# Call the core backward function for this sequence
dq_i, dk_i, dv_i, delta_i = attention_backward_core_ref_impl(
do_i,
q_i,
k_i,
v_i,
o_i,
softmax_lse_i,
sm_scale,
causal,
use_exp2
)
# Convert back to 'thd' layout
dq_i = dq_i.permute(1, 0, 2) # [L_q_i, num_heads, head_dim]
dk_i = dk_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim]
dv_i = dv_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim]
# Place outputs in pre-allocated tensors
dq[start_q:end_q, :, :] = dq_i
dk[start_k:end_k, :, :] += dk_i # Accumulate gradients for shared keys
dv[start_k:end_k, :, :] += dv_i # Accumulate gradients for shared values
# delta_i has shape [num_heads, L_q_i]
delta_i = delta_i.transpose(1, 0) # [L_q_i, num_heads]
delta[start_q:end_q, :] = delta_i
return dq, dk, dv, delta
def attention_vanilla_backward_pytorch_ref_impl(
do,
q,
k,
v,
o,
softmax_lse,
sm_scale,
causal,
layout,
use_exp2,
):
if layout == "bshd":
do = do.transpose(1, 2).contiguous()
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
o = o.transpose(1, 2).contiguous()
elif layout == "bhsd":
pass
else:
raise ValueError(f"Unknown layout {layout}")
# Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format
batch_size, num_heads, seq_len_q, head_dim = q.shape
seq_len_k = k.shape[2]
# Merge batch and heads dimensions
do = do.reshape(batch_size * num_heads, seq_len_q, head_dim)
q = q.reshape(batch_size * num_heads, seq_len_q, head_dim)
k = k.reshape(batch_size * num_heads, seq_len_k, head_dim)
v = v.reshape(batch_size * num_heads, seq_len_k, head_dim)
softmax_lse = softmax_lse.reshape(batch_size * num_heads, seq_len_q)
o = o.reshape(batch_size * num_heads, seq_len_q, head_dim)
dq, dk, dv, delta = attention_backward_core_ref_impl(
do,
q,
k,
v,
o,
softmax_lse,
sm_scale,
causal,
use_exp2
)
# Reshape outputs back to [batch_size, num_heads, seq_len, head_dim]
dq = dq.reshape(batch_size, num_heads, seq_len_q, head_dim)
dk = dk.reshape(batch_size, num_heads, seq_len_k, head_dim)
dv = dv.reshape(batch_size, num_heads, seq_len_k, head_dim)
delta = delta.reshape(batch_size, num_heads, seq_len_q)
# Go back to original layout
if layout == "bshd":
dq = dq.transpose(1, 2)
dk = dk.transpose(1, 2)
dv = dv.transpose(1, 2)
elif layout == "bhsd":
pass
else:
raise ValueError(f"Unknown layout {layout}")
return dq, dk, dv, delta
def attention_backward_pytorch_ref_impl(
do,
q,
k,
v,
o,
softmax_lse,
sm_scale,
causal,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
use_exp2
):
if layout == "thd":
dq, dk, dv, delta = attention_varlen_backward_pytorch_ref_impl(
do,
q,
k,
v,
o,
softmax_lse,
sm_scale,
causal,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
use_exp2,
)
else:
dq, dk, dv, delta = attention_vanilla_backward_pytorch_ref_impl(
do,
q,
k,
v,
o,
softmax_lse,
sm_scale,
causal,
layout,
use_exp2,
)
return dq, dk, dv, delta

View File

@ -0,0 +1,700 @@
import torch
import triton
import triton.language as tl
from modules.flash_attn_triton_amd.utils import _strides, get_padded_headsize
@triton.jit
def _fwd_kernel_splitK(
Q,
K,
V,
sm_scale,
Out_splitK, # [B, H, split_k, Mq, K]
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
K_new,
V_new,
Cache_seqlens,
Cache_batch_idx,
Alibi_slopes,
stride_qz,
stride_qm,
stride_qg,
stride_qh,
stride_qd,
stride_kz,
stride_kn,
stride_kg,
stride_kh,
stride_kd,
stride_vz,
stride_vn,
stride_vg,
stride_vh,
stride_vd,
stride_osk_zhg,
stride_osk_s,
stride_osk_m,
stride_osk_d, # pylint: disable=unused-argument
stride_mzhg,
stride_m2,
stride_ms,
stride_mm, # pylint: disable=unused-argument
stride_kn_z,
stride_kn_n,
stride_kn_g,
stride_kn_h,
stride_kn_d,
stride_vn_z,
stride_vn_n,
stride_vn_g,
stride_vn_h,
stride_vn_d,
stride_az,
stride_ah,
Z, # pylint: disable=unused-argument
N_CTX_Q,
N_CTX_K,
N_CTX_NEW,
BLOCK_N_PER_SPLIT,
H_q: tl.constexpr,
H_kv: tl.constexpr,
G_q: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
BOUNDS_CHECKS_N: tl.constexpr,
USE_CACHE_SEQLENs: tl.constexpr,
USE_CACHE_BATCH_IDX: tl.constexpr,
NEW_KV: tl.constexpr,
IS_GQA: tl.constexpr,
IS_CAUSAL: tl.constexpr,
USE_ALIBI: tl.constexpr,
):
# Padding
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
if PADDED_HEAD:
d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL
start_m = tl.program_id(0)
off_zhg = tl.program_id(1)
off_z = off_zhg // (H_q * G_q)
off_h_q = (off_zhg // G_q) % H_q
off_g_q = off_zhg % G_q
splitk_idx = tl.program_id(2)
# pick batch index
if USE_CACHE_BATCH_IDX:
cache_batch_idx = tl.load(Cache_batch_idx + off_z)
else:
cache_batch_idx = off_z
# Load ALiBi slope if enabled
if USE_ALIBI:
a_offset = off_z * stride_az + off_h_q * stride_ah
alibi_slope = tl.load(Alibi_slopes + a_offset)
else:
alibi_slope = None
lo = splitk_idx * BLOCK_N_PER_SPLIT
if USE_CACHE_SEQLENs:
cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z)
if NEW_KV:
kv_len = cache_seqlen_last_idx + N_CTX_NEW
else:
kv_len = cache_seqlen_last_idx
else:
kv_len = N_CTX_K
hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
HEAD_RATIO: tl.constexpr = H_q // H_kv
if IS_GQA:
k_head_idx = off_h_q // HEAD_RATIO
v_head_idx = k_head_idx
else:
k_head_idx = off_h_q
v_head_idx = off_h_q
# calculate base offset
k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg
v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg
# Copy new Keys and Values into Cache
if NEW_KV:
knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g
# Determine the starting position for new data in the cache
if USE_CACHE_SEQLENs:
start_idx = tl.load(Cache_seqlens + off_z)
else:
start_idx = N_CTX_K - N_CTX_NEW
# Copy new Keys
for i in range(0, N_CTX_NEW, BLOCK_N):
# Load from K_new
k_new_block = tl.load(
knew_base +
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d +
(tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n,
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
other=0
)
# Store to K
tl.store(
k_base +
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd +
(tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn,
k_new_block,
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
)
# Copy new Values
vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g
for i in range(0, N_CTX_NEW, BLOCK_N):
# Load from V_new
v_new_block = tl.load(
vnew_base +
(tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n +
tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d,
mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
other=0
)
# Store to V
tl.store(
v_base +
(tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn +
tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd,
v_new_block,
mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
)
Q_block_ptr = tl.make_block_ptr(
base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg,
shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qd),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=k_base,
shape=(ACTUAL_BLOCK_DMODEL, hi),
strides=(stride_kd, stride_kn),
offsets=(0, lo),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=v_base,
shape=(hi, ACTUAL_BLOCK_DMODEL),
strides=(stride_vn, stride_vd),
offsets=(lo, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
K_scale_shift_block_ptr = None
V_scale_shift_block_ptr = None
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load( # noqa: F821
tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, ))
q = (q * qk_scale).to(q.dtype)
if PADDED_HEAD:
q = tl.where(d_mask[None, :], q, 0.0)
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
k, v = load_k_v_group(
K_block_ptr,
V_block_ptr,
K_scale_shift_block_ptr,
V_scale_shift_block_ptr,
BOUNDS_CHECKS_N,
1,
BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL,
Q.dtype.element_ty,
0,
)
if PADDED_HEAD:
k = tl.where(d_mask[:, None], k, 0.0)
v = tl.where(d_mask[None, :], v, 0.0)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) # noqa: F821
if USE_ALIBI:
row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
col_idx = start_n + tl.arange(0, BLOCK_N)
# Compute relative positions
relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :])
relative_pos = tl.abs(relative_pos)
# Compute ALiBi bias
alibi_bias = -1 * alibi_slope * relative_pos
qk += (alibi_bias * 1.44269504)
# Apply causal mask if IS_CAUSAL is True
if IS_CAUSAL:
row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
col_idx = start_n + tl.arange(0, BLOCK_N)
# create a N_CTX_Q x kv_len causal mask
col_offset = N_CTX_Q - kv_len
causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :])
# Apply the mask
qk = tl.where(causal_mask, qk, float("-inf"))
# TODO: This is slow, and only needed at the last iteration.
# Maybe we can unroll the last iteration instead?
if BOUNDS_CHECKS_N:
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
if IS_CAUSAL:
alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf")))
else:
alpha = tl.math.exp2(m_i - m_i_new)
# cause of nan because subtracting infs
if IS_CAUSAL:
qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf"))
else:
qk = qk - m_i_new[:, None]
p = tl.math.exp2(qk)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
p = p.to(Q.dtype.element_ty)
# -- scale and update acc --
acc *= alpha[:, None]
acc += tl.dot(p.to(v.dtype), v)
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
shape=(N_CTX_Q, BLOCK_DMODEL),
strides=(stride_osk_m, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
tl.store(
tl.advance(O_block_ptr, (0, 0)),
acc,
boundary_check=(0, ),
)
# Write metadata for split-K reduction
Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M +
tl.arange(0, BLOCK_M))
tl.store(Metadata_ptr, m_i)
tl.store(Metadata_ptr + stride_m2, l_i)
@triton.jit
def load_k_v_group(
K_block_ptr,
V_block_ptr,
K_scale_shift_block_ptr, V_scale_shift_block_ptr, # pylint: disable=unused-argument
BOUNDS_CHECKS_N: tl.constexpr,
PACKED_PER_VAL: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # pylint: disable=unused-argument
ACTUAL_BLOCK_DMODEL: tl.constexpr,
dtype: tl.constexpr, # pylint: disable=unused-argument
group_id: tl.constexpr,
):
# Load K/V for a given block
# Advance to the current quantization group
K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0))
V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id))
# -- load k, v --
k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ())
v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ())
return k, v
@triton.jit
def cast_uint32_to_half2(scale_shift):
# Extract two float16 packed into one int32
scale = scale_shift & 0xFFFF
shift = scale_shift >> 16
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
return scale, shift
@triton.jit
def dequantize(
x_,
scale,
shift,
PACKED_PER_VAL: tl.constexpr = 8,
):
# PACKED_PER_VAL is the number of values packed into
# each element x_. For example, for int4 quantization
#and x_ of type int32, PACKED_PER_VAL is 8.
BLOCK_N: tl.constexpr = x_.shape[0]
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
offsets = tl.arange(0, PACKED_PER_VAL) * 4
quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL))
# Trick - instead of converting int4 to float16 we view it as float16
# and then multiply by 32768 * 512 == 2**24
quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
quant_offset = (quant_offset * 32768.0).to(tl.float16)
scale_512 = scale * 512
dequant = quant_offset * scale_512 + shift
return dequant
@triton.jit
def _splitK_reduce(
Out_splitK, # [B, H, split_k, Mq, K]
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
Out, # [B, H, M, K]
LSE, # [B, H, M]
stride_osk_zhg,
stride_osk_s,
stride_osk_m,
stride_osk_k,
stride_mzhg,
stride_m2,
stride_ms,
stride_mm,
stride_oz,
stride_oh,
stride_og,
stride_om,
stride_ok, # pylint: disable=unused-argument
stride_lse_zhg,
stride_lse_m, M_ceil: tl.constexpr, # pylint: disable=unused-argument
BLOCK_SIZE: tl.constexpr,
H: tl.constexpr,
G: tl.constexpr,
split_k: tl.constexpr,
splitK_pow2: tl.constexpr,
use_mask: tl.constexpr,
IS_CAUSAL: tl.constexpr,
):
off_zhg = tl.program_id(0)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
off_m = tl.program_id(1)
off_k = tl.program_id(2)
# read chunk
spk_idx = tl.arange(0, splitK_pow2)
kidx = tl.arange(0, BLOCK_SIZE)
Metadata_ptr = Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm
o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE +
stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k)
# read max values of each splitK
if use_mask:
spk_mask = spk_idx < split_k
l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf"))
l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0)
acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0)
else:
l_m = tl.load(Metadata_ptr)
l_sum = tl.load(Metadata_ptr + stride_m2)
acc = tl.load(o_ptr)
g_m = tl.max(l_m, axis=0)
if IS_CAUSAL:
l_m_offset = l_m - g_m
alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0)
else:
alpha = tl.math.exp2(l_m - g_m)
# read sum
l_sum *= alpha
g_sum = tl.sum(l_sum, axis=0)
acc = acc * alpha[:, None]
if IS_CAUSAL:
# Avoid division by zero
g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0)
acc_out = tl.sum(acc, axis=0) / g_sum_safe
else:
acc_out = tl.sum(acc, axis=0) / g_sum
# Store output
Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m +
off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE))
tl.store(Out_ptr, acc_out)
# Store lse
l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
if IS_CAUSAL:
lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m)
tl.store(l_ptrs, lse)
else:
tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504)
def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
# Scale and shift are such that quantization linearly maps
# int4 values range [0..15] to input values range min(k)..max(k)
# individually for every row
k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups)
max_vals = torch.max(k, dim=-1, keepdim=True).values
min_vals = torch.min(k, dim=-1, keepdim=True).values
scale_k: torch.Tensor = (max_vals - min_vals) / 15
shift_k = torch.min(k, dim=-1, keepdim=True).values
scale_k = scale_k.to(torch.float16)
shift_k = shift_k.to(torch.float16)
in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5
in_bytes = in_bytes.to(torch.uint8)
in_int4 = in_bytes & 0xF
in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4)
scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1)
k_quant = torch.concat(
[
scale_shift.flatten(start_dim=-2),
in_int4_packed.flatten(start_dim=-2),
],
dim=-1,
).view(torch.int16)
return k_quant
def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
k_i16 = quant_k.view(torch.int16)
k_ui8 = k_i16.view(torch.uint8)
ss_size = num_groups * 4
scale_shift_ui8 = k_ui8[..., 0:ss_size]
scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4)
scale = scale_shift_ui8[..., 0:2].view(torch.float16)
shift = scale_shift_ui8[..., 2:4].view(torch.float16)
kv_ui8 = k_ui8[..., ss_size:]
k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1)
k1_i4 = k_ui8 & 0xF
k2_i4 = (k_ui8 & 0xF0) >> 4
k_shape = k1_i4.shape
k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)
k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)
out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device)
out[..., ::2] = k1_f16
out[..., 1::2] = k2_f16
out = out.reshape(*k_shape[:-2], -1)
return out
def get_split_k(B: int, G: int, H: int, Mk: int) -> int:
"""Heuristic for the number of splits"""
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
split_k = max(Mk, 1024) // bh
max_chunk_size = 64
while split_k > 0 and Mk / split_k < max_chunk_size:
split_k = split_k // 2
while B * H * G * split_k >= 1024:
split_k = split_k // 2
split_k = min(split_k, 512)
split_k = max(split_k, 1)
return split_k
def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new):
# kernel config
BLOCK_M = 16
BLOCK_N = 64
SPLIT_K = None
NUM_QUANT_GROUPS = 1 # pylint: disable=unused-variable
# kernels expects "bsghd"
original_layout = layout
if layout == "bshd":
q = q.unsqueeze(2)
k = k.unsqueeze(2)
v = v.unsqueeze(2)
if new_kv:
k_new = k_new.unsqueeze(2)
v_new = v_new.unsqueeze(2)
layout = "bsghd"
elif layout == "bhsd":
q = q.permute(0, 2, 1, 3).unsqueeze(2)
k = k.permute(0, 2, 1, 3).unsqueeze(2)
v = v.permute(0, 2, 1, 3).unsqueeze(2)
if new_kv:
k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2)
v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2)
layout = "bsghd"
elif layout == "bsghd":
pass
elif layout is None:
raise ValueError("Layout not given")
assert layout == "bsghd"
# get dims
batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape
_, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape # pylint: disable=unused-variable
_, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape # pylint: disable=unused-variable
assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}"
# get padded size
dim_padded = get_padded_headsize(dim_k)
# Handle MQA/GQA case
if heads_per_group_q > heads_per_group_k:
is_gqa = True
elif heads_per_group_q < heads_per_group_k:
raise ValueError("heads_per_group_q < heads_per_group_k")
else:
is_gqa = False
assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}"
if SPLIT_K is not None:
split_k = SPLIT_K
else:
# Use heuristics
split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_k) # NOTE: should the split think about seqlens?
seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M
out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device)
metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device)
lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), device=q.device, dtype=torch.float32)
grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k)
num_warps = 1
split_size = (seqlen_k + split_k - 1) // split_k
use_cache_seqlens = cache_seqlens is not None
# TODO: enable quantization
_fwd_kernel_splitK[grid](
Q=q,
K=k,
V=v,
sm_scale=sm_scale,
Out_splitK=out_splitk,
Metadata=metadata,
K_new = k_new,
V_new = v_new,
Cache_seqlens=cache_seqlens,
Cache_batch_idx=cache_batch_idx,
Alibi_slopes=alibi_slopes,
**_strides(q, "qz", "qm", "qg", "qh", "qd"),
**_strides(k, "kz", "kn", "kg", "kh", "kd"),
**_strides(v, "vz", "vn", "vg", "vh", "vd"),
**_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"),
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
**_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"),
**_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"),
**_strides(alibi_slopes, "az", "ah"),
Z=batch_size,
H_q=heads_per_group_q,
H_kv=heads_per_group_k,
G_q=n_group_q,
N_CTX_Q=seqlen_q,
N_CTX_K=seqlen_k,
N_CTX_NEW=k_new.shape[1] if new_kv else None,
BLOCK_N_PER_SPLIT=split_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=dim_padded,
ACTUAL_BLOCK_DMODEL=dim_k,
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens,
USE_CACHE_SEQLENs=use_cache_seqlens,
USE_CACHE_BATCH_IDX=cache_batch_idx is not None,
NEW_KV=new_kv,
IS_GQA=is_gqa,
IS_CAUSAL=causal,
USE_ALIBI=False if alibi_slopes is None else True,
num_warps=num_warps,
num_stages=1,
)
out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype)
# Merge together
splitK_pow2 = triton.next_power_of_2(split_k)
use_mask = splitK_pow2 > split_k
if batch_size * n_group_q * heads_per_group_q * seqlen_q >= 512:
k_block_num = 1
else:
k_block_num = 2
assert dim_padded % k_block_num == 0
k_block_size = dim_padded // k_block_num
grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num)
_splitK_reduce[grid](
out_splitk,
metadata,
out,
lse,
**_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
**_strides(out, "oz", "om", "og", "oh", "ok"),
**_strides(lse, "lse_zhg", "lse_m"),
M_ceil=seqlen_q_ceil,
BLOCK_SIZE=k_block_size,
G=n_group_q,
H=heads_per_group_q,
# TODO: Tune num_warps
split_k=split_k,
splitK_pow2=splitK_pow2,
use_mask=use_mask,
IS_CAUSAL=causal,
num_warps=4)
lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q])
if q.ndim == 4:
# BMGHK -> BMHK
assert n_group_q == 1
out = out[:, :, 0]
lse = lse[:, 0]
if seqlen_k == 0:
out.zero_()
out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous()
# output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q
if original_layout == "bshd":
# out=out.transpose(1, 2).contiguous() # this screws up heads and data.
# the data is laid out properly. Just need to reshape dims
out = out.reshape(batch_size, seqlen_q, -1, dim_padded)
return out.narrow(-1, 0, dim_k), lse

View File

@ -0,0 +1,634 @@
import torch
import triton
import triton.language as tl
from modules.flash_attn_triton_amd.utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, AUTOTUNE
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): # pylint: disable=unused-argument
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
rng_keep = rng_output > dropout_p
return rng_keep
# Convenience function to load with optional boundary checks.
# "First" is the major dim, "second" is the minor dim.
@triton.jit
def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second):
if offset_first is not None and offset_second is not None:
mask = (offset_first[:, None] < boundary_first) & \
(offset_second[None, :] < boundary_second)
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_first is not None:
mask = offset_first[:, None] < boundary_first
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_second is not None:
mask = offset_second[None, :] < boundary_second
tensor = tl.load(ptrs, mask=mask, other=0.0)
else:
tensor = tl.load(ptrs)
return tensor
@triton.jit
def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1],
# [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :]
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
if transpose:
return alibi_block.T
else:
return alibi_block
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m,
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, exp_scores_ptrs,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, # pylint: disable=unused-argument
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_EXP2: tl.constexpr,
RETURN_SCORES: tl.constexpr):
if USE_EXP2:
RCP_LN2: tl.constexpr = 1.4426950408889634
# loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
if MASK_STEPS:
k_offs_n = start_n + tl.arange(0, BLOCK_N)
else:
k_offs_n = None
k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)
k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k)
if PRE_LOAD_V:
# We can use the same offsets as k, just with dims transposed.
v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS:
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn.
# last step might get wasted but that is okay. check if this masking works For
# that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))
# -- compute qk ----
qk += tl.dot(q, k)
qk_scaled = qk * SM_SCALE
if RETURN_SCORES:
score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
tl.store(score_ptrs, qk_scaled, mask=score_mask)
if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf"))
if bias_ptrs is not None:
bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None
bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k)
qk_scaled += bias
if alibi_slope is not None:
# Compute the global position of each token within the sequence
global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
global_n_positions = start_n + tl.arange(0, BLOCK_N)
alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions,
global_n_positions)
qk_scaled += alibi_block
# get max scores so far
m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1))
# scale and subtract max
q_shifted = qk_scaled - m_ij[:, None]
if RETURN_SCORES:
# NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
scores_scaled_shifted_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
tl.store(scores_scaled_shifted_ptrs, q_shifted, mask=scores_scaled_shifted_mask)
# Compute scaled QK and softmax probabilities
if USE_EXP2:
p = tl.math.exp2(q_shifted * RCP_LN2)
else:
p = tl.math.exp(q_shifted)
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N
keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k)
if RETURN_SCORES:
# NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
tl.store(exp_scores_ptrs, tl.where(keep, p, -p), mask=exp_score_mask)
p = tl.where(keep, p, 0.0)
elif RETURN_SCORES:
# NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
tl.store(exp_scores_ptrs, p, mask=exp_score_mask)
# -- update output accumulator --
# alpha is an adjustment factor for acc and li as we loop and find new maxes
# store the diff in maxes to adjust acc and li as we discover new maxes
m_diff = m_i - m_ij
if USE_EXP2:
alpha = tl.math.exp2(m_diff * RCP_LN2)
else:
alpha = tl.math.exp(m_diff)
acc = acc * alpha[:, None]
v = None
if not PRE_LOAD_V:
v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(v.type.element_ty), v)
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if bias_ptrs is not None:
bias_ptrs += BLOCK_N * stride_bn
if RETURN_SCORES:
score_ptrs += BLOCK_N
scores_scaled_shifted_ptrs += BLOCK_N
exp_scores_ptrs += BLOCK_N
return acc, l_i, m_i
def get_cdna_autotune_configs():
return [
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']
def get_rdna_autotune_configs():
return [
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']
def get_autotune_configs():
if AUTOTUNE:
if is_rdna():
return get_rdna_autotune_configs()
elif is_cdna():
return get_cdna_autotune_configs()
else:
raise ValueError("Unknown Device Type")
else:
return [
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False},
num_stages=1,
num_warps=4,
),
], [
"IS_CAUSAL",
"dropout_p",
"MAX_SEQLENS_Q",
"MAX_SEQLENS_K",
"ACTUAL_BLOCK_DMODEL",
"VARLEN",
"HQ",
"HK",
]
autotune_configs, autotune_keys = get_autotune_configs()
@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
# use_cuda_graph=True,
)
@triton.jit
def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, # pylint: disable=unused-argument
stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, scores, scores_scaled_shifted, exp_scores, alibi_slopes, HQ: tl.constexpr,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
# print("cu_seqlens_q_start:", cu_seqlens_q_start)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if IS_CAUSAL:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn matrix
n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is part of
# the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om
o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
o_ptrs_mask = offs_m[:, None] < seqlen_q
# We still need to write 0s to the result
tl.store(o_ptrs, acc, mask=o_ptrs_mask)
# The tensor allocated for L is based on MAX_SEQLENS_Q as that is
# statically known.
l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m
l_ptrs = l_offset + offs_m * stride_lse_m
l = tl.full([BLOCK_M], value=0.0, dtype=tl.float32)
# mask_m_offsets = start_m + tl.arange(0, BLOCK_M)
# lse_mask = mask_m_offsets < causal_start_idx
# softmax_lse = tl.where(lse_mask, 0.0, softmax_lse)
l_ptrs_mask = offs_m < MAX_SEQLENS_Q
tl.store(l_ptrs, l, mask=l_ptrs_mask)
# TODO: Should dropout and return encoded softmax be handled here too?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
if GROUP_SIZE != 1:
off_h_k = off_h_q // GROUP_SIZE
else:
off_h_k = off_h_q
n_extra_tokens = 0
# print("n_extra_tokens:", n_extra_tokens)
# print("seqlen_k:", seqlen_k)
# print("BLOCK_N:", BLOCK_N)
# return
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn
v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn
if USE_BIAS:
# Note: this might get large enough to overflow on some configs
bias_offset = off_h_q * stride_bh
bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn
else:
bias_ptrs = None
if USE_ALIBI:
a_offset = off_z * stride_az + off_h_q * stride_ah
alibi_slope = tl.load(alibi_slopes + a_offset)
else:
alibi_slope = None
if RETURN_SCORES:
scores_offset = scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
score_ptrs = scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
scores_scaled_shifted_offset = scores_scaled_shifted + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
scores_scaled_shifted_ptrs = scores_scaled_shifted_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
exp_scores_offset = exp_scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
exp_scores_ptrs = exp_scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
else:
score_ptrs = None
scores_scaled_shifted_ptrs = None
exp_scores_ptrs = None
if ENABLE_DROPOUT:
off_hz = off_z * HQ + off_h_q
batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k
else:
batch_philox_offset = 0
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# Q is loaded once at the beginning and shared by all N blocks.
q_ptrs_mask = offs_m[:, None] < seqlen_q
if PADDED_HEAD:
q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional block.
# In this case we might exceed n_blocks so pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its actual
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset,
exp_scores_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min, block_max, 0, 0, 0, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs,
# IS_CAUSAL, ....
False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD,
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if masked_blocks > 0:
if IS_CAUSAL:
offs_n_causal = offs_n + (seqlen_q - seqlen_k)
else:
offs_n_causal = 0
k_ptrs += n_full_blocks * BLOCK_N * stride_kn
v_ptrs += n_full_blocks * BLOCK_N * stride_vk
if USE_BIAS:
bias_ptrs += n_full_blocks * BLOCK_N * stride_bn
if RETURN_SCORES:
score_ptrs += n_full_blocks * BLOCK_N
scores_scaled_shifted_ptrs += n_full_blocks * BLOCK_N
exp_scores_ptrs += n_full_blocks * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset,
exp_scores_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs,
IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD,
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
# epilogue
# This helps the compiler do Newton Raphson on l_i vs on acc which is much larger.
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL:
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
z: tl.tensor = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE(Log Sum Exponents), the log of the normalization constant
l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m
l_ptrs = l_offset + offs_m * stride_lse_m
if USE_EXP2:
RCP_LN2: tl.constexpr = 1.4426950408889634
LN2: tl.constexpr = 0.6931471824645996
# compute log-sum-exp in base 2 units
mi_base2 = m_i * RCP_LN2
softmax_lse = mi_base2 + tl.math.log2(l_i)
# convert back to natural units
softmax_lse *= LN2
else:
softmax_lse = m_i + tl.math.log(l_i)
if IS_CAUSAL:
# zero out nans caused by -infs when doing causal
lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx
softmax_lse = tl.where(lse_mask, 0.0, softmax_lse)
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows.
# This is only true for the last M block. For others, overflow_size will be -ve
overflow_size = end_m_idx - seqlen_q
if overflow_size > 0:
boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32)
l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary
tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) # the log of the normalization constant
else:
tl.store(l_ptrs, softmax_lse) # the log of the normalization constant
# write back O
o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om
o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1)
if overflow_size > 0:
o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q)
if PADDED_HEAD:
o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)
def attention_prefill_forward_triton_impl(
q,
k,
v,
o,
sm_scale,
alibi_slopes,
causal,
bias,
dropout_p,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
return_scores,
use_exp2):
# check if varlen
is_varlen = layout == "thd"
# NOTE: a large bias tensor leads to overflow during pointer arithmetic
if bias is not None:
assert bias.numel() < 2**31
batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) # pylint: disable=unused-variable
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
# Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
# Smallest head_dim supported is 16. If smaller, the tile in the
# kernel is padded - there is no padding in memory for any dims.
padded_d_model = max(padded_d_model, 16)
grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) # pylint: disable=unnecessary-lambda-assignment
if return_scores:
scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)
scores_scaled_shifted = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)
scores_strides = (scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3))
else:
scores = None
scores_scaled_shifted = None
scores_strides = (0, 0 , 0 , 0)
# exp_scores is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out
# to give a consistent starting point and then populate it with the output of softmax with the sign bit set according
# to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing
# only. This return holds no useful output aside from debugging.
if return_scores:
exp_scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)
else:
exp_scores = None
# stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities)
if is_varlen:
softmax_lse = torch.empty((q.shape[0], nheads_q), device=q.device, dtype=torch.float32)
stride_lse_m, stride_lse_h = softmax_lse.stride()
stride_lse_z = 0
else:
softmax_lse = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32)
stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride()
# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
if bias is not None:
bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2),
bias.stride(3))
else:
bias_strides = (0, 0, 0, 0)
if alibi_slopes is not None:
alibi_strides = (alibi_slopes.stride(0), alibi_slopes.stride(1))
else:
alibi_strides = (0, 0)
attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, scores=scores,
scores_scaled_shifted=scores_scaled_shifted, exp_scores=exp_scores, alibi_slopes=alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_scores)
return o, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted

View File

@ -0,0 +1,258 @@
import math
import torch
def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2):
# Compute attention scores
attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
# Scale scores
attention_scaled_scores = sm_scale * attention_scores
# Apply causal mask if necessary
if causal:
L_q, L_k = q.shape[1], k.shape[1]
row_idx = torch.arange(L_q, device=q.device).unsqueeze(1)
col_idx = torch.arange(L_k, device=q.device).unsqueeze(0)
col_offset = L_q-L_k
causal_mask = row_idx >= (col_offset + col_idx)
# set -inf to places the causal mask is false
attention_scaled_scores = attention_scaled_scores.masked_fill(
torch.logical_not(causal_mask.unsqueeze(0)), float('-inf')
)
# Compute max for numerical stability
max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0]
if causal:
# Replace -inf in max_scores with zeros to avoid NaN in subtraction
max_scores = torch.where(
torch.isinf(max_scores), torch.zeros_like(max_scores), max_scores
)
# Shift scores
attention_shifted_scaled_scores = attention_scaled_scores - max_scores
# Exponentiate
if use_exp2:
RCP_LN = 1 / math.log(2)
exp_scores = torch.exp2(RCP_LN * attention_shifted_scaled_scores)
else:
exp_scores = torch.exp(attention_shifted_scaled_scores)
# Sum of exponentials
sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True)
if causal:
# if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly
sum_exp_scores = torch.where(
sum_exp_scores == 0,
torch.ones_like(sum_exp_scores),
sum_exp_scores
)
# Compute softmax probabilities
softmax = exp_scores / sum_exp_scores
# Compute log-sum-exp
if use_exp2:
LN2 = math.log(2)
RCP_LN = 1 / math.log(2)
max_scores_base2 = max_scores * RCP_LN
softmax_lse_base2 = max_scores_base2 + torch.log2(sum_exp_scores)
softmax_lse = softmax_lse_base2 * LN2
softmax_lse.squeeze_(-1)
else:
softmax_lse = max_scores + torch.log(sum_exp_scores)
softmax_lse = softmax_lse.squeeze(-1)
# Compute output
o = torch.matmul(softmax, v.to(torch.float32)).to(torch.float16)
return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores
def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_exp2):
"""Compute reference output and softmax_lse using PyTorch's built-in function"""
# Ensure the layout is 'bhsd'
if layout == "bshd":
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
elif layout != "bhsd":
raise ValueError(f"Unknown layout {layout}")
# Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format
batch_size, num_heads, seq_len_q, head_dim = q.shape
seq_len_k = k.shape[2]
# Merge batch and heads dimensions
q = q.reshape(batch_size * num_heads, seq_len_q, head_dim)
k = k.reshape(batch_size * num_heads, seq_len_k, head_dim)
v = v.reshape(batch_size * num_heads, seq_len_k, head_dim)
# Call the core attention function
o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores = attention_forward_core_ref_impl(
q, k, v, sm_scale, causal, use_exp2
)
# Reshape outputs back to [batch_size, num_heads, seq_len, head_dim]
o = o.reshape(batch_size, num_heads, seq_len_q, head_dim)
softmax_lse = softmax_lse.reshape(batch_size, num_heads, seq_len_q)
exp_scores = exp_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
softmax = softmax.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
attention_shifted_scaled_scores = attention_shifted_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
attention_scaled_scores = attention_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
attention_scores = attention_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
# Restore original layout if necessary
if layout == "bshd":
o = o.transpose(1, 2)
return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores
def attention_varlen_forward_pytorch_ref_impl(
q,
k,
v,
sm_scale,
causal,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q, max_seqlen_k, # pylint: disable=unused-argument
use_exp2
):
# Ensure the layout is 'thd'
if layout != 'thd':
raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.")
batch_size = cu_seqlens_q.shape[0] - 1
num_heads = q.shape[1]
head_dim = q.shape[2]
# Pre-allocate outputs
total_L_q = q.shape[0]
total_L_k = k.shape[0] # pylint: disable=unused-variable
o = torch.empty((total_L_q, num_heads, head_dim), dtype=q.dtype, device=q.device)
softmax_lse = torch.empty((total_L_q, num_heads), dtype=torch.float32, device=q.device)
for i in range(batch_size):
# Get the start and end indices for the current sequence
start_q = cu_seqlens_q[i].item()
end_q = cu_seqlens_q[i + 1].item()
start_k = cu_seqlens_k[i].item()
end_k = cu_seqlens_k[i + 1].item()
# Extract q_i, k_i, v_i
q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
# Permute to [num_heads, L_q_i, head_dim]
q_i = q_i.permute(1, 0, 2)
k_i = k_i.permute(1, 0, 2)
v_i = v_i.permute(1, 0, 2)
# Call the core attention function for this sequence
(
o_i,
softmax_lse_i,
exp_scores_i,
softmax_i,
attention_shifted_scaled_scores_i,
attention_scaled_scores_i,
attention_scores_i,
) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, use_exp2)
# Convert back to 'thd' layout and float16
o_i = o_i.permute(1, 0, 2).to(torch.float16) # [L_q_i, num_heads, head_dim]
# Place outputs in pre-allocated tensors
o[start_q:end_q, :, :] = o_i
softmax_lse[start_q:end_q, :] = softmax_lse_i.transpose(0, 1) # Transpose to [L_q_i, num_heads]
# For variable-sized outputs, map them into the preallocated tensors
# exp_scores_i: [num_heads, L_q_i, L_k_i] -> [L_q_i, num_heads, L_k_i]
exp_scores_i = exp_scores_i.permute(1, 0, 2)
softmax_i = softmax_i.permute(1, 0, 2)
attention_shifted_scaled_scores_i = attention_shifted_scaled_scores_i.permute(1, 0, 2)
attention_scaled_scores_i = attention_scaled_scores_i.permute(1, 0, 2)
attention_scores_i = attention_scores_i.permute(1, 0, 2)
return (
o,
softmax_lse,
None,
None,
None,
None,
None,
)
def attention_forward_pytorch_ref_impl(
q,
k,
v,
sm_scale,
causal,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
use_exp2
):
# compute reference
if layout == "thd":
(
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scaled_scores_ref,
attention_scores_ref,
) = attention_varlen_forward_pytorch_ref_impl(
q.clone(),
k.clone(),
v.clone(),
sm_scale,
causal,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
use_exp2,
)
else:
(
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scaled_scores_ref,
attention_scores_ref,
) = attention_vanilla_forward_pytorch_ref_impl(
q.clone(), k.clone(), v.clone(), sm_scale, causal, layout, use_exp2
)
return (
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scaled_scores_ref,
attention_scores_ref,
)
def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k):
q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K)
relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K)
return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)

View File

@ -0,0 +1,394 @@
import os
import torch
from modules.flash_attn_triton_amd.fwd_prefill import attention_prefill_forward_triton_impl
from modules.flash_attn_triton_amd.bwd_prefill import attention_prefill_backward_triton_impl
from modules.flash_attn_triton_amd.fwd_decode import attention_decode_forward_triton_impl
from modules.flash_attn_triton_amd.fwd_ref import attention_forward_pytorch_ref_impl
from modules.flash_attn_triton_amd.bwd_ref import attention_backward_pytorch_ref_impl
from modules.flash_attn_triton_amd.utils import MetaData, get_shape_from_layout
USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')
def fwd(q,
k,
v,
o,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size_left, window_size_right, softcap, # pylint: disable=unused-argument
return_softmax,
gen_ # pylint: disable=unused-argument
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD's Triton Backend yet")
if o is None:
o = torch.empty_like(q)
# Setup metadata
metadata = MetaData(sm_scale=softmax_scale)
metadata.max_seqlens_q = q.shape[1]
metadata.max_seqlens_k = k.shape[1]
metadata.layout = "bshd"
if return_softmax:
metadata.return_scores = True
batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout) # pylint: disable=unused-variable
if causal:
metadata.need_causal()
if alibi_slopes is not None:
metadata.need_alibi(alibi_slopes, batch, nheads_q)
if dropout_p > 0.0:
metadata.need_dropout(dropout_p, return_softmax)
# Check arguments
metadata.check_args(q, k, v, o)
if USE_REF:
(output,
softmax_lse,
exp_scores,
_,
_,
_,
_) = attention_forward_pytorch_ref_impl(
q,
k,
v,
metadata.sm_scale,
metadata.causal,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.use_exp2)
o.copy_(output)
else:
(_,
softmax_lse,
exp_scores,
_,
_,
_,
_,
_,
_) = attention_prefill_forward_triton_impl(
q,
k,
v,
o,
metadata.sm_scale,
metadata.alibi_slopes,
metadata.causal,
metadata.bias,
metadata.dropout_p,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.return_scores,
metadata.use_exp2)
return o, softmax_lse, exp_scores, None
def bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD yet")
if USE_REF:
dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
dout,
q,
k,
v,
out,
softmax_lse,
softmax_scale,
causal,
"bshd",
None,
None,
None,
None,
False,
)
dq.copy_(dq_ref)
dk.copy_(dk_ref)
dv.copy_(dv_ref)
delta = delta_ref
else:
dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
alibi_slopes,
causal,
"bshd",
None,
None,
None,
None,
False,
)
delta = delta_triton
return dq, dk, dv, delta
def varlen_fwd(
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
seqused_k, leftpad_k, block_table_, # pylint: disable=unused-argument
alibi_slopes,\
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
zero_tensors, # pylint: disable=unused-argument
causal,
window_size_left, window_size_right, softcap, # pylint: disable=unused-argument
return_softmax,
gen_ # pylint: disable=unused-argument
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD's Triton Backend yet")
if o is None:
o = torch.empty_like(q)
# Setup metadata
metadata = MetaData(sm_scale=softmax_scale)
if return_softmax:
metadata.return_scores = True
metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata
# get shapes
batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable
if causal:
metadata.need_causal()
if alibi_slopes is not None:
metadata.need_alibi(alibi_slopes, batch, nheads_q)
if dropout_p > 0.0:
metadata.need_dropout(dropout_p, return_softmax)
# Check arguments
metadata.check_args(q, k, v, o)
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
if USE_REF:
(output,
softmax_lse,
exp_scores,
_,
_,
_,
_) = attention_forward_pytorch_ref_impl(
q,
k,
v,
metadata.sm_scale,
metadata.causal,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.use_exp2)
o.copy_(output)
else:
(_,
softmax_lse,
exp_scores,
_,
_,
_,
_,
_,
_) = attention_prefill_forward_triton_impl(
q,
k,
v,
o,
metadata.sm_scale,
metadata.alibi_slopes,
metadata.causal,
metadata.bias,
metadata.dropout_p,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.return_scores,
metadata.use_exp2)
return o, softmax_lse, exp_scores, None
def varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
zero_tensors, # pylint: disable=unused-argument
causal,
window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD yet")
if USE_REF:
dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
dout,
q,
k,
v,
out,
softmax_lse,
softmax_scale,
causal,
"thd",
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
False,
)
dq.copy_(dq_ref)
dk.copy_(dk_ref)
dv.copy_(dv_ref)
delta = delta_ref
else:
dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
alibi_slopes,
causal,
"thd",
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
False,
)
delta = delta_triton
return dq, dk, dv, delta
def fwd_kvcache(
q,
k_cache,
v_cache,
k,
v,
cache_seqlens,
rotary_cos, rotary_sin, # pylint: disable=unused-argument
cache_batch_idx,
cache_leftpad, block_table, # pylint: disable=unused-argument
alibi_slopes,
out,
softmax_scale,
causal,
window_size_left, window_size_right, softcap, rotary_interleaved, num_splits, # pylint: disable=unused-argument
):
if out is None:
out = torch.empty_like(q)
# fill metadata
metadata = MetaData(sm_scale=softmax_scale)
metadata.layout = "bshd"
metadata.max_seqlens_q = q.shape[1]
metadata.max_seqlens_k = k_cache.shape[1]
metadata.cache_seqlens = cache_seqlens
metadata.cache_batch_idx = cache_batch_idx
if k is not None and v is not None:
metadata.new_kv = True
metadata.seqlen_new = k.shape[1]
metadata.k_new = k
metadata.v_new = v
if causal:
metadata.need_causal()
if alibi_slopes is not None:
batch, _ , nheads_q, _= q.shape
metadata.need_alibi(alibi_slopes, batch, nheads_q)
# launch kernel
# TODO: pass output as an arg. Maybe we are copying output which is causing slow down
output, softmax_lse = attention_decode_forward_triton_impl(
q,
k_cache,
v_cache,
metadata.sm_scale,
metadata.causal,
metadata.alibi_slopes,
metadata.layout,
metadata.cache_seqlens,
metadata.cache_batch_idx,
metadata.new_kv,
metadata.k_new,
metadata.v_new,
)
return output, softmax_lse

View File

@ -0,0 +1,280 @@
import os
import torch
import triton
AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes')
PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')
class MetaData():
cu_seqlens_q = None
cu_seqlens_k = None
max_seqlens_q = 0
max_seqlens_k = 0
bias = None
alibi_slopes = None
causal = False
num_contexts = 0
varlen = False
layout = None
cache_seqlens = None
cache_batch_idx = None
new_kv = False
seqlen_new = None
k_new = None
v_new = None
dropout_p, return_scores= 0.0, False
# NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW.
use_exp2 = False
def __repr__(self) -> str:
return (f"MetaData(\n"
f" sm_scale={self.sm_scale},\n"
f" cu_seqlens_q={self.cu_seqlens_q},\n"
f" cu_seqlens_k={self.cu_seqlens_k},\n"
f" max_seqlens_q={self.max_seqlens_q},\n"
f" max_seqlens_k={self.max_seqlens_k},\n"
f" bias={self.bias},\n"
f" alibi_slopes={self.alibi_slopes},\n"
f" causal={self.causal},\n"
f" num_contexts={self.num_contexts},\n"
f" varlen={self.varlen},\n"
f" layout={self.layout},\n"
f" cache_seqlens={self.cache_seqlens},\n"
f" cache_batch_idx={self.cache_batch_idx},\n"
f" new_kv={self.new_kv},\n"
f" seqlen_new={self.seqlen_new},\n"
f" k_new={self.k_new},\n"
f" v_new={self.v_new},\n"
f" dropout_p={self.dropout_p},\n"
f" return_scores={self.return_scores}\n"
f")")
def __init__(self, sm_scale=1.0):
self.sm_scale = sm_scale
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.layout = 'thd'
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
# Without "varlen", there should still be one sequence.
assert len(cu_seqlens_q) >= 2
assert len(cu_seqlens_q) == len(cu_seqlens_k)
self.num_contexts = len(cu_seqlens_q) - 1
for i in range(0, self.num_contexts):
self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q)
self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k)
def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): # pylint: disable=unused-argument
assert bias.is_cuda
assert bias.dim() == 4
assert bias.shape[0] == 1
assert bias.shape[2:] == (seqlen_q, seqlen_k)
self.bias = bias
def need_alibi(self, alibi_slopes, batch, nheads):
assert alibi_slopes.is_cuda
assert alibi_slopes.dim() == 2
assert alibi_slopes.shape[0] == batch
assert alibi_slopes.shape[1] == nheads
self.alibi_slopes = alibi_slopes
def need_causal(self):
self.causal = True
def need_dropout(self, dropout_p, return_scores):
self.dropout_p = dropout_p
self.return_scores = return_scores
def check_args(self, q, k, v, o):
assert q.dim() == k.dim() and q.dim() == v.dim()
batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) # pylint: disable=unused-variable
if self.varlen:
assert q.dim() == 3
assert self.cu_seqlens_q is not None
assert self.cu_seqlens_k is not None
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
# TODO: Remove once bias is supported with varlen
assert self.bias is None
# TODO:Remove once dropout is supported with varlen
assert self.dropout_p == 0.0
# assert not self.return_scores
else:
assert q.dim() == 4
assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
assert self.layout is not None
assert self.layout == 'thd' or not self.varlen
def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cuda", DEBUG_INPUT=False):
torch.manual_seed(20)
# Initialize q, k, v
if layout == 'bhsd':
q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
elif layout == 'bshd':
q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD)
else:
assert False, f'Got unsupported tensor layout: {layout}'
q = None
k = None
v = None
if DEBUG_INPUT:
if layout == "bhsd":
q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, 1, N_CTX_Q, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
elif layout == "bshd":
q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, N_CTX_Q, 1, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
else:
q = torch.randn(q_tensor_shape, dtype=dtype, device=device, requires_grad=True)
k = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
v = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
if DEBUG_INPUT:
sm_scale = 1
else:
sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.max_seqlens_q = N_CTX_Q
input_metadata.max_seqlens_k = N_CTX_K
input_metadata.layout = layout
return q, k, v, input_metadata
def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False):
torch.manual_seed(20)
# Random or equal sequence lengths based on 'equal_seqlens' flag
if not equal_seqlens:
max_seqlens_q = N_CTX_Q // Z
max_seqlens_k = N_CTX_K // Z
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32)
seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32)
else:
seqlens_q = torch.full((Z,), N_CTX_Q // Z, dtype=torch.int32)
seqlens_k = torch.full((Z,), N_CTX_K // Z, dtype=torch.int32)
# Calculate cumulative sequence lengths
cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0)])
cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0)])
cu_seqlens_q = cu_seqlens_q.to(device=device).to(torch.int32)
cu_seqlens_k = cu_seqlens_k.to(device=device).to(torch.int32)
# Total lengths
total_q = cu_seqlens_q[-1].item()
total_k = cu_seqlens_k[-1].item()
if DEBUG_INPUT:
# Initialize q, k, v with deterministic values
q = torch.arange(total_q, dtype=dtype, device=device).view(total_q, 1, 1)
q = q.expand(total_q, HQ, D_HEAD).contiguous().requires_grad_()
k = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
k = k.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
v = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
v = v.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
sm_scale = 1
else:
# Initialize q, k, v with random values
q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_()
k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
sm_scale = D_HEAD ** -0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
return q, k, v, input_metadata
def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None):
if layout == 'bhsd':
batch_q, nheads_q, max_seqlen_q, head_size_q = q.shape
batch_k, nheads_k, max_seqlen_k, head_size_k = k.shape
elif layout == 'bshd':
batch_q, max_seqlen_q, nheads_q, head_size_q = q.shape
batch_k, max_seqlen_k, nheads_k, head_size_k = k.shape
elif layout == 'thd':
batch_q, max_seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] # pylint: disable=self-assigning-variable
batch_k, max_seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2] # pylint: disable=self-assigning-variable
else:
assert False, "Got unsupported layout."
# assert
assert batch_q == batch_k
assert head_size_q == head_size_k
return batch_q, nheads_q, nheads_k, head_size_q, max_seqlen_q, max_seqlen_k
def get_strides_from_layout(q, k, v, o, layout):
if layout == 'thd':
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
elif layout == 'bhsd':
q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))
k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))
v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))
o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))
elif layout == 'bshd':
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
else:
assert False, 'Got unsupported layout.'
return q_strides, k_strides, v_strides, o_strides
def get_padded_headsize(size):
# Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (size - 1).bit_length()
# Smallest head_dim supported is 16. If smaller, the tile in the
# kernel is padded - there is no padding in memory for any dims.
padded_d_model = max(padded_d_model, 16)
return padded_d_model
def _strides(x: torch.Tensor, *stride_names: str):
if x is None:
return {f"stride_{s}": 0 for i, s in enumerate(stride_names)}
assert x.ndim == len(stride_names)
return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
def get_input_shapes():
cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128)
for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)]
return cases
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def is_rdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101",
"gfx1102", "gfx1200", "gfx1201")

View File

@ -1,13 +1,21 @@
from functools import wraps
import torch
import torch._dynamo.device_interface
from modules import rocm, zluda
from modules import rocm, zluda, shared
_topk = torch.topk
def topk(input: torch.Tensor, *args, **kwargs): # pylint: disable=redefined-builtin
device = input.device
values, indices = _topk(input.cpu(), *args, **kwargs)
return torch.return_types.topk((values.to(device), indices.to(device),))
MEM_BUS_WIDTH = {
"AMD Radeon RX 9070 XT": 256,
"AMD Radeon RX 9070": 256,
"AMD Radeon RX 9060 XT": 192,
"AMD Radeon RX 7900 XTX": 384,
"AMD Radeon RX 7900 XT": 320,
"AMD Radeon RX 7900 GRE": 256,
"AMD Radeon RX 7800 XT": 256,
"AMD Radeon RX 7700 XT": 192,
"AMD Radeon RX 7600 XT": 128,
"AMD Radeon RX 7600": 128,
}
class DeviceProperties:
@ -35,20 +43,60 @@ def torch__C__cuda_getCurrentRawStream(device):
def do_hijack():
torch.version.hip = rocm.version
torch.topk = topk
if zluda.default_agent is not None:
DeviceProperties.PROPERTIES_OVERRIDE["gcnArchName"] = zluda.default_agent.name
torch.cuda._get_device_properties = torch_cuda__get_device_properties # pylint: disable=protected-access
torch._C._cuda_getCurrentRawStream = torch__C__cuda_getCurrentRawStream # pylint: disable=protected-access
torch._dynamo.device_interface.CudaInterface.get_raw_stream = staticmethod(torch__C__cuda_getCurrentRawStream) # pylint: disable=protected-access
# Triton
try:
import triton
_get_device_properties = triton.runtime.driver.active.utils.get_device_properties
def triton_runtime_driver_active_utils_get_device_properties(device):
props = _get_device_properties(device)
props["mem_bus_width"] = 384
name = torch.cuda.get_device_name()[:-8]
if name in MEM_BUS_WIDTH:
props["mem_bus_width"] = MEM_BUS_WIDTH[name]
else:
props["mem_bus_width"] = 128
shared.log.warning(f'[TRITON] defaulting mem_bus_width=128 for device "{name}".')
return props
triton.runtime.driver.active.utils.get_device_properties = triton_runtime_driver_active_utils_get_device_properties
if 'Flash attention' in shared.opts.sdp_options:
from modules.flash_attn_triton_amd import interface_fa
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flash_atten)
def sdpa_flash_atten(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
if scale is None:
scale = query.shape[-1] ** (-0.5)
head_size_og = query.size(3)
if head_size_og % 8 != 0:
query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8])
key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8])
value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8])
out_padded, _, _, _ = interface_fa.fwd(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
None,
None,
dropout_p,
scale,
is_causal,
-1,
-1,
0.0,
False,
None,
)
return out_padded[..., :head_size_og].transpose(1, 2)
else:
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
shared.log.debug('Torch attention: type="triton flash attention"')
except Exception:
pass