diff --git a/html/licenses.html b/html/licenses.html index 6597fa3ae..dc0e1fdbe 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -637,6 +637,40 @@ SOFTWARE. limitations under the License. +
+BSD 3-Clause License + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ++
diff --git a/modules/flash_attn_triton_amd/__init__.py b/modules/flash_attn_triton_amd/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/modules/flash_attn_triton_amd/bwd_prefill.py b/modules/flash_attn_triton_amd/bwd_prefill.py
new file mode 100644
index 000000000..7f5be379b
--- /dev/null
+++ b/modules/flash_attn_triton_amd/bwd_prefill.py
@@ -0,0 +1,606 @@
+import torch
+import triton
+import triton.language as tl
+from modules.flash_attn_triton_amd.utils import get_shape_from_layout, get_strides_from_layout
+
+
+@triton.jit
+def _bwd_preprocess_use_o(
+ Out,
+ DO,
+ Delta,
+ stride_oz, stride_oh, stride_om, stride_ok,
+ stride_doz, stride_doh, stride_dom, stride_dok, # pylint: disable=unused-argument
+ stride_deltaz, stride_deltah, stride_deltam,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ BLOCK_M: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr,
+ ACTUAL_BLOCK_DMODEL: tl.constexpr,
+ N_CTX_Q: tl.constexpr,
+ Z: tl.constexpr, # pylint: disable=unused-argument
+ H: tl.constexpr,
+ IS_VARLEN: tl.constexpr
+):
+ pid_m = tl.program_id(0)
+ pid_bh = tl.program_id(1)
+
+ # Compute batch and head indices
+ off_z = pid_bh // H
+ off_h = pid_bh % H
+
+ if IS_VARLEN:
+ # Compute sequence lengths for the current batch
+ q_start = tl.load(cu_seqlens_q + off_z)
+ q_end = tl.load(cu_seqlens_q + off_z + 1)
+ k_start = tl.load(cu_seqlens_k + off_z)
+ k_end = tl.load(cu_seqlens_k + off_z + 1)
+
+ # Compute actual sequence lengths
+ N_CTX_Q = q_end - q_start
+ N_CTX_K = k_end - k_start # pylint: disable=unused-variable
+ else:
+ q_start = 0
+ k_start = 0
+ N_CTX_Q = max_seqlen_q
+ N_CTX_K = max_seqlen_k # pylint: disable=unused-variable
+
+ off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_d = tl.arange(0, BLOCK_DMODEL)
+
+ # create masks
+ mask_m = off_m < N_CTX_Q
+ mask_d = off_d < ACTUAL_BLOCK_DMODEL
+
+ # compute offsets
+ o_offset = Out + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
+ do_offset = DO + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
+
+ # compute pointers
+ out_ptrs = o_offset + off_m[:, None] * stride_om + off_d[None, :] * stride_ok
+ do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok
+
+ # load
+ o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
+ do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
+
+ # compute delta
+ delta = tl.sum(o * do, axis=1)
+
+ # write-back delta
+ delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
+ delta_ptrs = delta_offset + off_m * stride_deltam
+ tl.store(delta_ptrs, delta, mask=mask_m)
+
+
+@triton.jit
+def _bwd_kernel_one_col_block(
+ Q,
+ K,
+ V,
+ sm_scale,
+ Out, DO, DQ, DK, DV, L, D, # pylint: disable=unused-argument
+ q_offset,
+ k_offset,
+ v_offset,
+ do_offset,
+ dq_offset,
+ dk_offset,
+ dv_offset,
+ d_offset,
+ l_offset,
+ stride_dq_all, stride_qz, stride_qh, # pylint: disable=unused-argument
+ stride_qm,
+ stride_qk,
+ stride_kz, stride_kh, # pylint: disable=unused-argument
+ stride_kn,
+ stride_kk,
+ stride_vz, stride_vh, # pylint: disable=unused-argument
+ stride_vn,
+ stride_vk,
+ stride_deltaz, stride_deltah, # pylint: disable=unused-argument
+ stride_deltam,
+ Z, H, # pylint: disable=unused-argument
+ N_CTX_Q,
+ N_CTX_K,
+ off_h, off_z, off_hz, # pylint: disable=unused-argument
+ start_n,
+ num_block_m,
+ num_block_n, # pylint: disable=unused-argument
+ BLOCK_M: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr,
+ ACTUAL_BLOCK_DMODEL: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ SEQUENCE_PARALLEL: tl.constexpr,
+ CAUSAL: tl.constexpr,
+ USE_EXP2: tl.constexpr,
+):
+ if CAUSAL:
+ # TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
+ lo = 0
+ else:
+ lo = 0
+
+ # initialize col and head offsets
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+
+ # masks
+ mask_n = offs_n < N_CTX_K
+ mask_d = offs_d < ACTUAL_BLOCK_DMODEL
+ kv_mask = mask_n[:, None] & mask_d[None, :]
+
+ # initialize grad accumulators
+ dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
+ dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
+
+ # load k and v once per column block
+ k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
+ v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
+ k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
+ v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
+
+ # loop over rows
+ for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M):
+ offs_m = start_m + tl.arange(0, BLOCK_M)
+ q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
+ dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
+ do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
+
+ # update mask as row block changes
+ mask_m = offs_m < N_CTX_Q
+ q_mask = mask_m[:, None] & mask_d[None, :]
+
+ # load q, k, v, do on-chip
+ q = tl.load(q_ptrs, mask=q_mask, other=0.0)
+ do = tl.load(do_ptrs, mask=q_mask, other=0.0)
+
+ # recompute p = softmax(qk, dim=-1).T
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ qk += tl.dot(q, tl.trans(k))
+
+ if CAUSAL:
+ col_offset = N_CTX_Q - N_CTX_K
+ causal_mask = offs_m[:, None] >= (col_offset + offs_n[None, :])
+ qk = tl.where(causal_mask, qk, float("-inf"))
+
+ l_ptrs = l_offset + offs_m * stride_deltam
+ l_i = tl.load(l_ptrs, mask=mask_m)
+
+ # compute p
+ if USE_EXP2:
+ RCP_LN2: tl.constexpr = 1.4426950408889634
+ qk *= sm_scale * RCP_LN2
+ l_i *= RCP_LN2
+ p = tl.math.exp2(qk - l_i[:, None])
+ else:
+ qk *= sm_scale
+ p = tl.math.exp(qk - l_i[:, None])
+
+ # mask block in the cases where the data is smaller the block size
+ p_mask = mask_m[:, None] & mask_n[None, :]
+ p = tl.where(p_mask, p, 0.0)
+
+ # compute dv
+ dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
+
+ # compute dp
+ dp = tl.dot(do, tl.trans(v))
+
+ # compute ds , ds = p * (dp - delta[:, None])
+ d_ptrs = d_offset + offs_m * stride_deltam
+ Di = tl.load(d_ptrs, mask=mask_m)
+ ds = (p * (dp - Di[:, None])) * sm_scale
+ ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty)
+
+ # compute dk = dot(ds.T, q)
+ dk += tl.dot(tl.trans(ds), q)
+
+ # compute dq
+ if SEQUENCE_PARALLEL:
+ dq = tl.dot(ds, k)
+ else:
+ dq = tl.load(dq_ptrs, mask=q_mask, other=0.0)
+ dq += tl.dot(ds, k)
+ tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask)
+
+ # write-back dv and dk
+ dk_ptrs = dk_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
+ dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
+
+ # write-back
+ tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
+ tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)
+
+@triton.jit
+def _bwd_kernel(
+ Q,
+ K,
+ V,
+ sm_scale,
+ Out,
+ DO,
+ DQ,
+ DK,
+ DV,
+ L,
+ D,
+ stride_dq_all,
+ stride_qz,
+ stride_qh,
+ stride_qm,
+ stride_qk,
+ stride_kz,
+ stride_kh,
+ stride_kn,
+ stride_kk,
+ stride_vz,
+ stride_vh,
+ stride_vn,
+ stride_vk,
+ stride_deltaz,
+ stride_deltah,
+ stride_deltam,
+ Z,
+ H,
+ num_block_m,
+ num_block_n,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ BLOCK_M: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr,
+ ACTUAL_BLOCK_DMODEL: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ SEQUENCE_PARALLEL: tl.constexpr,
+ CAUSAL: tl.constexpr,
+ USE_EXP2: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+):
+ # program ids
+ off_hz = tl.program_id(0)
+ if SEQUENCE_PARALLEL:
+ start_n = tl.program_id(1)
+ off_z = off_hz // H
+ off_h = off_hz % H
+
+ if IS_VARLEN:
+ # Compute sequence lengths for the current batch
+ q_start = tl.load(cu_seqlens_q + off_z)
+ q_end = tl.load(cu_seqlens_q + off_z + 1)
+ k_start = tl.load(cu_seqlens_k + off_z)
+ k_end = tl.load(cu_seqlens_k + off_z + 1)
+
+ # Compute actual sequence lengths
+ N_CTX_Q = q_end - q_start
+ N_CTX_K = k_end - k_start
+ else:
+ q_start = 0
+ k_start = 0
+ N_CTX_Q = max_seqlen_q
+ N_CTX_K = max_seqlen_k
+
+ # input tensor offsets
+ q_offset = Q + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
+ k_offset = K + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
+ v_offset = V + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
+ do_offset = DO + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
+ l_offset = L + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
+ d_offset = D + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
+
+ # output tensor offsets
+ dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
+ dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
+ if SEQUENCE_PARALLEL:
+ dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
+ else:
+ dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
+
+ # inner loop
+ if SEQUENCE_PARALLEL:
+ _bwd_kernel_one_col_block(
+ Q,
+ K,
+ V,
+ sm_scale,
+ Out,
+ DO,
+ DQ,
+ DK,
+ DV,
+ L,
+ D,
+ q_offset,
+ k_offset,
+ v_offset,
+ do_offset,
+ dq_offset,
+ dk_offset,
+ dv_offset,
+ d_offset,
+ l_offset,
+ stride_dq_all,
+ stride_qz,
+ stride_qh,
+ stride_qm,
+ stride_qk,
+ stride_kz,
+ stride_kh,
+ stride_kn,
+ stride_kk,
+ stride_vz,
+ stride_vh,
+ stride_vn,
+ stride_vk,
+ stride_deltaz,
+ stride_deltah,
+ stride_deltam,
+ Z,
+ H,
+ N_CTX_Q,
+ N_CTX_K,
+ off_h,
+ off_z,
+ off_hz,
+ start_n,
+ num_block_m,
+ num_block_n,
+ BLOCK_M=BLOCK_M,
+ BLOCK_DMODEL=BLOCK_DMODEL,
+ ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
+ BLOCK_N=BLOCK_N,
+ SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
+ CAUSAL=CAUSAL,
+ USE_EXP2=USE_EXP2,
+ )
+ else:
+ for start_n in range(0, num_block_n):
+ _bwd_kernel_one_col_block(
+ Q,
+ K,
+ V,
+ sm_scale,
+ Out,
+ DO,
+ DQ,
+ DK,
+ DV,
+ L,
+ D,
+ q_offset,
+ k_offset,
+ v_offset,
+ do_offset,
+ dq_offset,
+ dk_offset,
+ dv_offset,
+ d_offset,
+ l_offset,
+ stride_dq_all,
+ stride_qz,
+ stride_qh,
+ stride_qm,
+ stride_qk,
+ stride_kz,
+ stride_kh,
+ stride_kn,
+ stride_kk,
+ stride_vz,
+ stride_vh,
+ stride_vn,
+ stride_vk,
+ stride_deltaz,
+ stride_deltah,
+ stride_deltam,
+ Z,
+ H,
+ N_CTX_Q,
+ N_CTX_K,
+ off_h,
+ off_z,
+ off_hz,
+ start_n,
+ num_block_m,
+ num_block_n,
+ BLOCK_M=BLOCK_M,
+ BLOCK_DMODEL=BLOCK_DMODEL,
+ ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
+ BLOCK_N=BLOCK_N,
+ SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
+ CAUSAL=CAUSAL,
+ USE_EXP2=USE_EXP2,
+ )
+
+
+# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom.
+def attention_prefill_backward_triton_impl(
+ do,
+ q,
+ k,
+ v,
+ o,
+ softmax_lse,
+ dq,
+ dk,
+ dv,
+ sm_scale: float,
+ alibi_slopes, # pylint: disable=unused-argument
+ causal,
+ layout: str,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q: int,
+ max_seqlen_k: int,
+ use_exp2: bool,
+ sequence_parallel = True,
+):
+ # make contigious
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ softmax_lse = softmax_lse.contiguous()
+
+ # get strides and shape
+ batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable
+ q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
+ stride_qz, stride_qh, stride_qm, stride_qk = q_strides
+ stride_kz, stride_kh, stride_kn, stride_kk = k_strides
+ stride_vz, stride_vh, stride_vn, stride_vk = v_strides
+ stride_oz, stride_oh, stride_om, stride_ok = o_strides
+ batch_headsize = batch * nheads_q
+ is_varlen = layout == "thd"
+
+ # FIXME: some configs lead to oom for some reason when using 64 x 64 blocks
+ if max_seqlen_q <= 32 or max_seqlen_k <= 32:
+ BLOCK_M = 32
+ BLOCK_N = 32
+ else:
+ BLOCK_M = 64
+ BLOCK_N = 64
+ num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful
+ num_stages = 1
+ waves_per_eu = 1
+
+ # divide up the problem
+ num_blocks_m = triton.cdiv(max_seqlen_q, BLOCK_M)
+ num_blocks_n = triton.cdiv(max_seqlen_k, BLOCK_N)
+
+ # get closest power of 2 over or equal to 32.
+ padded_d_model = 1 << (head_size - 1).bit_length()
+ padded_d_model = max(padded_d_model, 16)
+ BLOCK_DMODEL = padded_d_model
+ ACTUAL_BLOCK_DMODEL = head_size
+
+ do = do.contiguous()
+ # NOTE: we might need to copy the output tensor if they are not continuous or have other issues
+ copy_back = {"dq": False, "dk": False, "dv": False}
+
+ dq_og = None
+ # deal with dq
+ if dq is None:
+ if sequence_parallel:
+ dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
+ else:
+ dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype)
+ else:
+ dq_og = dq
+ if not dq.is_contiguous():
+ dq = dq.contiguous()
+ copy_back["dq"] = True
+
+ if sequence_parallel:
+ dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
+ copy_back["dq"] = True
+ else:
+ # NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
+ dq.zero_()
+ stride_dq_all = dq.stride()[0]
+
+ dk_og = None
+ dv_og = None
+ # deal with dk, dv
+ if (dk is None) or (dv is None):
+ dk = torch.empty_like(k)
+ dv = torch.empty_like(v)
+ else:
+ if not dk.is_contiguous():
+ dk_og = dk
+ dk = dk.contiguous()
+ copy_back["dk"] = True
+
+ if not dv.is_contiguous():
+ dv_og = dv
+ dv = dv.contiguous()
+ copy_back["dv"] = True
+
+ # assert contigious
+ assert do.is_contiguous()
+ assert q.is_contiguous()
+ assert k.is_contiguous()
+ assert v.is_contiguous()
+ assert o.is_contiguous()
+ assert softmax_lse.is_contiguous()
+
+ # init delta
+ delta = torch.empty_like(softmax_lse)
+ if is_varlen:
+ stride_deltam, stride_deltah = delta.stride()
+ stride_deltaz = 0
+ else:
+ stride_deltaz, stride_deltah, stride_deltam = delta.stride()
+
+ _bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
+ o,
+ do,
+ delta,
+ stride_oz, stride_oh, stride_om, stride_ok,
+ stride_oz, stride_oh, stride_om, stride_ok,
+ stride_deltaz, stride_deltah, stride_deltam,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ BLOCK_M=BLOCK_M,
+ BLOCK_DMODEL=BLOCK_DMODEL,
+ ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
+ N_CTX_Q=max_seqlen_q,
+ Z=batch,
+ H=nheads_q,
+ IS_VARLEN=is_varlen
+ )
+
+ _bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)](
+ q,
+ k,
+ v,
+ sm_scale,
+ o,
+ do,
+ dq,
+ dk,
+ dv,
+ softmax_lse,
+ delta,
+ stride_dq_all,
+ stride_qz, stride_qh, stride_qm, stride_qk,
+ stride_kz, stride_kh, stride_kn, stride_kk,
+ stride_vz, stride_vh, stride_vn, stride_vk,
+ stride_deltaz, stride_deltah, stride_deltam,
+ batch,
+ nheads_q,
+ num_blocks_m,
+ num_blocks_n,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ BLOCK_DMODEL=BLOCK_DMODEL,
+ ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
+ SEQUENCE_PARALLEL=sequence_parallel,
+ CAUSAL=causal,
+ USE_EXP2=use_exp2,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ waves_per_eu = waves_per_eu,
+ IS_VARLEN=is_varlen
+ )
+
+ if sequence_parallel:
+ dq = dq.sum(dim=0)
+
+ if copy_back["dq"]:
+ dq_og.copy_(dq)
+ dq = dq_og
+ if copy_back["dk"]:
+ dk_og.copy_(dk)
+ dk = dk_og
+ if copy_back["dv"]:
+ dv_og.copy_(dv)
+ dv = dv_og
+
+ return dq, dk, dv, delta, None, None
diff --git a/modules/flash_attn_triton_amd/bwd_ref.py b/modules/flash_attn_triton_amd/bwd_ref.py
new file mode 100644
index 000000000..2b1befd88
--- /dev/null
+++ b/modules/flash_attn_triton_amd/bwd_ref.py
@@ -0,0 +1,271 @@
+import math
+import torch
+
+
+def attention_backward_core_ref_impl(
+ do, q, k, v, o, softmax_lse, sm_scale, causal, use_exp2
+):
+ # cast to float32
+ do = do.to(torch.float32)
+ q = q.to(torch.float32)
+ k = k.to(torch.float32)
+ v = v.to(torch.float32)
+ o = o.to(torch.float32)
+ softmax_lse = softmax_lse.to(torch.float32)
+
+ # recompute attention_scores. Make sure it matches the forward impl. i.e. It use float32
+ attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
+
+ # scale scores
+ attention_scaled_scores = sm_scale * attention_scores
+
+ # Apply causal mask if necessary
+ if causal:
+ L_q, L_k = q.shape[1], k.shape[1]
+ row_idx = torch.arange(L_q, device=q.device).unsqueeze(1)
+ col_idx = torch.arange(L_k, device=q.device).unsqueeze(0)
+ col_offset = L_q-L_k
+ causal_mask = row_idx >= (col_offset + col_idx)
+ # set -inf to places the causal mask is false
+ attention_scaled_scores = attention_scaled_scores.masked_fill(
+ torch.logical_not(causal_mask.unsqueeze(0)), float('-inf')
+ )
+
+ # compute probabilities using softmax_lse
+ if use_exp2:
+ RCP_LN = 1 / math.log(2)
+ attention_scaled_scores_base2 = attention_scaled_scores * RCP_LN
+ softmax_lse_base2 = softmax_lse * RCP_LN
+ softmax_lse_3d = softmax_lse_base2.unsqueeze(-1)
+ p = torch.exp2(attention_scaled_scores_base2 - softmax_lse_3d)
+ else:
+ softmax_lse_3d = softmax_lse.unsqueeze(-1)
+ p = torch.exp(attention_scaled_scores - softmax_lse_3d)
+
+ # compute gradient wrt v
+ dv = torch.matmul(p.transpose(-2, -1), do.to(torch.float32))
+
+ # compute dp
+ dp = torch.matmul(do, v.transpose(-2, -1))
+
+ # calculate ds using dp
+ delta = torch.sum(o * do, axis=-1).to(torch.float32) # what OAI kernel uses
+ delta_3d = delta.unsqueeze(-1)
+ ds = (p * (dp - delta_3d)) * sm_scale
+
+ # compute gradient wrt k
+ dk = torch.matmul(ds.transpose(-2, -1), q.to(torch.float32))
+
+ # compute gradient wrt q
+ dq = torch.matmul(ds, k.to(torch.float32))
+
+ # cast back to original dtype
+ dq = dq.to(torch.float16)
+ dk = dk.to(torch.float16)
+ dv = dv.to(torch.float16)
+
+ # remove d dim with size 1
+ delta = delta_3d.squeeze(-1)
+
+ return dq, dk, dv, delta
+
+def attention_varlen_backward_pytorch_ref_impl(
+ do,
+ q,
+ k,
+ v,
+ o,
+ softmax_lse,
+ sm_scale,
+ causal,
+ layout,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q, max_seqlen_k, # pylint: disable=unused-argument
+ use_exp2,
+):
+ # Ensure the layout is 'thd'
+ if layout != 'thd':
+ raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.")
+
+ batch_size = cu_seqlens_q.shape[0] - 1
+ num_heads = q.shape[1]
+ head_dim = q.shape[2] # pylint: disable=unused-variable
+
+ # Pre-allocate outputs
+ total_L_q = q.shape[0]
+ total_L_k = k.shape[0] # pylint: disable=unused-variable
+
+ dq = torch.zeros_like(q)
+ dk = torch.zeros_like(k)
+ dv = torch.zeros_like(v)
+ # delta has the same shape as softmax_lse: [total_L_q, num_heads]
+ delta = torch.zeros((total_L_q, num_heads), dtype=torch.float32, device=o.device)
+
+ for i in range(batch_size):
+ # Get the start and end indices for the current sequence
+ start_q = cu_seqlens_q[i].item()
+ end_q = cu_seqlens_q[i + 1].item()
+ start_k = cu_seqlens_k[i].item()
+ end_k = cu_seqlens_k[i + 1].item()
+
+ # Extract q_i, k_i, v_i, do_i, o_i, softmax_lse_i
+ q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
+ k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
+ v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
+ do_i = do[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
+ o_i = o[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
+ # softmax_lse has shape [total_L_q, num_heads]
+ softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, num_heads]
+ softmax_lse_i = softmax_lse_i.transpose(0, 1) # [num_heads, L_q_i]
+
+ # Permute to [num_heads, L_q_i, head_dim]
+ q_i = q_i.permute(1, 0, 2)
+ k_i = k_i.permute(1, 0, 2)
+ v_i = v_i.permute(1, 0, 2)
+ do_i = do_i.permute(1, 0, 2)
+ o_i = o_i.permute(1, 0, 2)
+ # softmax_lse_i is already in [num_heads, L_q_i]
+
+ # Call the core backward function for this sequence
+ dq_i, dk_i, dv_i, delta_i = attention_backward_core_ref_impl(
+ do_i,
+ q_i,
+ k_i,
+ v_i,
+ o_i,
+ softmax_lse_i,
+ sm_scale,
+ causal,
+ use_exp2
+ )
+
+ # Convert back to 'thd' layout
+ dq_i = dq_i.permute(1, 0, 2) # [L_q_i, num_heads, head_dim]
+ dk_i = dk_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim]
+ dv_i = dv_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim]
+
+ # Place outputs in pre-allocated tensors
+ dq[start_q:end_q, :, :] = dq_i
+ dk[start_k:end_k, :, :] += dk_i # Accumulate gradients for shared keys
+ dv[start_k:end_k, :, :] += dv_i # Accumulate gradients for shared values
+ # delta_i has shape [num_heads, L_q_i]
+ delta_i = delta_i.transpose(1, 0) # [L_q_i, num_heads]
+ delta[start_q:end_q, :] = delta_i
+
+ return dq, dk, dv, delta
+
+def attention_vanilla_backward_pytorch_ref_impl(
+ do,
+ q,
+ k,
+ v,
+ o,
+ softmax_lse,
+ sm_scale,
+ causal,
+ layout,
+ use_exp2,
+):
+ if layout == "bshd":
+ do = do.transpose(1, 2).contiguous()
+ q = q.transpose(1, 2).contiguous()
+ k = k.transpose(1, 2).contiguous()
+ v = v.transpose(1, 2).contiguous()
+ o = o.transpose(1, 2).contiguous()
+ elif layout == "bhsd":
+ pass
+ else:
+ raise ValueError(f"Unknown layout {layout}")
+
+ # Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format
+ batch_size, num_heads, seq_len_q, head_dim = q.shape
+ seq_len_k = k.shape[2]
+
+ # Merge batch and heads dimensions
+ do = do.reshape(batch_size * num_heads, seq_len_q, head_dim)
+ q = q.reshape(batch_size * num_heads, seq_len_q, head_dim)
+ k = k.reshape(batch_size * num_heads, seq_len_k, head_dim)
+ v = v.reshape(batch_size * num_heads, seq_len_k, head_dim)
+ softmax_lse = softmax_lse.reshape(batch_size * num_heads, seq_len_q)
+ o = o.reshape(batch_size * num_heads, seq_len_q, head_dim)
+
+ dq, dk, dv, delta = attention_backward_core_ref_impl(
+ do,
+ q,
+ k,
+ v,
+ o,
+ softmax_lse,
+ sm_scale,
+ causal,
+ use_exp2
+ )
+
+ # Reshape outputs back to [batch_size, num_heads, seq_len, head_dim]
+ dq = dq.reshape(batch_size, num_heads, seq_len_q, head_dim)
+ dk = dk.reshape(batch_size, num_heads, seq_len_k, head_dim)
+ dv = dv.reshape(batch_size, num_heads, seq_len_k, head_dim)
+ delta = delta.reshape(batch_size, num_heads, seq_len_q)
+
+ # Go back to original layout
+ if layout == "bshd":
+ dq = dq.transpose(1, 2)
+ dk = dk.transpose(1, 2)
+ dv = dv.transpose(1, 2)
+ elif layout == "bhsd":
+ pass
+ else:
+ raise ValueError(f"Unknown layout {layout}")
+
+ return dq, dk, dv, delta
+
+
+def attention_backward_pytorch_ref_impl(
+ do,
+ q,
+ k,
+ v,
+ o,
+ softmax_lse,
+ sm_scale,
+ causal,
+ layout,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ use_exp2
+):
+ if layout == "thd":
+ dq, dk, dv, delta = attention_varlen_backward_pytorch_ref_impl(
+ do,
+ q,
+ k,
+ v,
+ o,
+ softmax_lse,
+ sm_scale,
+ causal,
+ layout,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ use_exp2,
+ )
+ else:
+ dq, dk, dv, delta = attention_vanilla_backward_pytorch_ref_impl(
+ do,
+ q,
+ k,
+ v,
+ o,
+ softmax_lse,
+ sm_scale,
+ causal,
+ layout,
+ use_exp2,
+ )
+
+ return dq, dk, dv, delta
diff --git a/modules/flash_attn_triton_amd/fwd_decode.py b/modules/flash_attn_triton_amd/fwd_decode.py
new file mode 100644
index 000000000..7a2a234d6
--- /dev/null
+++ b/modules/flash_attn_triton_amd/fwd_decode.py
@@ -0,0 +1,700 @@
+import torch
+import triton
+import triton.language as tl
+from modules.flash_attn_triton_amd.utils import _strides, get_padded_headsize
+
+
+@triton.jit
+def _fwd_kernel_splitK(
+ Q,
+ K,
+ V,
+ sm_scale,
+ Out_splitK, # [B, H, split_k, Mq, K]
+ Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
+ K_new,
+ V_new,
+ Cache_seqlens,
+ Cache_batch_idx,
+ Alibi_slopes,
+ stride_qz,
+ stride_qm,
+ stride_qg,
+ stride_qh,
+ stride_qd,
+ stride_kz,
+ stride_kn,
+ stride_kg,
+ stride_kh,
+ stride_kd,
+ stride_vz,
+ stride_vn,
+ stride_vg,
+ stride_vh,
+ stride_vd,
+ stride_osk_zhg,
+ stride_osk_s,
+ stride_osk_m,
+ stride_osk_d, # pylint: disable=unused-argument
+ stride_mzhg,
+ stride_m2,
+ stride_ms,
+ stride_mm, # pylint: disable=unused-argument
+ stride_kn_z,
+ stride_kn_n,
+ stride_kn_g,
+ stride_kn_h,
+ stride_kn_d,
+ stride_vn_z,
+ stride_vn_n,
+ stride_vn_g,
+ stride_vn_h,
+ stride_vn_d,
+ stride_az,
+ stride_ah,
+ Z, # pylint: disable=unused-argument
+ N_CTX_Q,
+ N_CTX_K,
+ N_CTX_NEW,
+ BLOCK_N_PER_SPLIT,
+ H_q: tl.constexpr,
+ H_kv: tl.constexpr,
+ G_q: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr,
+ ACTUAL_BLOCK_DMODEL: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BOUNDS_CHECKS_N: tl.constexpr,
+ USE_CACHE_SEQLENs: tl.constexpr,
+ USE_CACHE_BATCH_IDX: tl.constexpr,
+ NEW_KV: tl.constexpr,
+ IS_GQA: tl.constexpr,
+ IS_CAUSAL: tl.constexpr,
+ USE_ALIBI: tl.constexpr,
+):
+ # Padding
+ PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
+ if PADDED_HEAD:
+ d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL
+
+ start_m = tl.program_id(0)
+ off_zhg = tl.program_id(1)
+ off_z = off_zhg // (H_q * G_q)
+ off_h_q = (off_zhg // G_q) % H_q
+ off_g_q = off_zhg % G_q
+ splitk_idx = tl.program_id(2)
+
+ # pick batch index
+ if USE_CACHE_BATCH_IDX:
+ cache_batch_idx = tl.load(Cache_batch_idx + off_z)
+ else:
+ cache_batch_idx = off_z
+
+ # Load ALiBi slope if enabled
+ if USE_ALIBI:
+ a_offset = off_z * stride_az + off_h_q * stride_ah
+ alibi_slope = tl.load(Alibi_slopes + a_offset)
+ else:
+ alibi_slope = None
+
+ lo = splitk_idx * BLOCK_N_PER_SPLIT
+ if USE_CACHE_SEQLENs:
+ cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z)
+ if NEW_KV:
+ kv_len = cache_seqlen_last_idx + N_CTX_NEW
+ else:
+ kv_len = cache_seqlen_last_idx
+ else:
+ kv_len = N_CTX_K
+ hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
+
+ HEAD_RATIO: tl.constexpr = H_q // H_kv
+ if IS_GQA:
+ k_head_idx = off_h_q // HEAD_RATIO
+ v_head_idx = k_head_idx
+ else:
+ k_head_idx = off_h_q
+ v_head_idx = off_h_q
+
+ # calculate base offset
+ k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg
+ v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg
+
+ # Copy new Keys and Values into Cache
+ if NEW_KV:
+ knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g
+
+ # Determine the starting position for new data in the cache
+ if USE_CACHE_SEQLENs:
+ start_idx = tl.load(Cache_seqlens + off_z)
+ else:
+ start_idx = N_CTX_K - N_CTX_NEW
+
+ # Copy new Keys
+ for i in range(0, N_CTX_NEW, BLOCK_N):
+ # Load from K_new
+ k_new_block = tl.load(
+ knew_base +
+ tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d +
+ (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n,
+ mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
+ (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
+ other=0
+ )
+
+ # Store to K
+ tl.store(
+ k_base +
+ tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd +
+ (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn,
+ k_new_block,
+ mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
+ (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
+ )
+
+ # Copy new Values
+ vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g
+ for i in range(0, N_CTX_NEW, BLOCK_N):
+ # Load from V_new
+ v_new_block = tl.load(
+ vnew_base +
+ (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n +
+ tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d,
+ mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
+ (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
+ other=0
+ )
+
+ # Store to V
+ tl.store(
+ v_base +
+ (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn +
+ tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd,
+ v_new_block,
+ mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
+ (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
+ )
+
+ Q_block_ptr = tl.make_block_ptr(
+ base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg,
+ shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL),
+ strides=(stride_qm, stride_qd),
+ offsets=(start_m * BLOCK_M, 0),
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
+ order=(1, 0),
+ )
+
+ K_block_ptr = tl.make_block_ptr(
+ base=k_base,
+ shape=(ACTUAL_BLOCK_DMODEL, hi),
+ strides=(stride_kd, stride_kn),
+ offsets=(0, lo),
+ block_shape=(BLOCK_DMODEL, BLOCK_N),
+ order=(0, 1),
+ )
+ V_block_ptr = tl.make_block_ptr(
+ base=v_base,
+ shape=(hi, ACTUAL_BLOCK_DMODEL),
+ strides=(stride_vn, stride_vd),
+ offsets=(lo, 0),
+ block_shape=(BLOCK_N, BLOCK_DMODEL),
+ order=(1, 0),
+ )
+
+ K_scale_shift_block_ptr = None
+ V_scale_shift_block_ptr = None
+
+ # initialize pointer to m and l
+ m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821
+
+ # scale sm_scale by log_2(e) and use
+ # 2^x instead of exp in the loop because CSE and LICM
+ # don't work as expected with `exp` in the loop
+ qk_scale = sm_scale * 1.44269504
+ # load q: it will stay in SRAM throughout
+ q = tl.load( # noqa: F821
+ tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, ))
+ q = (q * qk_scale).to(q.dtype)
+ if PADDED_HEAD:
+ q = tl.where(d_mask[None, :], q, 0.0)
+
+ # loop over k, v and update accumulator
+ for start_n in range(lo, hi, BLOCK_N):
+ k, v = load_k_v_group(
+ K_block_ptr,
+ V_block_ptr,
+ K_scale_shift_block_ptr,
+ V_scale_shift_block_ptr,
+ BOUNDS_CHECKS_N,
+ 1,
+ BLOCK_DMODEL,
+ ACTUAL_BLOCK_DMODEL,
+ Q.dtype.element_ty,
+ 0,
+ )
+ if PADDED_HEAD:
+ k = tl.where(d_mask[:, None], k, 0.0)
+ v = tl.where(d_mask[None, :], v, 0.0)
+
+ # -- compute qk ---
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ qk += tl.dot(q, k) # noqa: F821
+
+ if USE_ALIBI:
+ row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ col_idx = start_n + tl.arange(0, BLOCK_N)
+
+ # Compute relative positions
+ relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :])
+ relative_pos = tl.abs(relative_pos)
+
+ # Compute ALiBi bias
+ alibi_bias = -1 * alibi_slope * relative_pos
+ qk += (alibi_bias * 1.44269504)
+
+ # Apply causal mask if IS_CAUSAL is True
+ if IS_CAUSAL:
+ row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ col_idx = start_n + tl.arange(0, BLOCK_N)
+
+ # create a N_CTX_Q x kv_len causal mask
+ col_offset = N_CTX_Q - kv_len
+ causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :])
+
+ # Apply the mask
+ qk = tl.where(causal_mask, qk, float("-inf"))
+
+ # TODO: This is slow, and only needed at the last iteration.
+ # Maybe we can unroll the last iteration instead?
+ if BOUNDS_CHECKS_N:
+ qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
+
+ # -- compute scaling constant ---
+ m_i_new = tl.maximum(m_i, tl.max(qk, 1))
+ if IS_CAUSAL:
+ alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf")))
+ else:
+ alpha = tl.math.exp2(m_i - m_i_new)
+ # cause of nan because subtracting infs
+ if IS_CAUSAL:
+ qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf"))
+ else:
+ qk = qk - m_i_new[:, None]
+
+ p = tl.math.exp2(qk)
+
+ # -- update m_i and l_i --
+ l_i = l_i * alpha + tl.sum(p, 1)
+ m_i = m_i_new
+ p = p.to(Q.dtype.element_ty)
+
+ # -- scale and update acc --
+ acc *= alpha[:, None]
+ acc += tl.dot(p.to(v.dtype), v)
+
+ # update pointers
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
+
+ # write back O
+ O_block_ptr = tl.make_block_ptr(
+ base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
+ shape=(N_CTX_Q, BLOCK_DMODEL),
+ strides=(stride_osk_m, 1),
+ offsets=(start_m * BLOCK_M, 0),
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
+ order=(1, 0),
+ )
+ tl.store(
+ tl.advance(O_block_ptr, (0, 0)),
+ acc,
+ boundary_check=(0, ),
+ )
+ # Write metadata for split-K reduction
+ Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M +
+ tl.arange(0, BLOCK_M))
+ tl.store(Metadata_ptr, m_i)
+ tl.store(Metadata_ptr + stride_m2, l_i)
+
+
+@triton.jit
+def load_k_v_group(
+ K_block_ptr,
+ V_block_ptr,
+ K_scale_shift_block_ptr, V_scale_shift_block_ptr, # pylint: disable=unused-argument
+ BOUNDS_CHECKS_N: tl.constexpr,
+ PACKED_PER_VAL: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # pylint: disable=unused-argument
+ ACTUAL_BLOCK_DMODEL: tl.constexpr,
+ dtype: tl.constexpr, # pylint: disable=unused-argument
+ group_id: tl.constexpr,
+):
+ # Load K/V for a given block
+ # Advance to the current quantization group
+ K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0))
+ V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id))
+
+ # -- load k, v --
+ k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ())
+ v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ())
+
+ return k, v
+
+
+@triton.jit
+def cast_uint32_to_half2(scale_shift):
+ # Extract two float16 packed into one int32
+ scale = scale_shift & 0xFFFF
+ shift = scale_shift >> 16
+ scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
+ shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
+ return scale, shift
+
+
+@triton.jit
+def dequantize(
+ x_,
+ scale,
+ shift,
+ PACKED_PER_VAL: tl.constexpr = 8,
+):
+ # PACKED_PER_VAL is the number of values packed into
+ # each element x_. For example, for int4 quantization
+ #and x_ of type int32, PACKED_PER_VAL is 8.
+
+ BLOCK_N: tl.constexpr = x_.shape[0]
+ BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
+ offsets = tl.arange(0, PACKED_PER_VAL) * 4
+ quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
+
+ quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL))
+ # Trick - instead of converting int4 to float16 we view it as float16
+ # and then multiply by 32768 * 512 == 2**24
+ quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
+ quant_offset = (quant_offset * 32768.0).to(tl.float16)
+ scale_512 = scale * 512
+
+ dequant = quant_offset * scale_512 + shift
+ return dequant
+
+
+@triton.jit
+def _splitK_reduce(
+ Out_splitK, # [B, H, split_k, Mq, K]
+ Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
+ Out, # [B, H, M, K]
+ LSE, # [B, H, M]
+ stride_osk_zhg,
+ stride_osk_s,
+ stride_osk_m,
+ stride_osk_k,
+ stride_mzhg,
+ stride_m2,
+ stride_ms,
+ stride_mm,
+ stride_oz,
+ stride_oh,
+ stride_og,
+ stride_om,
+ stride_ok, # pylint: disable=unused-argument
+ stride_lse_zhg,
+ stride_lse_m, M_ceil: tl.constexpr, # pylint: disable=unused-argument
+ BLOCK_SIZE: tl.constexpr,
+ H: tl.constexpr,
+ G: tl.constexpr,
+ split_k: tl.constexpr,
+ splitK_pow2: tl.constexpr,
+ use_mask: tl.constexpr,
+ IS_CAUSAL: tl.constexpr,
+):
+ off_zhg = tl.program_id(0)
+ off_z = off_zhg // (H * G)
+ off_h = (off_zhg // G) % H
+ off_g = off_zhg % G
+ off_m = tl.program_id(1)
+ off_k = tl.program_id(2)
+
+ # read chunk
+ spk_idx = tl.arange(0, splitK_pow2)
+ kidx = tl.arange(0, BLOCK_SIZE)
+
+ Metadata_ptr = Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm
+
+ o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE +
+ stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k)
+
+ # read max values of each splitK
+ if use_mask:
+ spk_mask = spk_idx < split_k
+ l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf"))
+ l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0)
+ acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0)
+ else:
+ l_m = tl.load(Metadata_ptr)
+ l_sum = tl.load(Metadata_ptr + stride_m2)
+ acc = tl.load(o_ptr)
+
+ g_m = tl.max(l_m, axis=0)
+
+ if IS_CAUSAL:
+ l_m_offset = l_m - g_m
+ alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0)
+ else:
+ alpha = tl.math.exp2(l_m - g_m)
+
+ # read sum
+ l_sum *= alpha
+ g_sum = tl.sum(l_sum, axis=0)
+ acc = acc * alpha[:, None]
+
+ if IS_CAUSAL:
+ # Avoid division by zero
+ g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0)
+ acc_out = tl.sum(acc, axis=0) / g_sum_safe
+ else:
+ acc_out = tl.sum(acc, axis=0) / g_sum
+
+ # Store output
+ Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m +
+ off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE))
+ tl.store(Out_ptr, acc_out)
+
+ # Store lse
+ l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
+ if IS_CAUSAL:
+ lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m)
+ tl.store(l_ptrs, lse)
+ else:
+ tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504)
+
+
+def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
+ # Scale and shift are such that quantization linearly maps
+ # int4 values range [0..15] to input values range min(k)..max(k)
+ # individually for every row
+ k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups)
+ max_vals = torch.max(k, dim=-1, keepdim=True).values
+ min_vals = torch.min(k, dim=-1, keepdim=True).values
+ scale_k: torch.Tensor = (max_vals - min_vals) / 15
+
+ shift_k = torch.min(k, dim=-1, keepdim=True).values
+ scale_k = scale_k.to(torch.float16)
+ shift_k = shift_k.to(torch.float16)
+
+ in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5
+ in_bytes = in_bytes.to(torch.uint8)
+ in_int4 = in_bytes & 0xF
+ in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4)
+ scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1)
+ k_quant = torch.concat(
+ [
+ scale_shift.flatten(start_dim=-2),
+ in_int4_packed.flatten(start_dim=-2),
+ ],
+ dim=-1,
+ ).view(torch.int16)
+ return k_quant
+
+
+def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
+ k_i16 = quant_k.view(torch.int16)
+ k_ui8 = k_i16.view(torch.uint8)
+
+ ss_size = num_groups * 4
+ scale_shift_ui8 = k_ui8[..., 0:ss_size]
+ scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4)
+ scale = scale_shift_ui8[..., 0:2].view(torch.float16)
+ shift = scale_shift_ui8[..., 2:4].view(torch.float16)
+
+ kv_ui8 = k_ui8[..., ss_size:]
+ k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1)
+ k1_i4 = k_ui8 & 0xF
+ k2_i4 = (k_ui8 & 0xF0) >> 4
+ k_shape = k1_i4.shape
+ k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)
+ k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)
+
+ out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device)
+ out[..., ::2] = k1_f16
+ out[..., 1::2] = k2_f16
+ out = out.reshape(*k_shape[:-2], -1)
+
+ return out
+
+
+def get_split_k(B: int, G: int, H: int, Mk: int) -> int:
+ """Heuristic for the number of splits"""
+ bh = max(B * H, 1) # NOTE: Handle B*h=0 case
+ split_k = max(Mk, 1024) // bh
+ max_chunk_size = 64
+ while split_k > 0 and Mk / split_k < max_chunk_size:
+ split_k = split_k // 2
+ while B * H * G * split_k >= 1024:
+ split_k = split_k // 2
+ split_k = min(split_k, 512)
+ split_k = max(split_k, 1)
+ return split_k
+
+def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new):
+ # kernel config
+ BLOCK_M = 16
+ BLOCK_N = 64
+ SPLIT_K = None
+ NUM_QUANT_GROUPS = 1 # pylint: disable=unused-variable
+
+ # kernels expects "bsghd"
+ original_layout = layout
+ if layout == "bshd":
+ q = q.unsqueeze(2)
+ k = k.unsqueeze(2)
+ v = v.unsqueeze(2)
+ if new_kv:
+ k_new = k_new.unsqueeze(2)
+ v_new = v_new.unsqueeze(2)
+ layout = "bsghd"
+ elif layout == "bhsd":
+ q = q.permute(0, 2, 1, 3).unsqueeze(2)
+ k = k.permute(0, 2, 1, 3).unsqueeze(2)
+ v = v.permute(0, 2, 1, 3).unsqueeze(2)
+ if new_kv:
+ k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2)
+ v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2)
+ layout = "bsghd"
+ elif layout == "bsghd":
+ pass
+ elif layout is None:
+ raise ValueError("Layout not given")
+ assert layout == "bsghd"
+
+ # get dims
+ batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape
+ _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape # pylint: disable=unused-variable
+ _, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape # pylint: disable=unused-variable
+
+ assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}"
+
+ # get padded size
+ dim_padded = get_padded_headsize(dim_k)
+
+ # Handle MQA/GQA case
+ if heads_per_group_q > heads_per_group_k:
+ is_gqa = True
+ elif heads_per_group_q < heads_per_group_k:
+ raise ValueError("heads_per_group_q < heads_per_group_k")
+ else:
+ is_gqa = False
+
+ assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}"
+
+ if SPLIT_K is not None:
+ split_k = SPLIT_K
+ else:
+ # Use heuristics
+ split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_k) # NOTE: should the split think about seqlens?
+
+ seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M
+ out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device)
+ metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device)
+ lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), device=q.device, dtype=torch.float32)
+ grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k)
+
+ num_warps = 1
+ split_size = (seqlen_k + split_k - 1) // split_k
+ use_cache_seqlens = cache_seqlens is not None
+
+ # TODO: enable quantization
+ _fwd_kernel_splitK[grid](
+ Q=q,
+ K=k,
+ V=v,
+ sm_scale=sm_scale,
+ Out_splitK=out_splitk,
+ Metadata=metadata,
+ K_new = k_new,
+ V_new = v_new,
+ Cache_seqlens=cache_seqlens,
+ Cache_batch_idx=cache_batch_idx,
+ Alibi_slopes=alibi_slopes,
+ **_strides(q, "qz", "qm", "qg", "qh", "qd"),
+ **_strides(k, "kz", "kn", "kg", "kh", "kd"),
+ **_strides(v, "vz", "vn", "vg", "vh", "vd"),
+ **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"),
+ **_strides(metadata, "mzhg", "m2", "ms", "mm"),
+ **_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"),
+ **_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"),
+ **_strides(alibi_slopes, "az", "ah"),
+ Z=batch_size,
+ H_q=heads_per_group_q,
+ H_kv=heads_per_group_k,
+ G_q=n_group_q,
+ N_CTX_Q=seqlen_q,
+ N_CTX_K=seqlen_k,
+ N_CTX_NEW=k_new.shape[1] if new_kv else None,
+ BLOCK_N_PER_SPLIT=split_size,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ BLOCK_DMODEL=dim_padded,
+ ACTUAL_BLOCK_DMODEL=dim_k,
+ BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens,
+ USE_CACHE_SEQLENs=use_cache_seqlens,
+ USE_CACHE_BATCH_IDX=cache_batch_idx is not None,
+ NEW_KV=new_kv,
+ IS_GQA=is_gqa,
+ IS_CAUSAL=causal,
+ USE_ALIBI=False if alibi_slopes is None else True,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+
+ out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype)
+
+ # Merge together
+ splitK_pow2 = triton.next_power_of_2(split_k)
+ use_mask = splitK_pow2 > split_k
+ if batch_size * n_group_q * heads_per_group_q * seqlen_q >= 512:
+ k_block_num = 1
+ else:
+ k_block_num = 2
+ assert dim_padded % k_block_num == 0
+ k_block_size = dim_padded // k_block_num
+ grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num)
+
+ _splitK_reduce[grid](
+ out_splitk,
+ metadata,
+ out,
+ lse,
+ **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
+ **_strides(metadata, "mzhg", "m2", "ms", "mm"),
+ **_strides(out, "oz", "om", "og", "oh", "ok"),
+ **_strides(lse, "lse_zhg", "lse_m"),
+ M_ceil=seqlen_q_ceil,
+ BLOCK_SIZE=k_block_size,
+ G=n_group_q,
+ H=heads_per_group_q,
+ # TODO: Tune num_warps
+ split_k=split_k,
+ splitK_pow2=splitK_pow2,
+ use_mask=use_mask,
+ IS_CAUSAL=causal,
+ num_warps=4)
+
+ lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q])
+ if q.ndim == 4:
+ # BMGHK -> BMHK
+ assert n_group_q == 1
+ out = out[:, :, 0]
+ lse = lse[:, 0]
+ if seqlen_k == 0:
+ out.zero_()
+ out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous()
+
+ # output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q
+ if original_layout == "bshd":
+ # out=out.transpose(1, 2).contiguous() # this screws up heads and data.
+ # the data is laid out properly. Just need to reshape dims
+ out = out.reshape(batch_size, seqlen_q, -1, dim_padded)
+
+ return out.narrow(-1, 0, dim_k), lse
diff --git a/modules/flash_attn_triton_amd/fwd_prefill.py b/modules/flash_attn_triton_amd/fwd_prefill.py
new file mode 100644
index 000000000..3e2cd32af
--- /dev/null
+++ b/modules/flash_attn_triton_amd/fwd_prefill.py
@@ -0,0 +1,634 @@
+import torch
+import triton
+import triton.language as tl
+from modules.flash_attn_triton_amd.utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, AUTOTUNE
+
+
+@triton.jit
+def cdiv_fn(x, y):
+ return (x + y - 1) // y
+
+
+@triton.jit
+def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): # pylint: disable=unused-argument
+ ms = tl.arange(0, m)
+ ns = tl.arange(0, n)
+ return philox_offset + ms[:, None] * stride + ns[None, :]
+
+
+@triton.jit
+def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
+ rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)
+ # TODO: use tl.randint for better performance
+ return tl.rand(philox_seed, rng_offsets)
+
+
+@triton.jit
+def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
+ rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
+ rng_keep = rng_output > dropout_p
+ return rng_keep
+
+
+# Convenience function to load with optional boundary checks.
+# "First" is the major dim, "second" is the minor dim.
+@triton.jit
+def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second):
+ if offset_first is not None and offset_second is not None:
+ mask = (offset_first[:, None] < boundary_first) & \
+ (offset_second[None, :] < boundary_second)
+ tensor = tl.load(ptrs, mask=mask, other=0.0)
+ elif offset_first is not None:
+ mask = offset_first[:, None] < boundary_first
+ tensor = tl.load(ptrs, mask=mask, other=0.0)
+ elif offset_second is not None:
+ mask = offset_second[None, :] < boundary_second
+ tensor = tl.load(ptrs, mask=mask, other=0.0)
+ else:
+ tensor = tl.load(ptrs)
+ return tensor
+
+
+@triton.jit
+def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False):
+ # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix
+ # for casual mask we want something like this where (1 is kept and 0 is masked)
+ # seqlen_q = 2 and seqlen_k = 5
+ # 1 1 1 1 0
+ # 1 1 1 1 1
+ # seqlen_q = 5 and seqlen_k = 2
+ # 0 0
+ # 0 0
+ # 0 0
+ # 1 0
+ # 1 1
+ # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal
+ # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
+ # 1. offs_m[:,None] = [[0],
+ # [1],
+ # 2. offs_m[:,None] + seqlen_k = [[5],
+ # [6],
+ # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
+ # [4],
+ # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1],
+ # [4], [ 4, 3, 2, 1, 0]]
+ # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
+ # [ -4, -3, -2, -1, 0]],
+ relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :]
+ alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
+ if transpose:
+ return alibi_block.T
+ else:
+ return alibi_block
+
+
+@triton.jit
+def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m,
+ actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, exp_scores_ptrs,
+ block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, # pylint: disable=unused-argument
+ IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
+ OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
+ ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr,
+ ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_EXP2: tl.constexpr,
+ RETURN_SCORES: tl.constexpr):
+ if USE_EXP2:
+ RCP_LN2: tl.constexpr = 1.4426950408889634
+
+ # loop over k, v, and update accumulator
+ for start_n in range(block_min, block_max, BLOCK_N):
+ # For padded blocks, we will overrun the tensor size if
+ # we load all BLOCK_N. For others, the blocks are all within range.
+ if MASK_STEPS:
+ k_offs_n = start_n + tl.arange(0, BLOCK_N)
+ else:
+ k_offs_n = None
+ k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)
+ k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k)
+ if PRE_LOAD_V:
+ # We can use the same offsets as k, just with dims transposed.
+ v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ # We start from end of seqlen_k so only the first iteration would need
+ # to be checked for padding if it is not a multiple of block_n
+ # TODO: This can be optimized to only be true for the padded block.
+ if MASK_STEPS:
+ # If this is the last block / iteration, we want to
+ # mask if the sequence length is not a multiple of block size
+ # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn.
+ # last step might get wasted but that is okay. check if this masking works For
+ # that case.
+ if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
+ boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
+ size_n = start_n + OFFS_N[None, :]
+ mask = size_n < boundary_m[:, None]
+ qk = tl.where(mask, qk, float("-inf"))
+
+ # -- compute qk ----
+ qk += tl.dot(q, k)
+ qk_scaled = qk * SM_SCALE
+ if RETURN_SCORES:
+ score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
+ tl.store(score_ptrs, qk_scaled, mask=score_mask)
+
+ if IS_CAUSAL:
+ causal_boundary = start_n + offs_n_causal
+ causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
+ qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf"))
+ if bias_ptrs is not None:
+ bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None
+ bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k)
+ qk_scaled += bias
+
+ if alibi_slope is not None:
+ # Compute the global position of each token within the sequence
+ global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ global_n_positions = start_n + tl.arange(0, BLOCK_N)
+ alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions,
+ global_n_positions)
+ qk_scaled += alibi_block
+ # get max scores so far
+ m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1))
+
+ # scale and subtract max
+ q_shifted = qk_scaled - m_ij[:, None]
+ if RETURN_SCORES:
+ # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
+ scores_scaled_shifted_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
+ tl.store(scores_scaled_shifted_ptrs, q_shifted, mask=scores_scaled_shifted_mask)
+
+ # Compute scaled QK and softmax probabilities
+ if USE_EXP2:
+ p = tl.math.exp2(q_shifted * RCP_LN2)
+ else:
+ p = tl.math.exp(q_shifted)
+
+ # CAVEAT: Must update l_ij before applying dropout
+ l_ij = tl.sum(p, 1)
+ if ENABLE_DROPOUT:
+ philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N
+ keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k)
+ if RETURN_SCORES:
+ # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
+ exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
+ tl.store(exp_scores_ptrs, tl.where(keep, p, -p), mask=exp_score_mask)
+ p = tl.where(keep, p, 0.0)
+ elif RETURN_SCORES:
+ # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
+ exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
+ tl.store(exp_scores_ptrs, p, mask=exp_score_mask)
+
+ # -- update output accumulator --
+ # alpha is an adjustment factor for acc and li as we loop and find new maxes
+ # store the diff in maxes to adjust acc and li as we discover new maxes
+ m_diff = m_i - m_ij
+ if USE_EXP2:
+ alpha = tl.math.exp2(m_diff * RCP_LN2)
+ else:
+ alpha = tl.math.exp(m_diff)
+ acc = acc * alpha[:, None]
+ v = None
+ if not PRE_LOAD_V:
+ v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
+ # -- update m_i and l_i
+ l_i = l_i * alpha + l_ij
+ # update m_i and l_i
+ m_i = m_ij
+ acc += tl.dot(p.to(v.type.element_ty), v)
+ k_ptrs += BLOCK_N * stride_kn
+ v_ptrs += BLOCK_N * stride_vk
+ if bias_ptrs is not None:
+ bias_ptrs += BLOCK_N * stride_bn
+ if RETURN_SCORES:
+ score_ptrs += BLOCK_N
+ scores_scaled_shifted_ptrs += BLOCK_N
+ exp_scores_ptrs += BLOCK_N
+ return acc, l_i, m_i
+
+
+def get_cdna_autotune_configs():
+ return [
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=4),
+ # Fall-back config.
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=4),
+ ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']
+
+
+def get_rdna_autotune_configs():
+ return [
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=2),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=2),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=2),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=2),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=2),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=2),
+ # Fall-back config.
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
+ num_warps=2),
+ ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']
+
+
+def get_autotune_configs():
+ if AUTOTUNE:
+ if is_rdna():
+ return get_rdna_autotune_configs()
+ elif is_cdna():
+ return get_cdna_autotune_configs()
+ else:
+ raise ValueError("Unknown Device Type")
+ else:
+ return [
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False},
+ num_stages=1,
+ num_warps=4,
+ ),
+ ], [
+ "IS_CAUSAL",
+ "dropout_p",
+ "MAX_SEQLENS_Q",
+ "MAX_SEQLENS_K",
+ "ACTUAL_BLOCK_DMODEL",
+ "VARLEN",
+ "HQ",
+ "HK",
+ ]
+
+
+autotune_configs, autotune_keys = get_autotune_configs()
+
+@triton.autotune(
+ configs=autotune_configs,
+ key=autotune_keys,
+ # use_cuda_graph=True,
+)
+@triton.jit
+def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
+ stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
+ stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, # pylint: disable=unused-argument
+ stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
+ dropout_p, philox_seed, philox_offset_base, scores, scores_scaled_shifted, exp_scores, alibi_slopes, HQ: tl.constexpr,
+ HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
+ MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
+ ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr):
+ start_m = tl.program_id(0)
+ off_h_q = tl.program_id(1)
+ off_z = tl.program_id(2)
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+ if VARLEN:
+ cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
+ cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
+ # print("cu_seqlens_q_start:", cu_seqlens_q_start)
+
+ seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
+ # We have a one-size-fits-all grid in id(0). Some seqlens might be too
+ # small for all start_m so for those we return early.
+ if start_m * BLOCK_M > seqlen_q:
+ return
+ cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
+ cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
+ seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
+ else:
+ cu_seqlens_q_start = 0
+ cu_seqlens_k_start = 0
+ seqlen_q = MAX_SEQLENS_Q
+ seqlen_k = MAX_SEQLENS_K
+
+ # Now we compute whether we need to exit early due to causal masking.
+ # This is because for seqlen_q > seqlen_k, M rows of the attn scores
+ # are completely masked, resulting in 0s written to the output, and
+ # inf written to LSE. We don't need to do any GEMMs in this case.
+ # This block of code determines what N is, and if this WG is operating
+ # on those M rows.
+ n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
+ if IS_CAUSAL:
+ # If seqlen_q == seqlen_k, the attn scores are a square matrix.
+ # If seqlen_q != seqlen_k, attn scores are rectangular which means
+ # the causal mask boundary is bottom right aligned, and ends at either
+ # the top edge (seqlen_q < seqlen_k) or left edge.
+ # This captures the decrease in n_blocks if we have a rectangular attn matrix
+ n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
+ # This is what adjusts the block_max for the current WG, only
+ # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
+ n_blocks = min(n_blocks, n_blocks_seqlen)
+ # If we have no blocks after adjusting for seqlen deltas, this WG is part of
+ # the blocks that are all 0. We exit early.
+ if n_blocks <= 0:
+ o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om
+ o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
+ o_ptrs_mask = offs_m[:, None] < seqlen_q
+ # We still need to write 0s to the result
+ tl.store(o_ptrs, acc, mask=o_ptrs_mask)
+ # The tensor allocated for L is based on MAX_SEQLENS_Q as that is
+ # statically known.
+ l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m
+ l_ptrs = l_offset + offs_m * stride_lse_m
+
+ l = tl.full([BLOCK_M], value=0.0, dtype=tl.float32)
+
+ # mask_m_offsets = start_m + tl.arange(0, BLOCK_M)
+ # lse_mask = mask_m_offsets < causal_start_idx
+ # softmax_lse = tl.where(lse_mask, 0.0, softmax_lse)
+ l_ptrs_mask = offs_m < MAX_SEQLENS_Q
+ tl.store(l_ptrs, l, mask=l_ptrs_mask)
+ # TODO: Should dropout and return encoded softmax be handled here too?
+ return
+
+ # If MQA / GQA, set the K and V head offsets appropriately.
+ GROUP_SIZE: tl.constexpr = HQ // HK
+ if GROUP_SIZE != 1:
+ off_h_k = off_h_q // GROUP_SIZE
+ else:
+ off_h_k = off_h_q
+
+ n_extra_tokens = 0
+ # print("n_extra_tokens:", n_extra_tokens)
+ # print("seqlen_k:", seqlen_k)
+ # print("BLOCK_N:", BLOCK_N)
+ # return
+ if seqlen_k < BLOCK_N:
+ n_extra_tokens = BLOCK_N - seqlen_k
+ elif seqlen_k % BLOCK_N:
+ n_extra_tokens = seqlen_k % BLOCK_N
+ PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
+
+ # Compute pointers for all the tensors used in this kernel.
+ q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
+ q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
+ k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
+ k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn
+ v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
+ v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn
+ if USE_BIAS:
+ # Note: this might get large enough to overflow on some configs
+ bias_offset = off_h_q * stride_bh
+ bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn
+ else:
+ bias_ptrs = None
+
+ if USE_ALIBI:
+ a_offset = off_z * stride_az + off_h_q * stride_ah
+ alibi_slope = tl.load(alibi_slopes + a_offset)
+ else:
+ alibi_slope = None
+
+ if RETURN_SCORES:
+ scores_offset = scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
+ score_ptrs = scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
+
+ scores_scaled_shifted_offset = scores_scaled_shifted + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
+ scores_scaled_shifted_ptrs = scores_scaled_shifted_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
+
+ exp_scores_offset = exp_scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
+ exp_scores_ptrs = exp_scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
+ else:
+ score_ptrs = None
+ scores_scaled_shifted_ptrs = None
+ exp_scores_ptrs = None
+
+ if ENABLE_DROPOUT:
+ off_hz = off_z * HQ + off_h_q
+ batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k
+ else:
+ batch_philox_offset = 0
+ # initialize pointer to m and l
+ m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
+ l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+ # Q is loaded once at the beginning and shared by all N blocks.
+ q_ptrs_mask = offs_m[:, None] < seqlen_q
+ if PADDED_HEAD:
+ q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
+ q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)
+
+ # Here we compute how many full and masked blocks we have.
+ padded_block_k = n_extra_tokens != 0
+ is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
+ if IS_CAUSAL:
+ # There are always at least BLOCK_M // BLOCK_N masked blocks.
+ # Additionally there might be one more due to dissimilar seqlens.
+ masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
+ else:
+ # Padding on Q does not need to be masked in the FA loop.
+ masked_blocks = padded_block_k
+ # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block.
+ # In this case we might exceed n_blocks so pick the min.
+ masked_blocks = min(masked_blocks, n_blocks)
+ n_full_blocks = n_blocks - masked_blocks
+ block_min = 0
+ block_max = n_blocks * BLOCK_N
+ # Compute for full blocks. Here we set causal to false regardless of its actual
+ # value because there is no masking. Similarly we do not need padding.
+ if n_full_blocks > 0:
+ block_max = (n_blocks - masked_blocks) * BLOCK_N
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
+ start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset,
+ exp_scores_ptrs,
+ # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
+ block_min, block_max, 0, 0, 0, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs,
+ # IS_CAUSAL, ....
+ False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
+ # _, MASK_STEPS, ...
+ PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD,
+ ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
+ block_min = block_max
+ block_max = n_blocks * BLOCK_N
+
+ tl.debug_barrier()
+ # Remaining blocks, if any, are full / not masked.
+ if masked_blocks > 0:
+ if IS_CAUSAL:
+ offs_n_causal = offs_n + (seqlen_q - seqlen_k)
+ else:
+ offs_n_causal = 0
+ k_ptrs += n_full_blocks * BLOCK_N * stride_kn
+ v_ptrs += n_full_blocks * BLOCK_N * stride_vk
+ if USE_BIAS:
+ bias_ptrs += n_full_blocks * BLOCK_N * stride_bn
+ if RETURN_SCORES:
+ score_ptrs += n_full_blocks * BLOCK_N
+ scores_scaled_shifted_ptrs += n_full_blocks * BLOCK_N
+ exp_scores_ptrs += n_full_blocks * BLOCK_N
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
+ start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset,
+ exp_scores_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
+ n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs,
+ IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
+ # _, MASK_STEPS, ...
+ PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD,
+ ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
+ # epilogue
+ # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger.
+ l_recip = 1 / l_i[:, None]
+ acc = acc * l_recip
+ if ENABLE_DROPOUT:
+ acc = acc / (1 - dropout_p)
+ # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
+ # then we have one block with a row of all NaNs which come from computing
+ # softmax over a row of all -infs (-inf - inf = NaN). We check for that here
+ # and store 0s where there are NaNs as these rows should've been zeroed out.
+ end_m_idx = (start_m + 1) * BLOCK_M
+ start_m_idx = start_m * BLOCK_M
+ causal_start_idx = seqlen_q - seqlen_k
+ acc = acc.to(Out.type.element_ty)
+ if IS_CAUSAL:
+ if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
+ out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32)
+ mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
+ out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
+ z: tl.tensor = 0.0
+ acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
+
+ # write back LSE(Log Sum Exponents), the log of the normalization constant
+ l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m
+ l_ptrs = l_offset + offs_m * stride_lse_m
+ if USE_EXP2:
+ RCP_LN2: tl.constexpr = 1.4426950408889634
+ LN2: tl.constexpr = 0.6931471824645996
+ # compute log-sum-exp in base 2 units
+ mi_base2 = m_i * RCP_LN2
+ softmax_lse = mi_base2 + tl.math.log2(l_i)
+ # convert back to natural units
+ softmax_lse *= LN2
+ else:
+ softmax_lse = m_i + tl.math.log(l_i)
+
+ if IS_CAUSAL:
+ # zero out nans caused by -infs when doing causal
+ lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx
+ softmax_lse = tl.where(lse_mask, 0.0, softmax_lse)
+
+ # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows.
+ # This is only true for the last M block. For others, overflow_size will be -ve
+ overflow_size = end_m_idx - seqlen_q
+ if overflow_size > 0:
+ boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32)
+ l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary
+ tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) # the log of the normalization constant
+ else:
+ tl.store(l_ptrs, softmax_lse) # the log of the normalization constant
+
+ # write back O
+ o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om
+ o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
+ o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1)
+ if overflow_size > 0:
+ o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q)
+ if PADDED_HEAD:
+ o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
+ tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)
+
+
+def attention_prefill_forward_triton_impl(
+ q,
+ k,
+ v,
+ o,
+ sm_scale,
+ alibi_slopes,
+ causal,
+ bias,
+ dropout_p,
+ layout,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlens_q,
+ max_seqlens_k,
+ return_scores,
+ use_exp2):
+ # check if varlen
+ is_varlen = layout == "thd"
+
+ # NOTE: a large bias tensor leads to overflow during pointer arithmetic
+ if bias is not None:
+ assert bias.numel() < 2**31
+
+ batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) # pylint: disable=unused-variable
+ q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
+
+ # Get closest power of 2 over or equal to 32.
+ padded_d_model = 1 << (head_size - 1).bit_length()
+ # Smallest head_dim supported is 16. If smaller, the tile in the
+ # kernel is padded - there is no padding in memory for any dims.
+ padded_d_model = max(padded_d_model, 16)
+
+ grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) # pylint: disable=unnecessary-lambda-assignment
+
+ if return_scores:
+ scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
+ dtype=torch.float32)
+ scores_scaled_shifted = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
+ dtype=torch.float32)
+ scores_strides = (scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3))
+ else:
+ scores = None
+ scores_scaled_shifted = None
+ scores_strides = (0, 0 , 0 , 0)
+
+ # exp_scores is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out
+ # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according
+ # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing
+ # only. This return holds no useful output aside from debugging.
+ if return_scores:
+ exp_scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
+ dtype=torch.float32)
+ else:
+ exp_scores = None
+
+ # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities)
+ if is_varlen:
+ softmax_lse = torch.empty((q.shape[0], nheads_q), device=q.device, dtype=torch.float32)
+ stride_lse_m, stride_lse_h = softmax_lse.stride()
+ stride_lse_z = 0
+ else:
+ softmax_lse = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32)
+ stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride()
+
+ # Seed the RNG so we get reproducible results for testing.
+ philox_seed = 0x1BF52
+ philox_offset = 0x1D4B42
+
+ if bias is not None:
+ bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2),
+ bias.stride(3))
+ else:
+ bias_strides = (0, 0, 0, 0)
+
+ if alibi_slopes is not None:
+ alibi_strides = (alibi_slopes.stride(0), alibi_slopes.stride(1))
+ else:
+ alibi_strides = (0, 0)
+
+
+ attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
+ *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
+ dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, scores=scores,
+ scores_scaled_shifted=scores_scaled_shifted, exp_scores=exp_scores, alibi_slopes=alibi_slopes,
+ HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
+ MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
+ BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
+ USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
+ > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_scores)
+
+ return o, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted
diff --git a/modules/flash_attn_triton_amd/fwd_ref.py b/modules/flash_attn_triton_amd/fwd_ref.py
new file mode 100644
index 000000000..03e53efde
--- /dev/null
+++ b/modules/flash_attn_triton_amd/fwd_ref.py
@@ -0,0 +1,258 @@
+import math
+import torch
+
+
+def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2):
+ # Compute attention scores
+ attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
+
+ # Scale scores
+ attention_scaled_scores = sm_scale * attention_scores
+
+ # Apply causal mask if necessary
+ if causal:
+ L_q, L_k = q.shape[1], k.shape[1]
+ row_idx = torch.arange(L_q, device=q.device).unsqueeze(1)
+ col_idx = torch.arange(L_k, device=q.device).unsqueeze(0)
+ col_offset = L_q-L_k
+ causal_mask = row_idx >= (col_offset + col_idx)
+ # set -inf to places the causal mask is false
+ attention_scaled_scores = attention_scaled_scores.masked_fill(
+ torch.logical_not(causal_mask.unsqueeze(0)), float('-inf')
+ )
+
+
+ # Compute max for numerical stability
+ max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0]
+ if causal:
+ # Replace -inf in max_scores with zeros to avoid NaN in subtraction
+ max_scores = torch.where(
+ torch.isinf(max_scores), torch.zeros_like(max_scores), max_scores
+ )
+
+ # Shift scores
+ attention_shifted_scaled_scores = attention_scaled_scores - max_scores
+
+ # Exponentiate
+ if use_exp2:
+ RCP_LN = 1 / math.log(2)
+ exp_scores = torch.exp2(RCP_LN * attention_shifted_scaled_scores)
+ else:
+ exp_scores = torch.exp(attention_shifted_scaled_scores)
+
+ # Sum of exponentials
+ sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True)
+ if causal:
+ # if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly
+ sum_exp_scores = torch.where(
+ sum_exp_scores == 0,
+ torch.ones_like(sum_exp_scores),
+ sum_exp_scores
+ )
+
+ # Compute softmax probabilities
+ softmax = exp_scores / sum_exp_scores
+
+ # Compute log-sum-exp
+ if use_exp2:
+ LN2 = math.log(2)
+ RCP_LN = 1 / math.log(2)
+ max_scores_base2 = max_scores * RCP_LN
+ softmax_lse_base2 = max_scores_base2 + torch.log2(sum_exp_scores)
+ softmax_lse = softmax_lse_base2 * LN2
+ softmax_lse.squeeze_(-1)
+ else:
+ softmax_lse = max_scores + torch.log(sum_exp_scores)
+ softmax_lse = softmax_lse.squeeze(-1)
+
+ # Compute output
+ o = torch.matmul(softmax, v.to(torch.float32)).to(torch.float16)
+
+ return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores
+
+def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_exp2):
+ """Compute reference output and softmax_lse using PyTorch's built-in function"""
+
+ # Ensure the layout is 'bhsd'
+ if layout == "bshd":
+ q = q.transpose(1, 2).contiguous()
+ k = k.transpose(1, 2).contiguous()
+ v = v.transpose(1, 2).contiguous()
+ elif layout != "bhsd":
+ raise ValueError(f"Unknown layout {layout}")
+
+ # Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format
+ batch_size, num_heads, seq_len_q, head_dim = q.shape
+ seq_len_k = k.shape[2]
+
+ # Merge batch and heads dimensions
+ q = q.reshape(batch_size * num_heads, seq_len_q, head_dim)
+ k = k.reshape(batch_size * num_heads, seq_len_k, head_dim)
+ v = v.reshape(batch_size * num_heads, seq_len_k, head_dim)
+
+ # Call the core attention function
+ o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores = attention_forward_core_ref_impl(
+ q, k, v, sm_scale, causal, use_exp2
+ )
+
+ # Reshape outputs back to [batch_size, num_heads, seq_len, head_dim]
+ o = o.reshape(batch_size, num_heads, seq_len_q, head_dim)
+ softmax_lse = softmax_lse.reshape(batch_size, num_heads, seq_len_q)
+ exp_scores = exp_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
+ softmax = softmax.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
+ attention_shifted_scaled_scores = attention_shifted_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
+ attention_scaled_scores = attention_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
+ attention_scores = attention_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
+
+ # Restore original layout if necessary
+ if layout == "bshd":
+ o = o.transpose(1, 2)
+
+ return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores
+
+def attention_varlen_forward_pytorch_ref_impl(
+ q,
+ k,
+ v,
+ sm_scale,
+ causal,
+ layout,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q, max_seqlen_k, # pylint: disable=unused-argument
+ use_exp2
+):
+ # Ensure the layout is 'thd'
+ if layout != 'thd':
+ raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.")
+
+ batch_size = cu_seqlens_q.shape[0] - 1
+ num_heads = q.shape[1]
+ head_dim = q.shape[2]
+
+ # Pre-allocate outputs
+ total_L_q = q.shape[0]
+ total_L_k = k.shape[0] # pylint: disable=unused-variable
+
+ o = torch.empty((total_L_q, num_heads, head_dim), dtype=q.dtype, device=q.device)
+ softmax_lse = torch.empty((total_L_q, num_heads), dtype=torch.float32, device=q.device)
+
+ for i in range(batch_size):
+ # Get the start and end indices for the current sequence
+ start_q = cu_seqlens_q[i].item()
+ end_q = cu_seqlens_q[i + 1].item()
+ start_k = cu_seqlens_k[i].item()
+ end_k = cu_seqlens_k[i + 1].item()
+
+ # Extract q_i, k_i, v_i
+ q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
+ k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
+ v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
+
+ # Permute to [num_heads, L_q_i, head_dim]
+ q_i = q_i.permute(1, 0, 2)
+ k_i = k_i.permute(1, 0, 2)
+ v_i = v_i.permute(1, 0, 2)
+
+ # Call the core attention function for this sequence
+ (
+ o_i,
+ softmax_lse_i,
+ exp_scores_i,
+ softmax_i,
+ attention_shifted_scaled_scores_i,
+ attention_scaled_scores_i,
+ attention_scores_i,
+ ) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, use_exp2)
+
+ # Convert back to 'thd' layout and float16
+ o_i = o_i.permute(1, 0, 2).to(torch.float16) # [L_q_i, num_heads, head_dim]
+
+ # Place outputs in pre-allocated tensors
+ o[start_q:end_q, :, :] = o_i
+ softmax_lse[start_q:end_q, :] = softmax_lse_i.transpose(0, 1) # Transpose to [L_q_i, num_heads]
+
+ # For variable-sized outputs, map them into the preallocated tensors
+ # exp_scores_i: [num_heads, L_q_i, L_k_i] -> [L_q_i, num_heads, L_k_i]
+ exp_scores_i = exp_scores_i.permute(1, 0, 2)
+ softmax_i = softmax_i.permute(1, 0, 2)
+ attention_shifted_scaled_scores_i = attention_shifted_scaled_scores_i.permute(1, 0, 2)
+ attention_scaled_scores_i = attention_scaled_scores_i.permute(1, 0, 2)
+ attention_scores_i = attention_scores_i.permute(1, 0, 2)
+
+ return (
+ o,
+ softmax_lse,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+def attention_forward_pytorch_ref_impl(
+ q,
+ k,
+ v,
+ sm_scale,
+ causal,
+ layout,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ use_exp2
+ ):
+ # compute reference
+ if layout == "thd":
+ (
+ o_ref,
+ softmax_lse_ref,
+ exp_scores_ref,
+ softmax_ref,
+ attention_shifted_scaled_scores_ref,
+ attention_scaled_scores_ref,
+ attention_scores_ref,
+ ) = attention_varlen_forward_pytorch_ref_impl(
+ q.clone(),
+ k.clone(),
+ v.clone(),
+ sm_scale,
+ causal,
+ layout,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ use_exp2,
+ )
+ else:
+ (
+ o_ref,
+ softmax_lse_ref,
+ exp_scores_ref,
+ softmax_ref,
+ attention_shifted_scaled_scores_ref,
+ attention_scaled_scores_ref,
+ attention_scores_ref,
+ ) = attention_vanilla_forward_pytorch_ref_impl(
+ q.clone(), k.clone(), v.clone(), sm_scale, causal, layout, use_exp2
+ )
+
+ return (
+ o_ref,
+ softmax_lse_ref,
+ exp_scores_ref,
+ softmax_ref,
+ attention_shifted_scaled_scores_ref,
+ attention_scaled_scores_ref,
+ attention_scores_ref,
+ )
+
+
+def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k):
+ q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
+ k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K)
+ relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K)
+ return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
diff --git a/modules/flash_attn_triton_amd/interface_fa.py b/modules/flash_attn_triton_amd/interface_fa.py
new file mode 100644
index 000000000..72373d35f
--- /dev/null
+++ b/modules/flash_attn_triton_amd/interface_fa.py
@@ -0,0 +1,394 @@
+import os
+import torch
+from modules.flash_attn_triton_amd.fwd_prefill import attention_prefill_forward_triton_impl
+from modules.flash_attn_triton_amd.bwd_prefill import attention_prefill_backward_triton_impl
+from modules.flash_attn_triton_amd.fwd_decode import attention_decode_forward_triton_impl
+from modules.flash_attn_triton_amd.fwd_ref import attention_forward_pytorch_ref_impl
+from modules.flash_attn_triton_amd.bwd_ref import attention_backward_pytorch_ref_impl
+from modules.flash_attn_triton_amd.utils import MetaData, get_shape_from_layout
+
+
+USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')
+
+
+def fwd(q,
+ k,
+ v,
+ o,
+ alibi_slopes,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size_left, window_size_right, softcap, # pylint: disable=unused-argument
+ return_softmax,
+ gen_ # pylint: disable=unused-argument
+):
+ if dropout_p != 0.0:
+ raise ValueError("dropout is not supported on AMD's Triton Backend yet")
+
+ if o is None:
+ o = torch.empty_like(q)
+
+ # Setup metadata
+ metadata = MetaData(sm_scale=softmax_scale)
+ metadata.max_seqlens_q = q.shape[1]
+ metadata.max_seqlens_k = k.shape[1]
+ metadata.layout = "bshd"
+ if return_softmax:
+ metadata.return_scores = True
+
+ batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout) # pylint: disable=unused-variable
+
+ if causal:
+ metadata.need_causal()
+
+ if alibi_slopes is not None:
+ metadata.need_alibi(alibi_slopes, batch, nheads_q)
+
+ if dropout_p > 0.0:
+ metadata.need_dropout(dropout_p, return_softmax)
+
+ # Check arguments
+ metadata.check_args(q, k, v, o)
+ if USE_REF:
+ (output,
+ softmax_lse,
+ exp_scores,
+ _,
+ _,
+ _,
+ _) = attention_forward_pytorch_ref_impl(
+ q,
+ k,
+ v,
+ metadata.sm_scale,
+ metadata.causal,
+ metadata.layout,
+ metadata.cu_seqlens_q,
+ metadata.cu_seqlens_k,
+ metadata.max_seqlens_q,
+ metadata.max_seqlens_k,
+ metadata.use_exp2)
+ o.copy_(output)
+ else:
+ (_,
+ softmax_lse,
+ exp_scores,
+ _,
+ _,
+ _,
+ _,
+ _,
+ _) = attention_prefill_forward_triton_impl(
+ q,
+ k,
+ v,
+ o,
+ metadata.sm_scale,
+ metadata.alibi_slopes,
+ metadata.causal,
+ metadata.bias,
+ metadata.dropout_p,
+ metadata.layout,
+ metadata.cu_seqlens_q,
+ metadata.cu_seqlens_k,
+ metadata.max_seqlens_q,
+ metadata.max_seqlens_k,
+ metadata.return_scores,
+ metadata.use_exp2)
+
+ return o, softmax_lse, exp_scores, None
+
+
+def bwd(
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ dq,
+ dk,
+ dv,
+ alibi_slopes,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
+):
+ if dropout_p != 0.0:
+ raise ValueError("dropout is not supported on AMD yet")
+
+ if USE_REF:
+ dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ softmax_scale,
+ causal,
+ "bshd",
+ None,
+ None,
+ None,
+ None,
+ False,
+ )
+ dq.copy_(dq_ref)
+ dk.copy_(dk_ref)
+ dv.copy_(dv_ref)
+ delta = delta_ref
+ else:
+ dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ dq,
+ dk,
+ dv,
+ softmax_scale,
+ alibi_slopes,
+ causal,
+ "bshd",
+ None,
+ None,
+ None,
+ None,
+ False,
+ )
+ delta = delta_triton
+
+ return dq, dk, dv, delta
+
+
+def varlen_fwd(
+ q,
+ k,
+ v,
+ o,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ seqused_k, leftpad_k, block_table_, # pylint: disable=unused-argument
+ alibi_slopes,\
+ max_seqlen_q,
+ max_seqlen_k,
+ dropout_p,
+ softmax_scale,
+ zero_tensors, # pylint: disable=unused-argument
+ causal,
+ window_size_left, window_size_right, softcap, # pylint: disable=unused-argument
+ return_softmax,
+ gen_ # pylint: disable=unused-argument
+):
+ if dropout_p != 0.0:
+ raise ValueError("dropout is not supported on AMD's Triton Backend yet")
+
+ if o is None:
+ o = torch.empty_like(q)
+
+ # Setup metadata
+ metadata = MetaData(sm_scale=softmax_scale)
+ if return_softmax:
+ metadata.return_scores = True
+ metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata
+
+ # get shapes
+ batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable
+
+ if causal:
+ metadata.need_causal()
+
+ if alibi_slopes is not None:
+ metadata.need_alibi(alibi_slopes, batch, nheads_q)
+
+ if dropout_p > 0.0:
+ metadata.need_dropout(dropout_p, return_softmax)
+
+ # Check arguments
+ metadata.check_args(q, k, v, o)
+ if o is None:
+ o = torch.empty_like(q, dtype=v.dtype)
+
+ if USE_REF:
+ (output,
+ softmax_lse,
+ exp_scores,
+ _,
+ _,
+ _,
+ _) = attention_forward_pytorch_ref_impl(
+ q,
+ k,
+ v,
+ metadata.sm_scale,
+ metadata.causal,
+ metadata.layout,
+ metadata.cu_seqlens_q,
+ metadata.cu_seqlens_k,
+ metadata.max_seqlens_q,
+ metadata.max_seqlens_k,
+ metadata.use_exp2)
+ o.copy_(output)
+ else:
+ (_,
+ softmax_lse,
+ exp_scores,
+ _,
+ _,
+ _,
+ _,
+ _,
+ _) = attention_prefill_forward_triton_impl(
+ q,
+ k,
+ v,
+ o,
+ metadata.sm_scale,
+ metadata.alibi_slopes,
+ metadata.causal,
+ metadata.bias,
+ metadata.dropout_p,
+ metadata.layout,
+ metadata.cu_seqlens_q,
+ metadata.cu_seqlens_k,
+ metadata.max_seqlens_q,
+ metadata.max_seqlens_k,
+ metadata.return_scores,
+ metadata.use_exp2)
+
+ return o, softmax_lse, exp_scores, None
+
+
+def varlen_bwd(
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ dq,
+ dk,
+ dv,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ alibi_slopes,
+ max_seqlen_q,
+ max_seqlen_k,
+ dropout_p,
+ softmax_scale,
+ zero_tensors, # pylint: disable=unused-argument
+ causal,
+ window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
+):
+ if dropout_p != 0.0:
+ raise ValueError("dropout is not supported on AMD yet")
+
+ if USE_REF:
+ dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ softmax_scale,
+ causal,
+ "thd",
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ False,
+ )
+ dq.copy_(dq_ref)
+ dk.copy_(dk_ref)
+ dv.copy_(dv_ref)
+ delta = delta_ref
+ else:
+ dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ dq,
+ dk,
+ dv,
+ softmax_scale,
+ alibi_slopes,
+ causal,
+ "thd",
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ False,
+ )
+ delta = delta_triton
+
+ return dq, dk, dv, delta
+
+
+def fwd_kvcache(
+ q,
+ k_cache,
+ v_cache,
+ k,
+ v,
+ cache_seqlens,
+ rotary_cos, rotary_sin, # pylint: disable=unused-argument
+ cache_batch_idx,
+ cache_leftpad, block_table, # pylint: disable=unused-argument
+ alibi_slopes,
+ out,
+ softmax_scale,
+ causal,
+ window_size_left, window_size_right, softcap, rotary_interleaved, num_splits, # pylint: disable=unused-argument
+):
+ if out is None:
+ out = torch.empty_like(q)
+
+ # fill metadata
+ metadata = MetaData(sm_scale=softmax_scale)
+ metadata.layout = "bshd"
+ metadata.max_seqlens_q = q.shape[1]
+ metadata.max_seqlens_k = k_cache.shape[1]
+ metadata.cache_seqlens = cache_seqlens
+ metadata.cache_batch_idx = cache_batch_idx
+
+ if k is not None and v is not None:
+ metadata.new_kv = True
+ metadata.seqlen_new = k.shape[1]
+ metadata.k_new = k
+ metadata.v_new = v
+
+ if causal:
+ metadata.need_causal()
+
+ if alibi_slopes is not None:
+ batch, _ , nheads_q, _= q.shape
+ metadata.need_alibi(alibi_slopes, batch, nheads_q)
+
+ # launch kernel
+ # TODO: pass output as an arg. Maybe we are copying output which is causing slow down
+ output, softmax_lse = attention_decode_forward_triton_impl(
+ q,
+ k_cache,
+ v_cache,
+ metadata.sm_scale,
+ metadata.causal,
+ metadata.alibi_slopes,
+ metadata.layout,
+ metadata.cache_seqlens,
+ metadata.cache_batch_idx,
+ metadata.new_kv,
+ metadata.k_new,
+ metadata.v_new,
+ )
+ return output, softmax_lse
diff --git a/modules/flash_attn_triton_amd/utils.py b/modules/flash_attn_triton_amd/utils.py
new file mode 100644
index 000000000..77384cff6
--- /dev/null
+++ b/modules/flash_attn_triton_amd/utils.py
@@ -0,0 +1,280 @@
+import os
+import torch
+import triton
+
+
+AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes')
+PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')
+
+
+class MetaData():
+ cu_seqlens_q = None
+ cu_seqlens_k = None
+ max_seqlens_q = 0
+ max_seqlens_k = 0
+ bias = None
+ alibi_slopes = None
+ causal = False
+ num_contexts = 0
+ varlen = False
+ layout = None
+ cache_seqlens = None
+ cache_batch_idx = None
+ new_kv = False
+ seqlen_new = None
+ k_new = None
+ v_new = None
+ dropout_p, return_scores= 0.0, False
+ # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW.
+ use_exp2 = False
+
+ def __repr__(self) -> str:
+ return (f"MetaData(\n"
+ f" sm_scale={self.sm_scale},\n"
+ f" cu_seqlens_q={self.cu_seqlens_q},\n"
+ f" cu_seqlens_k={self.cu_seqlens_k},\n"
+ f" max_seqlens_q={self.max_seqlens_q},\n"
+ f" max_seqlens_k={self.max_seqlens_k},\n"
+ f" bias={self.bias},\n"
+ f" alibi_slopes={self.alibi_slopes},\n"
+ f" causal={self.causal},\n"
+ f" num_contexts={self.num_contexts},\n"
+ f" varlen={self.varlen},\n"
+ f" layout={self.layout},\n"
+ f" cache_seqlens={self.cache_seqlens},\n"
+ f" cache_batch_idx={self.cache_batch_idx},\n"
+ f" new_kv={self.new_kv},\n"
+ f" seqlen_new={self.seqlen_new},\n"
+ f" k_new={self.k_new},\n"
+ f" v_new={self.v_new},\n"
+ f" dropout_p={self.dropout_p},\n"
+ f" return_scores={self.return_scores}\n"
+ f")")
+
+ def __init__(self, sm_scale=1.0):
+ self.sm_scale = sm_scale
+
+ def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
+ self.varlen = True
+ self.layout = 'thd'
+ self.cu_seqlens_q = cu_seqlens_q
+ self.cu_seqlens_k = cu_seqlens_k
+ # Without "varlen", there should still be one sequence.
+ assert len(cu_seqlens_q) >= 2
+ assert len(cu_seqlens_q) == len(cu_seqlens_k)
+ self.num_contexts = len(cu_seqlens_q) - 1
+ for i in range(0, self.num_contexts):
+ self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q)
+ self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k)
+
+ def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): # pylint: disable=unused-argument
+ assert bias.is_cuda
+ assert bias.dim() == 4
+ assert bias.shape[0] == 1
+ assert bias.shape[2:] == (seqlen_q, seqlen_k)
+ self.bias = bias
+
+ def need_alibi(self, alibi_slopes, batch, nheads):
+ assert alibi_slopes.is_cuda
+ assert alibi_slopes.dim() == 2
+ assert alibi_slopes.shape[0] == batch
+ assert alibi_slopes.shape[1] == nheads
+ self.alibi_slopes = alibi_slopes
+
+ def need_causal(self):
+ self.causal = True
+
+ def need_dropout(self, dropout_p, return_scores):
+ self.dropout_p = dropout_p
+ self.return_scores = return_scores
+
+ def check_args(self, q, k, v, o):
+ assert q.dim() == k.dim() and q.dim() == v.dim()
+
+ batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) # pylint: disable=unused-variable
+ if self.varlen:
+ assert q.dim() == 3
+ assert self.cu_seqlens_q is not None
+ assert self.cu_seqlens_k is not None
+ assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
+ # TODO: Remove once bias is supported with varlen
+ assert self.bias is None
+ # TODO:Remove once dropout is supported with varlen
+ assert self.dropout_p == 0.0
+ # assert not self.return_scores
+ else:
+ assert q.dim() == 4
+ assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
+ assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
+ assert k.shape == v.shape
+ assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
+ # TODO: Change assert if we support qkl f8 and v f16
+ assert q.dtype == k.dtype and q.dtype == v.dtype
+ assert head_size <= 256
+ assert o.shape == q.shape
+ assert (nheads_q % nheads_k) == 0
+ assert self.layout is not None
+ assert self.layout == 'thd' or not self.varlen
+
+def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cuda", DEBUG_INPUT=False):
+ torch.manual_seed(20)
+
+ # Initialize q, k, v
+ if layout == 'bhsd':
+ q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
+ k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
+ elif layout == 'bshd':
+ q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
+ k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD)
+ else:
+ assert False, f'Got unsupported tensor layout: {layout}'
+
+ q = None
+ k = None
+ v = None
+
+ if DEBUG_INPUT:
+ if layout == "bhsd":
+ q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, 1, N_CTX_Q, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
+ k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
+ v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
+ elif layout == "bshd":
+ q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, N_CTX_Q, 1, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
+ k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
+ v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
+ else:
+ q = torch.randn(q_tensor_shape, dtype=dtype, device=device, requires_grad=True)
+ k = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
+ v = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
+
+ if DEBUG_INPUT:
+ sm_scale = 1
+ else:
+ sm_scale = D_HEAD**-0.5
+ input_metadata = MetaData(sm_scale=sm_scale)
+ input_metadata.max_seqlens_q = N_CTX_Q
+ input_metadata.max_seqlens_k = N_CTX_K
+ input_metadata.layout = layout
+ return q, k, v, input_metadata
+
+
+def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False):
+ torch.manual_seed(20)
+
+ # Random or equal sequence lengths based on 'equal_seqlens' flag
+ if not equal_seqlens:
+ max_seqlens_q = N_CTX_Q // Z
+ max_seqlens_k = N_CTX_K // Z
+ seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32)
+ seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32)
+ else:
+ seqlens_q = torch.full((Z,), N_CTX_Q // Z, dtype=torch.int32)
+ seqlens_k = torch.full((Z,), N_CTX_K // Z, dtype=torch.int32)
+
+ # Calculate cumulative sequence lengths
+ cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0)])
+ cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0)])
+ cu_seqlens_q = cu_seqlens_q.to(device=device).to(torch.int32)
+ cu_seqlens_k = cu_seqlens_k.to(device=device).to(torch.int32)
+
+ # Total lengths
+ total_q = cu_seqlens_q[-1].item()
+ total_k = cu_seqlens_k[-1].item()
+
+ if DEBUG_INPUT:
+ # Initialize q, k, v with deterministic values
+ q = torch.arange(total_q, dtype=dtype, device=device).view(total_q, 1, 1)
+ q = q.expand(total_q, HQ, D_HEAD).contiguous().requires_grad_()
+ k = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
+ k = k.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
+ v = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
+ v = v.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
+ sm_scale = 1
+ else:
+ # Initialize q, k, v with random values
+ q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_()
+ k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
+ v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
+ sm_scale = D_HEAD ** -0.5
+
+ input_metadata = MetaData(sm_scale=sm_scale)
+ input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
+ return q, k, v, input_metadata
+
+
+def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None):
+ if layout == 'bhsd':
+ batch_q, nheads_q, max_seqlen_q, head_size_q = q.shape
+ batch_k, nheads_k, max_seqlen_k, head_size_k = k.shape
+ elif layout == 'bshd':
+ batch_q, max_seqlen_q, nheads_q, head_size_q = q.shape
+ batch_k, max_seqlen_k, nheads_k, head_size_k = k.shape
+ elif layout == 'thd':
+ batch_q, max_seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] # pylint: disable=self-assigning-variable
+ batch_k, max_seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2] # pylint: disable=self-assigning-variable
+ else:
+ assert False, "Got unsupported layout."
+
+ # assert
+ assert batch_q == batch_k
+ assert head_size_q == head_size_k
+
+ return batch_q, nheads_q, nheads_k, head_size_q, max_seqlen_q, max_seqlen_k
+
+
+def get_strides_from_layout(q, k, v, o, layout):
+ if layout == 'thd':
+ q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
+ k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
+ v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
+ o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
+ elif layout == 'bhsd':
+ q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))
+ k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))
+ v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))
+ o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))
+ elif layout == 'bshd':
+ q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
+ k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
+ v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
+ o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
+ else:
+ assert False, 'Got unsupported layout.'
+ return q_strides, k_strides, v_strides, o_strides
+
+
+def get_padded_headsize(size):
+ # Get closest power of 2 over or equal to 32.
+ padded_d_model = 1 << (size - 1).bit_length()
+ # Smallest head_dim supported is 16. If smaller, the tile in the
+ # kernel is padded - there is no padding in memory for any dims.
+ padded_d_model = max(padded_d_model, 16)
+ return padded_d_model
+
+
+def _strides(x: torch.Tensor, *stride_names: str):
+ if x is None:
+ return {f"stride_{s}": 0 for i, s in enumerate(stride_names)}
+
+ assert x.ndim == len(stride_names)
+ return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
+
+
+def get_input_shapes():
+ cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128)
+ for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)]
+ return cases
+
+
+def is_hip():
+ return triton.runtime.driver.active.get_current_target().backend == "hip"
+
+
+def is_cdna():
+ return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
+ 'gfx90a', 'gfx908')
+
+
+def is_rdna():
+ return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101",
+ "gfx1102", "gfx1200", "gfx1201")
diff --git a/modules/zluda_hijacks.py b/modules/zluda_hijacks.py
index 1b3f750a9..4f1224923 100644
--- a/modules/zluda_hijacks.py
+++ b/modules/zluda_hijacks.py
@@ -1,13 +1,21 @@
+from functools import wraps
import torch
import torch._dynamo.device_interface
-from modules import rocm, zluda
+from modules import rocm, zluda, shared
-_topk = torch.topk
-def topk(input: torch.Tensor, *args, **kwargs): # pylint: disable=redefined-builtin
- device = input.device
- values, indices = _topk(input.cpu(), *args, **kwargs)
- return torch.return_types.topk((values.to(device), indices.to(device),))
+MEM_BUS_WIDTH = {
+ "AMD Radeon RX 9070 XT": 256,
+ "AMD Radeon RX 9070": 256,
+ "AMD Radeon RX 9060 XT": 192,
+ "AMD Radeon RX 7900 XTX": 384,
+ "AMD Radeon RX 7900 XT": 320,
+ "AMD Radeon RX 7900 GRE": 256,
+ "AMD Radeon RX 7800 XT": 256,
+ "AMD Radeon RX 7700 XT": 192,
+ "AMD Radeon RX 7600 XT": 128,
+ "AMD Radeon RX 7600": 128,
+}
class DeviceProperties:
@@ -35,20 +43,60 @@ def torch__C__cuda_getCurrentRawStream(device):
def do_hijack():
torch.version.hip = rocm.version
- torch.topk = topk
if zluda.default_agent is not None:
DeviceProperties.PROPERTIES_OVERRIDE["gcnArchName"] = zluda.default_agent.name
torch.cuda._get_device_properties = torch_cuda__get_device_properties # pylint: disable=protected-access
torch._C._cuda_getCurrentRawStream = torch__C__cuda_getCurrentRawStream # pylint: disable=protected-access
torch._dynamo.device_interface.CudaInterface.get_raw_stream = staticmethod(torch__C__cuda_getCurrentRawStream) # pylint: disable=protected-access
+
+ # Triton
try:
import triton
_get_device_properties = triton.runtime.driver.active.utils.get_device_properties
def triton_runtime_driver_active_utils_get_device_properties(device):
props = _get_device_properties(device)
- props["mem_bus_width"] = 384
+ name = torch.cuda.get_device_name()[:-8]
+ if name in MEM_BUS_WIDTH:
+ props["mem_bus_width"] = MEM_BUS_WIDTH[name]
+ else:
+ props["mem_bus_width"] = 128
+ shared.log.warning(f'[TRITON] defaulting mem_bus_width=128 for device "{name}".')
return props
triton.runtime.driver.active.utils.get_device_properties = triton_runtime_driver_active_utils_get_device_properties
+
+ if 'Flash attention' in shared.opts.sdp_options:
+ from modules.flash_attn_triton_amd import interface_fa
+ sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
+ @wraps(sdpa_pre_flash_atten)
+ def sdpa_flash_atten(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
+ if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
+ if scale is None:
+ scale = query.shape[-1] ** (-0.5)
+ head_size_og = query.size(3)
+ if head_size_og % 8 != 0:
+ query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8])
+ key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8])
+ value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8])
+ out_padded, _, _, _ = interface_fa.fwd(
+ query.transpose(1, 2),
+ key.transpose(1, 2),
+ value.transpose(1, 2),
+ None,
+ None,
+ dropout_p,
+ scale,
+ is_causal,
+ -1,
+ -1,
+ 0.0,
+ False,
+ None,
+ )
+ return out_padded[..., :head_size_og].transpose(1, 2)
+ else:
+ return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
+ torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
+ shared.log.debug('Torch attention: type="triton flash attention"')
except Exception:
pass