zluda 3.9.4 & update flash attention 2

pull/3914/head
Seunghoon Lee 2025-05-07 00:39:40 +09:00
parent 83fc68ece3
commit 2a4eb86d10
No known key found for this signature in database
GPG Key ID: 436E38F4E70BD152
7 changed files with 536 additions and 1904 deletions

View File

@ -1,606 +0,0 @@
import torch
import triton
import triton.language as tl
from modules.flash_attn_triton_amd.utils import get_shape_from_layout, get_strides_from_layout
@triton.jit
def _bwd_preprocess_use_o(
Out,
DO,
Delta,
stride_oz, stride_oh, stride_om, stride_ok,
stride_doz, stride_doh, stride_dom, stride_dok, # pylint: disable=unused-argument
stride_deltaz, stride_deltah, stride_deltam,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
N_CTX_Q: tl.constexpr,
Z: tl.constexpr, # pylint: disable=unused-argument
H: tl.constexpr,
IS_VARLEN: tl.constexpr
):
pid_m = tl.program_id(0)
pid_bh = tl.program_id(1)
# Compute batch and head indices
off_z = pid_bh // H
off_h = pid_bh % H
if IS_VARLEN:
# Compute sequence lengths for the current batch
q_start = tl.load(cu_seqlens_q + off_z)
q_end = tl.load(cu_seqlens_q + off_z + 1)
k_start = tl.load(cu_seqlens_k + off_z)
k_end = tl.load(cu_seqlens_k + off_z + 1)
# Compute actual sequence lengths
N_CTX_Q = q_end - q_start
N_CTX_K = k_end - k_start # pylint: disable=unused-variable
else:
q_start = 0
k_start = 0
N_CTX_Q = max_seqlen_q
N_CTX_K = max_seqlen_k # pylint: disable=unused-variable
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_d = tl.arange(0, BLOCK_DMODEL)
# create masks
mask_m = off_m < N_CTX_Q
mask_d = off_d < ACTUAL_BLOCK_DMODEL
# compute offsets
o_offset = Out + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
do_offset = DO + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
# compute pointers
out_ptrs = o_offset + off_m[:, None] * stride_om + off_d[None, :] * stride_ok
do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok
# load
o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
# compute delta
delta = tl.sum(o * do, axis=1)
# write-back delta
delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
delta_ptrs = delta_offset + off_m * stride_deltam
tl.store(delta_ptrs, delta, mask=mask_m)
@triton.jit
def _bwd_kernel_one_col_block(
Q,
K,
V,
sm_scale,
Out, DO, DQ, DK, DV, L, D, # pylint: disable=unused-argument
q_offset,
k_offset,
v_offset,
do_offset,
dq_offset,
dk_offset,
dv_offset,
d_offset,
l_offset,
stride_dq_all, stride_qz, stride_qh, # pylint: disable=unused-argument
stride_qm,
stride_qk,
stride_kz, stride_kh, # pylint: disable=unused-argument
stride_kn,
stride_kk,
stride_vz, stride_vh, # pylint: disable=unused-argument
stride_vn,
stride_vk,
stride_deltaz, stride_deltah, # pylint: disable=unused-argument
stride_deltam,
Z, H, # pylint: disable=unused-argument
N_CTX_Q,
N_CTX_K,
off_h, off_z, off_hz, # pylint: disable=unused-argument
start_n,
num_block_m,
num_block_n, # pylint: disable=unused-argument
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
USE_EXP2: tl.constexpr,
):
if CAUSAL:
# TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
lo = 0
else:
lo = 0
# initialize col and head offsets
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
# masks
mask_n = offs_n < N_CTX_K
mask_d = offs_d < ACTUAL_BLOCK_DMODEL
kv_mask = mask_n[:, None] & mask_d[None, :]
# initialize grad accumulators
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
# load k and v once per column block
k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
# loop over rows
for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M):
offs_m = start_m + tl.arange(0, BLOCK_M)
q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
# update mask as row block changes
mask_m = offs_m < N_CTX_Q
q_mask = mask_m[:, None] & mask_d[None, :]
# load q, k, v, do on-chip
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
do = tl.load(do_ptrs, mask=q_mask, other=0.0)
# recompute p = softmax(qk, dim=-1).T
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
if CAUSAL:
col_offset = N_CTX_Q - N_CTX_K
causal_mask = offs_m[:, None] >= (col_offset + offs_n[None, :])
qk = tl.where(causal_mask, qk, float("-inf"))
l_ptrs = l_offset + offs_m * stride_deltam
l_i = tl.load(l_ptrs, mask=mask_m)
# compute p
if USE_EXP2:
RCP_LN2: tl.constexpr = 1.4426950408889634
qk *= sm_scale * RCP_LN2
l_i *= RCP_LN2
p = tl.math.exp2(qk - l_i[:, None])
else:
qk *= sm_scale
p = tl.math.exp(qk - l_i[:, None])
# mask block in the cases where the data is smaller the block size
p_mask = mask_m[:, None] & mask_n[None, :]
p = tl.where(p_mask, p, 0.0)
# compute dv
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp
dp = tl.dot(do, tl.trans(v))
# compute ds , ds = p * (dp - delta[:, None])
d_ptrs = d_offset + offs_m * stride_deltam
Di = tl.load(d_ptrs, mask=mask_m)
ds = (p * (dp - Di[:, None])) * sm_scale
ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty)
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds), q)
# compute dq
if SEQUENCE_PARALLEL:
dq = tl.dot(ds, k)
else:
dq = tl.load(dq_ptrs, mask=q_mask, other=0.0)
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask)
# write-back dv and dk
dk_ptrs = dk_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
# write-back
tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)
@triton.jit
def _bwd_kernel(
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
D,
stride_dq_all,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vn,
stride_vk,
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
num_block_m,
num_block_n,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
USE_EXP2: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
# program ids
off_hz = tl.program_id(0)
if SEQUENCE_PARALLEL:
start_n = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
if IS_VARLEN:
# Compute sequence lengths for the current batch
q_start = tl.load(cu_seqlens_q + off_z)
q_end = tl.load(cu_seqlens_q + off_z + 1)
k_start = tl.load(cu_seqlens_k + off_z)
k_end = tl.load(cu_seqlens_k + off_z + 1)
# Compute actual sequence lengths
N_CTX_Q = q_end - q_start
N_CTX_K = k_end - k_start
else:
q_start = 0
k_start = 0
N_CTX_Q = max_seqlen_q
N_CTX_K = max_seqlen_k
# input tensor offsets
q_offset = Q + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
k_offset = K + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
v_offset = V + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
do_offset = DO + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
l_offset = L + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
d_offset = D + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
# output tensor offsets
dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
if SEQUENCE_PARALLEL:
dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
else:
dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
# inner loop
if SEQUENCE_PARALLEL:
_bwd_kernel_one_col_block(
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
D,
q_offset,
k_offset,
v_offset,
do_offset,
dq_offset,
dk_offset,
dv_offset,
d_offset,
l_offset,
stride_dq_all,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vn,
stride_vk,
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
N_CTX_Q,
N_CTX_K,
off_h,
off_z,
off_hz,
start_n,
num_block_m,
num_block_n,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
USE_EXP2=USE_EXP2,
)
else:
for start_n in range(0, num_block_n):
_bwd_kernel_one_col_block(
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
D,
q_offset,
k_offset,
v_offset,
do_offset,
dq_offset,
dk_offset,
dv_offset,
d_offset,
l_offset,
stride_dq_all,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vn,
stride_vk,
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
N_CTX_Q,
N_CTX_K,
off_h,
off_z,
off_hz,
start_n,
num_block_m,
num_block_n,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
USE_EXP2=USE_EXP2,
)
# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom.
def attention_prefill_backward_triton_impl(
do,
q,
k,
v,
o,
softmax_lse,
dq,
dk,
dv,
sm_scale: float,
alibi_slopes, # pylint: disable=unused-argument
causal,
layout: str,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q: int,
max_seqlen_k: int,
use_exp2: bool,
sequence_parallel = True,
):
# make contigious
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
softmax_lse = softmax_lse.contiguous()
# get strides and shape
batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
stride_qz, stride_qh, stride_qm, stride_qk = q_strides
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
stride_oz, stride_oh, stride_om, stride_ok = o_strides
batch_headsize = batch * nheads_q
is_varlen = layout == "thd"
# FIXME: some configs lead to oom for some reason when using 64 x 64 blocks
if max_seqlen_q <= 32 or max_seqlen_k <= 32:
BLOCK_M = 32
BLOCK_N = 32
else:
BLOCK_M = 64
BLOCK_N = 64
num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful
num_stages = 1
waves_per_eu = 1
# divide up the problem
num_blocks_m = triton.cdiv(max_seqlen_q, BLOCK_M)
num_blocks_n = triton.cdiv(max_seqlen_k, BLOCK_N)
# get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
padded_d_model = max(padded_d_model, 16)
BLOCK_DMODEL = padded_d_model
ACTUAL_BLOCK_DMODEL = head_size
do = do.contiguous()
# NOTE: we might need to copy the output tensor if they are not continuous or have other issues
copy_back = {"dq": False, "dk": False, "dv": False}
dq_og = None
# deal with dq
if dq is None:
if sequence_parallel:
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
else:
dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype)
else:
dq_og = dq
if not dq.is_contiguous():
dq = dq.contiguous()
copy_back["dq"] = True
if sequence_parallel:
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
copy_back["dq"] = True
else:
# NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
dq.zero_()
stride_dq_all = dq.stride()[0]
dk_og = None
dv_og = None
# deal with dk, dv
if (dk is None) or (dv is None):
dk = torch.empty_like(k)
dv = torch.empty_like(v)
else:
if not dk.is_contiguous():
dk_og = dk
dk = dk.contiguous()
copy_back["dk"] = True
if not dv.is_contiguous():
dv_og = dv
dv = dv.contiguous()
copy_back["dv"] = True
# assert contigious
assert do.is_contiguous()
assert q.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert o.is_contiguous()
assert softmax_lse.is_contiguous()
# init delta
delta = torch.empty_like(softmax_lse)
if is_varlen:
stride_deltam, stride_deltah = delta.stride()
stride_deltaz = 0
else:
stride_deltaz, stride_deltah, stride_deltam = delta.stride()
_bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
o,
do,
delta,
stride_oz, stride_oh, stride_om, stride_ok,
stride_oz, stride_oh, stride_om, stride_ok,
stride_deltaz, stride_deltah, stride_deltam,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
N_CTX_Q=max_seqlen_q,
Z=batch,
H=nheads_q,
IS_VARLEN=is_varlen
)
_bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)](
q,
k,
v,
sm_scale,
o,
do,
dq,
dk,
dv,
softmax_lse,
delta,
stride_dq_all,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_deltaz, stride_deltah, stride_deltam,
batch,
nheads_q,
num_blocks_m,
num_blocks_n,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
CAUSAL=causal,
USE_EXP2=use_exp2,
num_warps=num_warps,
num_stages=num_stages,
waves_per_eu = waves_per_eu,
IS_VARLEN=is_varlen
)
if sequence_parallel:
dq = dq.sum(dim=0)
if copy_back["dq"]:
dq_og.copy_(dq)
dq = dq_og
if copy_back["dk"]:
dk_og.copy_(dk)
dk = dk_og
if copy_back["dv"]:
dv_og.copy_(dv)
dv = dv_og
return dq, dk, dv, delta, None, None

View File

@ -1,700 +0,0 @@
import torch
import triton
import triton.language as tl
from modules.flash_attn_triton_amd.utils import _strides, get_padded_headsize
@triton.jit
def _fwd_kernel_splitK(
Q,
K,
V,
sm_scale,
Out_splitK, # [B, H, split_k, Mq, K]
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
K_new,
V_new,
Cache_seqlens,
Cache_batch_idx,
Alibi_slopes,
stride_qz,
stride_qm,
stride_qg,
stride_qh,
stride_qd,
stride_kz,
stride_kn,
stride_kg,
stride_kh,
stride_kd,
stride_vz,
stride_vn,
stride_vg,
stride_vh,
stride_vd,
stride_osk_zhg,
stride_osk_s,
stride_osk_m,
stride_osk_d, # pylint: disable=unused-argument
stride_mzhg,
stride_m2,
stride_ms,
stride_mm, # pylint: disable=unused-argument
stride_kn_z,
stride_kn_n,
stride_kn_g,
stride_kn_h,
stride_kn_d,
stride_vn_z,
stride_vn_n,
stride_vn_g,
stride_vn_h,
stride_vn_d,
stride_az,
stride_ah,
Z, # pylint: disable=unused-argument
N_CTX_Q,
N_CTX_K,
N_CTX_NEW,
BLOCK_N_PER_SPLIT,
H_q: tl.constexpr,
H_kv: tl.constexpr,
G_q: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
BOUNDS_CHECKS_N: tl.constexpr,
USE_CACHE_SEQLENs: tl.constexpr,
USE_CACHE_BATCH_IDX: tl.constexpr,
NEW_KV: tl.constexpr,
IS_GQA: tl.constexpr,
IS_CAUSAL: tl.constexpr,
USE_ALIBI: tl.constexpr,
):
# Padding
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
if PADDED_HEAD:
d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL
start_m = tl.program_id(0)
off_zhg = tl.program_id(1)
off_z = off_zhg // (H_q * G_q)
off_h_q = (off_zhg // G_q) % H_q
off_g_q = off_zhg % G_q
splitk_idx = tl.program_id(2)
# pick batch index
if USE_CACHE_BATCH_IDX:
cache_batch_idx = tl.load(Cache_batch_idx + off_z)
else:
cache_batch_idx = off_z
# Load ALiBi slope if enabled
if USE_ALIBI:
a_offset = off_z * stride_az + off_h_q * stride_ah
alibi_slope = tl.load(Alibi_slopes + a_offset)
else:
alibi_slope = None
lo = splitk_idx * BLOCK_N_PER_SPLIT
if USE_CACHE_SEQLENs:
cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z)
if NEW_KV:
kv_len = cache_seqlen_last_idx + N_CTX_NEW
else:
kv_len = cache_seqlen_last_idx
else:
kv_len = N_CTX_K
hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
HEAD_RATIO: tl.constexpr = H_q // H_kv
if IS_GQA:
k_head_idx = off_h_q // HEAD_RATIO
v_head_idx = k_head_idx
else:
k_head_idx = off_h_q
v_head_idx = off_h_q
# calculate base offset
k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg
v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg
# Copy new Keys and Values into Cache
if NEW_KV:
knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g
# Determine the starting position for new data in the cache
if USE_CACHE_SEQLENs:
start_idx = tl.load(Cache_seqlens + off_z)
else:
start_idx = N_CTX_K - N_CTX_NEW
# Copy new Keys
for i in range(0, N_CTX_NEW, BLOCK_N):
# Load from K_new
k_new_block = tl.load(
knew_base +
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d +
(tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n,
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
other=0
)
# Store to K
tl.store(
k_base +
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd +
(tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn,
k_new_block,
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
)
# Copy new Values
vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g
for i in range(0, N_CTX_NEW, BLOCK_N):
# Load from V_new
v_new_block = tl.load(
vnew_base +
(tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n +
tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d,
mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
other=0
)
# Store to V
tl.store(
v_base +
(tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn +
tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd,
v_new_block,
mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
)
Q_block_ptr = tl.make_block_ptr(
base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg,
shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qd),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=k_base,
shape=(ACTUAL_BLOCK_DMODEL, hi),
strides=(stride_kd, stride_kn),
offsets=(0, lo),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=v_base,
shape=(hi, ACTUAL_BLOCK_DMODEL),
strides=(stride_vn, stride_vd),
offsets=(lo, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
K_scale_shift_block_ptr = None
V_scale_shift_block_ptr = None
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load( # noqa: F821
tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, ))
q = (q * qk_scale).to(q.dtype)
if PADDED_HEAD:
q = tl.where(d_mask[None, :], q, 0.0)
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
k, v = load_k_v_group(
K_block_ptr,
V_block_ptr,
K_scale_shift_block_ptr,
V_scale_shift_block_ptr,
BOUNDS_CHECKS_N,
1,
BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL,
Q.dtype.element_ty,
0,
)
if PADDED_HEAD:
k = tl.where(d_mask[:, None], k, 0.0)
v = tl.where(d_mask[None, :], v, 0.0)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) # noqa: F821
if USE_ALIBI:
row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
col_idx = start_n + tl.arange(0, BLOCK_N)
# Compute relative positions
relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :])
relative_pos = tl.abs(relative_pos)
# Compute ALiBi bias
alibi_bias = -1 * alibi_slope * relative_pos
qk += (alibi_bias * 1.44269504)
# Apply causal mask if IS_CAUSAL is True
if IS_CAUSAL:
row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
col_idx = start_n + tl.arange(0, BLOCK_N)
# create a N_CTX_Q x kv_len causal mask
col_offset = N_CTX_Q - kv_len
causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :])
# Apply the mask
qk = tl.where(causal_mask, qk, float("-inf"))
# TODO: This is slow, and only needed at the last iteration.
# Maybe we can unroll the last iteration instead?
if BOUNDS_CHECKS_N:
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
if IS_CAUSAL:
alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf")))
else:
alpha = tl.math.exp2(m_i - m_i_new)
# cause of nan because subtracting infs
if IS_CAUSAL:
qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf"))
else:
qk = qk - m_i_new[:, None]
p = tl.math.exp2(qk)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
p = p.to(Q.dtype.element_ty)
# -- scale and update acc --
acc *= alpha[:, None]
acc += tl.dot(p.to(v.dtype), v)
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
shape=(N_CTX_Q, BLOCK_DMODEL),
strides=(stride_osk_m, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
tl.store(
tl.advance(O_block_ptr, (0, 0)),
acc,
boundary_check=(0, ),
)
# Write metadata for split-K reduction
Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M +
tl.arange(0, BLOCK_M))
tl.store(Metadata_ptr, m_i)
tl.store(Metadata_ptr + stride_m2, l_i)
@triton.jit
def load_k_v_group(
K_block_ptr,
V_block_ptr,
K_scale_shift_block_ptr, V_scale_shift_block_ptr, # pylint: disable=unused-argument
BOUNDS_CHECKS_N: tl.constexpr,
PACKED_PER_VAL: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # pylint: disable=unused-argument
ACTUAL_BLOCK_DMODEL: tl.constexpr,
dtype: tl.constexpr, # pylint: disable=unused-argument
group_id: tl.constexpr,
):
# Load K/V for a given block
# Advance to the current quantization group
K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0))
V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id))
# -- load k, v --
k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ())
v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ())
return k, v
@triton.jit
def cast_uint32_to_half2(scale_shift):
# Extract two float16 packed into one int32
scale = scale_shift & 0xFFFF
shift = scale_shift >> 16
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
return scale, shift
@triton.jit
def dequantize(
x_,
scale,
shift,
PACKED_PER_VAL: tl.constexpr = 8,
):
# PACKED_PER_VAL is the number of values packed into
# each element x_. For example, for int4 quantization
#and x_ of type int32, PACKED_PER_VAL is 8.
BLOCK_N: tl.constexpr = x_.shape[0]
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
offsets = tl.arange(0, PACKED_PER_VAL) * 4
quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL))
# Trick - instead of converting int4 to float16 we view it as float16
# and then multiply by 32768 * 512 == 2**24
quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
quant_offset = (quant_offset * 32768.0).to(tl.float16)
scale_512 = scale * 512
dequant = quant_offset * scale_512 + shift
return dequant
@triton.jit
def _splitK_reduce(
Out_splitK, # [B, H, split_k, Mq, K]
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
Out, # [B, H, M, K]
LSE, # [B, H, M]
stride_osk_zhg,
stride_osk_s,
stride_osk_m,
stride_osk_k,
stride_mzhg,
stride_m2,
stride_ms,
stride_mm,
stride_oz,
stride_oh,
stride_og,
stride_om,
stride_ok, # pylint: disable=unused-argument
stride_lse_zhg,
stride_lse_m, M_ceil: tl.constexpr, # pylint: disable=unused-argument
BLOCK_SIZE: tl.constexpr,
H: tl.constexpr,
G: tl.constexpr,
split_k: tl.constexpr,
splitK_pow2: tl.constexpr,
use_mask: tl.constexpr,
IS_CAUSAL: tl.constexpr,
):
off_zhg = tl.program_id(0)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
off_m = tl.program_id(1)
off_k = tl.program_id(2)
# read chunk
spk_idx = tl.arange(0, splitK_pow2)
kidx = tl.arange(0, BLOCK_SIZE)
Metadata_ptr = Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm
o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE +
stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k)
# read max values of each splitK
if use_mask:
spk_mask = spk_idx < split_k
l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf"))
l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0)
acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0)
else:
l_m = tl.load(Metadata_ptr)
l_sum = tl.load(Metadata_ptr + stride_m2)
acc = tl.load(o_ptr)
g_m = tl.max(l_m, axis=0)
if IS_CAUSAL:
l_m_offset = l_m - g_m
alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0)
else:
alpha = tl.math.exp2(l_m - g_m)
# read sum
l_sum *= alpha
g_sum = tl.sum(l_sum, axis=0)
acc = acc * alpha[:, None]
if IS_CAUSAL:
# Avoid division by zero
g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0)
acc_out = tl.sum(acc, axis=0) / g_sum_safe
else:
acc_out = tl.sum(acc, axis=0) / g_sum
# Store output
Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m +
off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE))
tl.store(Out_ptr, acc_out)
# Store lse
l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
if IS_CAUSAL:
lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m)
tl.store(l_ptrs, lse)
else:
tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504)
def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
# Scale and shift are such that quantization linearly maps
# int4 values range [0..15] to input values range min(k)..max(k)
# individually for every row
k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups)
max_vals = torch.max(k, dim=-1, keepdim=True).values
min_vals = torch.min(k, dim=-1, keepdim=True).values
scale_k: torch.Tensor = (max_vals - min_vals) / 15
shift_k = torch.min(k, dim=-1, keepdim=True).values
scale_k = scale_k.to(torch.float16)
shift_k = shift_k.to(torch.float16)
in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5
in_bytes = in_bytes.to(torch.uint8)
in_int4 = in_bytes & 0xF
in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4)
scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1)
k_quant = torch.concat(
[
scale_shift.flatten(start_dim=-2),
in_int4_packed.flatten(start_dim=-2),
],
dim=-1,
).view(torch.int16)
return k_quant
def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
k_i16 = quant_k.view(torch.int16)
k_ui8 = k_i16.view(torch.uint8)
ss_size = num_groups * 4
scale_shift_ui8 = k_ui8[..., 0:ss_size]
scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4)
scale = scale_shift_ui8[..., 0:2].view(torch.float16)
shift = scale_shift_ui8[..., 2:4].view(torch.float16)
kv_ui8 = k_ui8[..., ss_size:]
k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1)
k1_i4 = k_ui8 & 0xF
k2_i4 = (k_ui8 & 0xF0) >> 4
k_shape = k1_i4.shape
k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)
k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)
out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device)
out[..., ::2] = k1_f16
out[..., 1::2] = k2_f16
out = out.reshape(*k_shape[:-2], -1)
return out
def get_split_k(B: int, G: int, H: int, Mk: int) -> int:
"""Heuristic for the number of splits"""
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
split_k = max(Mk, 1024) // bh
max_chunk_size = 64
while split_k > 0 and Mk / split_k < max_chunk_size:
split_k = split_k // 2
while B * H * G * split_k >= 1024:
split_k = split_k // 2
split_k = min(split_k, 512)
split_k = max(split_k, 1)
return split_k
def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new):
# kernel config
BLOCK_M = 16
BLOCK_N = 64
SPLIT_K = None
NUM_QUANT_GROUPS = 1 # pylint: disable=unused-variable
# kernels expects "bsghd"
original_layout = layout
if layout == "bshd":
q = q.unsqueeze(2)
k = k.unsqueeze(2)
v = v.unsqueeze(2)
if new_kv:
k_new = k_new.unsqueeze(2)
v_new = v_new.unsqueeze(2)
layout = "bsghd"
elif layout == "bhsd":
q = q.permute(0, 2, 1, 3).unsqueeze(2)
k = k.permute(0, 2, 1, 3).unsqueeze(2)
v = v.permute(0, 2, 1, 3).unsqueeze(2)
if new_kv:
k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2)
v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2)
layout = "bsghd"
elif layout == "bsghd":
pass
elif layout is None:
raise ValueError("Layout not given")
assert layout == "bsghd"
# get dims
batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape
_, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape # pylint: disable=unused-variable
_, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape # pylint: disable=unused-variable
assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}"
# get padded size
dim_padded = get_padded_headsize(dim_k)
# Handle MQA/GQA case
if heads_per_group_q > heads_per_group_k:
is_gqa = True
elif heads_per_group_q < heads_per_group_k:
raise ValueError("heads_per_group_q < heads_per_group_k")
else:
is_gqa = False
assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}"
if SPLIT_K is not None:
split_k = SPLIT_K
else:
# Use heuristics
split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_k) # NOTE: should the split think about seqlens?
seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M
out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device)
metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device)
lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), device=q.device, dtype=torch.float32)
grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k)
num_warps = 1
split_size = (seqlen_k + split_k - 1) // split_k
use_cache_seqlens = cache_seqlens is not None
# TODO: enable quantization
_fwd_kernel_splitK[grid](
Q=q,
K=k,
V=v,
sm_scale=sm_scale,
Out_splitK=out_splitk,
Metadata=metadata,
K_new = k_new,
V_new = v_new,
Cache_seqlens=cache_seqlens,
Cache_batch_idx=cache_batch_idx,
Alibi_slopes=alibi_slopes,
**_strides(q, "qz", "qm", "qg", "qh", "qd"),
**_strides(k, "kz", "kn", "kg", "kh", "kd"),
**_strides(v, "vz", "vn", "vg", "vh", "vd"),
**_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"),
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
**_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"),
**_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"),
**_strides(alibi_slopes, "az", "ah"),
Z=batch_size,
H_q=heads_per_group_q,
H_kv=heads_per_group_k,
G_q=n_group_q,
N_CTX_Q=seqlen_q,
N_CTX_K=seqlen_k,
N_CTX_NEW=k_new.shape[1] if new_kv else None,
BLOCK_N_PER_SPLIT=split_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=dim_padded,
ACTUAL_BLOCK_DMODEL=dim_k,
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens,
USE_CACHE_SEQLENs=use_cache_seqlens,
USE_CACHE_BATCH_IDX=cache_batch_idx is not None,
NEW_KV=new_kv,
IS_GQA=is_gqa,
IS_CAUSAL=causal,
USE_ALIBI=False if alibi_slopes is None else True,
num_warps=num_warps,
num_stages=1,
)
out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype)
# Merge together
splitK_pow2 = triton.next_power_of_2(split_k)
use_mask = splitK_pow2 > split_k
if batch_size * n_group_q * heads_per_group_q * seqlen_q >= 512:
k_block_num = 1
else:
k_block_num = 2
assert dim_padded % k_block_num == 0
k_block_size = dim_padded // k_block_num
grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num)
_splitK_reduce[grid](
out_splitk,
metadata,
out,
lse,
**_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
**_strides(out, "oz", "om", "og", "oh", "ok"),
**_strides(lse, "lse_zhg", "lse_m"),
M_ceil=seqlen_q_ceil,
BLOCK_SIZE=k_block_size,
G=n_group_q,
H=heads_per_group_q,
# TODO: Tune num_warps
split_k=split_k,
splitK_pow2=splitK_pow2,
use_mask=use_mask,
IS_CAUSAL=causal,
num_warps=4)
lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q])
if q.ndim == 4:
# BMGHK -> BMHK
assert n_group_q == 1
out = out[:, :, 0]
lse = lse[:, 0]
if seqlen_k == 0:
out.zero_()
out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous()
# output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q
if original_layout == "bshd":
# out=out.transpose(1, 2).contiguous() # this screws up heads and data.
# the data is laid out properly. Just need to reshape dims
out = out.reshape(batch_size, seqlen_q, -1, dim_padded)
return out.narrow(-1, 0, dim_k), lse

View File

@ -1,33 +1,8 @@
from typing import Literal, Optional, Union
import torch
import triton
import triton.language as tl
from modules.flash_attn_triton_amd.utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, AUTOTUNE
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): # pylint: disable=unused-argument
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
rng_keep = rng_output > dropout_p
return rng_keep
from modules.flash_attn_triton_amd.utils import AUTOTUNE, compute_alibi_block, get_shapes_from_layout, get_strides_from_layout, is_cdna, is_rdna
# Convenience function to load with optional boundary checks.
@ -50,47 +25,14 @@ def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second):
@triton.jit
def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1],
# [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :]
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
if transpose:
return alibi_block.T
else:
return alibi_block
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m,
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, exp_scores_ptrs,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, # pylint: disable=unused-argument
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m,
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, # pylint: disable=unused-argument
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_EXP2: tl.constexpr,
RETURN_SCORES: tl.constexpr):
ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr,
RETURN_SCORES: tl.constexpr, ACCUMULATOR_TYPE):
if USE_EXP2:
RCP_LN2: tl.constexpr = 1.4426950408889634
@ -107,7 +49,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
if PRE_LOAD_V:
# We can use the same offsets as k, just with dims transposed.
v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
@ -117,18 +59,20 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn.
# last step might get wasted but that is okay. check if this masking works For
# that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
if start_n + BLOCK_N == block_max and n_extra_tokens != 0:
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))
# compute masks
q_mask = OFFS_M[:, None] < actual_seqlen_q
k_mask = (start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k
p_mask = q_mask & k_mask
# -- compute qk ----
qk += tl.dot(q, k)
qk_scaled = qk * SM_SCALE
if RETURN_SCORES:
score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
tl.store(score_ptrs, qk_scaled, mask=score_mask)
qk_scaled = qk * SM_SCALE
if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
@ -139,8 +83,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k)
qk_scaled += bias
if alibi_slope is not None:
# Compute the global position of each token within the sequence
if USE_ALIBI:
# compute the global position of each token within the sequence
global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
global_n_positions = start_n + tl.arange(0, BLOCK_N)
alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions,
@ -151,10 +95,6 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
# scale and subtract max
q_shifted = qk_scaled - m_ij[:, None]
if RETURN_SCORES:
# NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
scores_scaled_shifted_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
tl.store(scores_scaled_shifted_ptrs, q_shifted, mask=scores_scaled_shifted_mask)
# Compute scaled QK and softmax probabilities
if USE_EXP2:
@ -165,17 +105,18 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N
keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k)
if RETURN_SCORES:
# NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
tl.store(exp_scores_ptrs, tl.where(keep, p, -p), mask=exp_score_mask)
p = tl.where(keep, p, 0.0)
rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance
dropout_mask = rng_output > dropout_p
# return scores with negative values for dropped vals
sd_mask = tl.where(dropout_mask, p, -p)
tl.store(sd_mask_ptrs, sd_mask, mask=p_mask)
# apply dropout mask in place
p = tl.where(dropout_mask, p, 0.0)
elif RETURN_SCORES:
# NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
tl.store(exp_scores_ptrs, p, mask=exp_score_mask)
tl.store(sd_mask_ptrs, p, mask=p_mask)
# -- update output accumulator --
# alpha is an adjustment factor for acc and li as we loop and find new maxes
@ -186,7 +127,6 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
else:
alpha = tl.math.exp(m_diff)
acc = acc * alpha[:, None]
v = None
if not PRE_LOAD_V:
v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
# -- update m_i and l_i
@ -199,9 +139,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
if bias_ptrs is not None:
bias_ptrs += BLOCK_N * stride_bn
if RETURN_SCORES:
score_ptrs += BLOCK_N
scores_scaled_shifted_ptrs += BLOCK_N
exp_scores_ptrs += BLOCK_N
sd_mask_ptrs += BLOCK_N * stride_sn
if ENABLE_DROPOUT:
dropout_mask_ptrs += BLOCK_N * stride_sn
philox_ptrs += BLOCK_N * stride_sn
return acc, l_i, m_i
@ -222,7 +164,7 @@ def get_cdna_autotune_configs():
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK']
def get_rdna_autotune_configs():
@ -242,7 +184,7 @@ def get_rdna_autotune_configs():
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK']
def get_autotune_configs():
@ -266,7 +208,7 @@ def get_autotune_configs():
"MAX_SEQLENS_Q",
"MAX_SEQLENS_K",
"ACTUAL_BLOCK_DMODEL",
"VARLEN",
"IS_VARLEN",
"HQ",
"HK",
]
@ -277,37 +219,47 @@ autotune_configs, autotune_keys = get_autotune_configs()
@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
# use_cuda_graph=True,
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, # pylint: disable=unused-argument
SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, # pylint: disable=unused-argument
stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, scores, scores_scaled_shifted, exp_scores, alibi_slopes, HQ: tl.constexpr,
dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_INFERENCE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr):
# set params
ACCUMULATOR_TYPE = tl.float32
# compute offsets
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
if VARLEN:
# handle seqlen
if IS_VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
# print("cu_seqlens_q_start:", cu_seqlens_q_start)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
# we have a one-size-fits-all grid in id(0). Some seqlens might be too small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
elif IS_INFERENCE:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = MAX_SEQLENS_Q
seqlen_k = tl.load(Cache_seqlens + off_z)
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
@ -320,14 +272,14 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
n_blocks = tl.cdiv(seqlen_k, BLOCK_N)
if IS_CAUSAL:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn matrix
n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
n_blocks_seqlen = tl.cdiv((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
@ -345,7 +297,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m
l_ptrs = l_offset + offs_m * stride_lse_m
l = tl.full([BLOCK_M], value=0.0, dtype=tl.float32)
l = tl.full([BLOCK_M], value=0.0, dtype=ACCUMULATOR_TYPE)
# mask_m_offsets = start_m + tl.arange(0, BLOCK_M)
# lse_mask = mask_m_offsets < causal_start_idx
@ -371,7 +323,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
# Compute pointers for all the tensors used in this kernel.
q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
@ -394,28 +346,23 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
alibi_slope = None
if RETURN_SCORES:
scores_offset = scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
score_ptrs = scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
scores_scaled_shifted_offset = scores_scaled_shifted + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
scores_scaled_shifted_ptrs = scores_scaled_shifted_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
exp_scores_offset = exp_scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
exp_scores_ptrs = exp_scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
sd_mask_offset = sd_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm
sd_mask_ptrs = sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
else:
score_ptrs = None
scores_scaled_shifted_ptrs = None
exp_scores_ptrs = None
sd_mask_ptrs = None
if ENABLE_DROPOUT:
off_hz = off_z * HQ + off_h_q
batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k
dropout_mask_offset = dropout_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm
dropout_mask_ptrs = dropout_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
batch_philox_offset = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm
philox_ptrs = batch_philox_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
else:
batch_philox_offset = 0
dropout_mask_ptrs = None
philox_ptrs = 0
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE)
l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=ACCUMULATOR_TYPE)
# Q is loaded once at the beginning and shared by all N blocks.
q_ptrs_mask = offs_m[:, None] < seqlen_q
if PADDED_HEAD:
@ -442,16 +389,16 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset,
exp_scores_ptrs,
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min, block_max, 0, 0, 0, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs,
block_min, block_max, 0, 0, 0, alibi_slope,
# IS_CAUSAL, ....
False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD,
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE)
block_min = block_max
block_max = n_blocks * BLOCK_N
@ -467,23 +414,25 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
if USE_BIAS:
bias_ptrs += n_full_blocks * BLOCK_N * stride_bn
if RETURN_SCORES:
score_ptrs += n_full_blocks * BLOCK_N
scores_scaled_shifted_ptrs += n_full_blocks * BLOCK_N
exp_scores_ptrs += n_full_blocks * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset,
exp_scores_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs,
sd_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn
if ENABLE_DROPOUT:
dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn
philox_ptrs += n_full_blocks * BLOCK_N * stride_sn
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
n_extra_tokens, alibi_slope,
IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD,
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE)
# epilogue
# This helps the compiler do Newton Raphson on l_i vs on acc which is much larger.
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
dropout_scale = 1 / (1 - dropout_p)
acc = acc * dropout_scale
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
@ -491,13 +440,12 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL:
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
z: tl.tensor = 0.0
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE(Log Sum Exponents), the log of the normalization constant
@ -541,30 +489,43 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
def attention_prefill_forward_triton_impl(
q,
k,
v,
o,
sm_scale,
alibi_slopes,
causal,
bias,
dropout_p,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
return_scores,
use_exp2):
# check if varlen
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
alibi_slopes: Optional[torch.Tensor],
causal: bool,
bias: Optional[torch.Tensor],
layout: Literal["bshd", "bhsd", "thd"],
# varlen
cu_seqlens_q: Optional[torch.Tensor],
cu_seqlens_k: Optional[torch.Tensor],
max_seqlens_q: int,
max_seqlens_k: int,
# inference
cache_seqlens: Optional[Union[(int, torch.Tensor)]],
cache_batch_idx: Optional[torch.Tensor],
# dropout
dropout_p: float,
philox_seed: Optional[int],
philox_offset: Optional[int],
# misc
return_softmax: bool,
use_exp2: bool,
):
# check flags
is_varlen = layout == "thd"
use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0))
is_inference = cache_seqlens is not None
if is_inference:
assert layout == "bshd", f"{layout} layout is not supported with inference. Use bshd layout"
# NOTE: a large bias tensor leads to overflow during pointer arithmetic
if bias is not None:
assert bias.numel() < 2**31
if (bias is not None):
assert (bias.numel() < 2**31)
batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) # pylint: disable=unused-variable
batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
# Get closest power of 2 over or equal to 32.
@ -573,59 +534,45 @@ def attention_prefill_forward_triton_impl(
# kernel is padded - there is no padding in memory for any dims.
padded_d_model = max(padded_d_model, 16)
grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) # pylint: disable=unnecessary-lambda-assignment
grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch)
if return_scores:
scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)
scores_scaled_shifted = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)
scores_strides = (scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3))
else:
scores = None
scores_scaled_shifted = None
scores_strides = (0, 0 , 0 , 0)
# exp_scores is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out
# sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out
# to give a consistent starting point and then populate it with the output of softmax with the sign bit set according
# to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing
# only. This return holds no useful output aside from debugging.
if return_scores:
exp_scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)
# only. This return holds no useful output aside from debugging.
use_dropout = (dropout_p > 0.0)
if use_dropout or return_softmax:
sd_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)
dropout_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)
scores_strides = (sd_mask.stride(0), sd_mask.stride(1), sd_mask.stride(2), sd_mask.stride(3))
else:
exp_scores = None
sd_mask = None
dropout_mask = None
scores_strides = (0, 0, 0, 0)
# stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities)
if is_varlen:
softmax_lse = torch.empty((q.shape[0], nheads_q), device=q.device, dtype=torch.float32)
softmax_lse = torch.zeros((q.shape[0], nheads_q), device=q.device, dtype=torch.float32)
stride_lse_m, stride_lse_h = softmax_lse.stride()
stride_lse_z = 0
else:
softmax_lse = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32)
softmax_lse = torch.zeros((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32)
stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride()
# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
if bias is not None:
bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2),
bias.stride(3))
else:
bias_strides = (0, 0, 0, 0)
if alibi_slopes is not None:
alibi_strides = (alibi_slopes.stride(0), alibi_slopes.stride(1))
else:
alibi_strides = (0, 0)
attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, scores=scores,
scores_scaled_shifted=scores_scaled_shifted, exp_scores=exp_scores, alibi_slopes=alibi_slopes,
attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx,
sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, stride_az, stride_ah, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=is_varlen, IS_INFERENCE=is_inference,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_scores)
USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax)

View File

@ -1,22 +1,16 @@
import torch
from modules.flash_attn_triton_amd.fwd_prefill import attention_prefill_forward_triton_impl
from modules.flash_attn_triton_amd.bwd_prefill import attention_prefill_backward_triton_impl
from modules.flash_attn_triton_amd.fwd_decode import attention_decode_forward_triton_impl
from modules.flash_attn_triton_amd.utils import MetaData, get_shape_from_layout
from modules.flash_attn_triton_amd.utils import MetaData
def fwd(q,
k,
v,
dropout_p,
softmax_scale,
causal,
def fwd(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
dropout_p: float,
softmax_scale: float,
causal: bool
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD's Triton Backend yet")
o = torch.empty_like(q)
# Setup metadata
metadata = MetaData(sm_scale=softmax_scale)
metadata.max_seqlens_q = q.shape[1]
@ -24,251 +18,35 @@ def fwd(q,
metadata.layout = "bshd"
if causal:
metadata.need_causal()
#if dropout_p > 0.0:
# metadata.need_dropout(dropout_p, False)
# Check arguments
metadata.check_args(q, k, v, o)
attention_prefill_forward_triton_impl(
q,
k,
v,
o,
metadata.sm_scale,
metadata.alibi_slopes,
metadata.causal,
metadata.bias,
metadata.dropout_p,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.return_scores,
metadata.use_exp2)
return o
def bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD yet")
dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
alibi_slopes,
causal,
"bshd",
None,
None,
None,
None,
False,
)
delta = delta_triton
return dq, dk, dv, delta
def varlen_fwd(
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
seqused_k, leftpad_k, block_table_, # pylint: disable=unused-argument
alibi_slopes,\
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
zero_tensors, # pylint: disable=unused-argument
causal,
window_size_left, window_size_right, softcap, # pylint: disable=unused-argument
return_softmax,
gen_ # pylint: disable=unused-argument
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD's Triton Backend yet")
if o is None:
o = torch.empty_like(q)
# Setup metadata
metadata = MetaData(sm_scale=softmax_scale)
if return_softmax:
metadata.return_scores = True
metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata
# get shapes
batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable
if causal:
metadata.need_causal()
if alibi_slopes is not None:
metadata.need_alibi(alibi_slopes, batch, nheads_q)
metadata.need_causal(True)
if dropout_p > 0.0:
metadata.need_dropout(dropout_p, return_softmax)
metadata.need_dropout(dropout_p)
# Check arguments
metadata.check_args(q, k, v, o)
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
# check arguments
metadata.check_args(q, k, v, out)
# call implementation
attention_prefill_forward_triton_impl(
q,
k,
v,
o,
out,
metadata.sm_scale,
metadata.alibi_slopes,
metadata.causal,
metadata.bias,
metadata.dropout_p,
None,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.return_scores,
metadata.cache_seqlens,
metadata.cache_batch_idx,
metadata.dropout_p,
metadata.philox_seed,
metadata.philox_offset,
False,
metadata.use_exp2)
return o
def varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
zero_tensors, # pylint: disable=unused-argument
causal,
window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD yet")
dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
alibi_slopes,
causal,
"thd",
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
False,
)
delta = delta_triton
return dq, dk, dv, delta
def fwd_kvcache(
q,
k_cache,
v_cache,
k,
v,
cache_seqlens,
rotary_cos, rotary_sin, # pylint: disable=unused-argument
cache_batch_idx,
cache_leftpad, block_table, # pylint: disable=unused-argument
alibi_slopes,
out,
softmax_scale,
causal,
window_size_left, window_size_right, softcap, rotary_interleaved, num_splits, # pylint: disable=unused-argument
):
if out is None:
out = torch.empty_like(q)
# fill metadata
metadata = MetaData(sm_scale=softmax_scale)
metadata.layout = "bshd"
metadata.max_seqlens_q = q.shape[1]
metadata.max_seqlens_k = k_cache.shape[1]
metadata.cache_seqlens = cache_seqlens
metadata.cache_batch_idx = cache_batch_idx
if k is not None and v is not None:
metadata.new_kv = True
metadata.seqlen_new = k.shape[1]
metadata.k_new = k
metadata.v_new = v
if causal:
metadata.need_causal()
if alibi_slopes is not None:
batch, _ , nheads_q, _= q.shape
metadata.need_alibi(alibi_slopes, batch, nheads_q)
# launch kernel
# TODO: pass output as an arg. Maybe we are copying output which is causing slow down
output, softmax_lse = attention_decode_forward_triton_impl(
q,
k_cache,
v_cache,
metadata.sm_scale,
metadata.causal,
metadata.alibi_slopes,
metadata.layout,
metadata.cache_seqlens,
metadata.cache_batch_idx,
metadata.new_kv,
metadata.k_new,
metadata.v_new,
)
return output, softmax_lse
# varlen

View File

@ -1,33 +1,48 @@
import os
import csv
import math
import torch
import os
import random
import functools
import triton
import triton.language as tl
from typing import Literal, Optional, Union
from modules.rocm import Agent, MicroArchitecture
AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes')
USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')
PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')
# -------------------------------
# Metadata
# -------------------------------
class MetaData():
cu_seqlens_q = None
cu_seqlens_k = None
max_seqlens_q = 0
max_seqlens_k = 0
bias = None
alibi_slopes = None
causal = False
cu_seqlens_q: Optional[torch.Tensor] = None
cu_seqlens_k: Optional[torch.Tensor] = None
max_seqlens_q: int = 0
max_seqlens_k: int = 0
bias: Optional[torch.Tensor] = None
alibi_slopes: Optional[torch.Tensor] = None
causal: bool = False
num_contexts = 0
varlen = False
layout = None
cache_seqlens = None
varlen: bool = False
layout: Optional[Literal["bshd", "bhsd", "thd"]] = None
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None
cache_batch_idx = None
new_kv = False
seqlen_new = None
k_new = None
v_new = None
dropout_p, return_scores= 0.0, False
packing: Optional[bool] = None
return_scores: bool = False
dropout_p: float = 0.0
philox_seed: Optional[int] = None
philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing.
# NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW.
use_exp2 = False
use_exp2: bool = False
rotary_sin: Optional[torch.Tensor] = None
rotary_cos: Optional[torch.Tensor] = None
rotary_interleaved: bool = False
rotary_conjunction: bool = False
def __repr__(self) -> str:
return (f"MetaData(\n"
@ -44,10 +59,6 @@ class MetaData():
f" layout={self.layout},\n"
f" cache_seqlens={self.cache_seqlens},\n"
f" cache_batch_idx={self.cache_batch_idx},\n"
f" new_kv={self.new_kv},\n"
f" seqlen_new={self.seqlen_new},\n"
f" k_new={self.k_new},\n"
f" v_new={self.v_new},\n"
f" dropout_p={self.dropout_p},\n"
f" return_scores={self.return_scores}\n"
f")")
@ -55,20 +66,19 @@ class MetaData():
def __init__(self, sm_scale=1.0):
self.sm_scale = sm_scale
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
self.varlen = True
self.layout = 'thd'
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
self.max_seqlens_q = max_seqlen_q
self.max_seqlens_k = max_seqlen_k
# Without "varlen", there should still be one sequence.
assert len(cu_seqlens_q) >= 2
assert len(cu_seqlens_q) == len(cu_seqlens_k)
self.num_contexts = len(cu_seqlens_q) - 1
for i in range(0, self.num_contexts):
self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q)
self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k)
def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): # pylint: disable=unused-argument
def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k):
assert bias.is_cuda
assert bias.dim() == 4
assert bias.shape[0] == 1
@ -82,17 +92,25 @@ class MetaData():
assert alibi_slopes.shape[1] == nheads
self.alibi_slopes = alibi_slopes
def need_causal(self):
self.causal = True
def need_causal(self, causal):
self.causal = causal
def need_dropout(self, dropout_p, return_scores):
self.dropout_p = dropout_p
self.return_scores = return_scores
def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False):
self.rotary_sin = sin
self.rotary_cos = cos
self.rotary_interleaved = rotary_interleaved
self.rotary_conjunction = rotary_conjunction
def need_dropout(self, dropout_p, return_scores = True):
if dropout_p > 0.0:
self.dropout_p = dropout_p
self.return_scores = return_scores
self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49
def check_args(self, q, k, v, o):
assert q.dim() == k.dim() and q.dim() == v.dim()
batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) # pylint: disable=unused-variable
batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k)
if self.varlen:
assert q.dim() == 3
assert self.cu_seqlens_q is not None
@ -100,8 +118,6 @@ class MetaData():
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
# TODO: Remove once bias is supported with varlen
assert self.bias is None
# TODO:Remove once dropout is supported with varlen
assert self.dropout_p == 0.0
# assert not self.return_scores
else:
assert q.dim() == 4
@ -111,138 +127,286 @@ class MetaData():
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
assert self.layout is not None
assert self.layout == 'thd' or not self.varlen
def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cuda", DEBUG_INPUT=False):
# -------------------------------
# Input Helper
# -------------------------------
def random_seqlens_composition(SEQ_LEN, BATCH):
# generate a random composition of N into Z positive parts.
idx = torch.randperm(SEQ_LEN - 1)[: BATCH - 1] + 1
idx, _ = torch.sort(idx)
breakpoints = torch.cat([
torch.tensor([0], dtype=torch.long),
idx,
torch.tensor([SEQ_LEN], dtype=torch.long),
])
seqlens = (breakpoints[1:] - breakpoints[:-1]).to(torch.int32)
return seqlens
def generate_varlen_tensor(
total_seqlen: int,
num_heads: int,
head_size: int,
batch_size: Optional[int] = None,
equal_seqlens: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.float32,
DEBUG_INPUT: bool = False
):
# get valid batch_size
if batch_size is None:
valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen]
batch_size = random.choice(valid_batch_sizes)
# get seqlens
if equal_seqlens:
seqlens = torch.full(
(batch_size,),
total_seqlen // batch_size,
dtype=torch.int32,
device=device
)
seqlens[-1] += total_seqlen % batch_size
else:
seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device)
# create cumulative sequence lengths
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device)
max_seqlen = torch.max(seqlens).to(torch.int32).item()
# create varlen tensor
if DEBUG_INPUT:
x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device)
for i in range(batch_size):
start = cu_seqlens[i].item()
end = cu_seqlens[i+1].item()
length = end - start
x[start:end, :, :] = (
torch.arange(length, dtype=dtype, device=device)
.view(length, 1, 1)
.expand(length, num_heads, head_size)
)
else:
x = torch.randn((total_seqlen, num_heads, head_size), dtype=dtype, device=device)
x.requires_grad_()
return x, cu_seqlens, max_seqlen
def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False):
# gen tensor
tensor_shape = (BATCH, SEQ_LEN, NUM_HEADS, D_HEAD)
if DEBUG_INPUT:
x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1).expand(*tensor_shape).contiguous()
else:
x = torch.randn(tensor_shape, dtype=dtype, device=device)
x.requires_grad_()
return x
def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False):
# gen tensor
tensor_shape = (BATCH, NUM_HEADS, SEQ_LEN, D_HEAD)
if DEBUG_INPUT:
x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous()
else:
x = torch.randn(tensor_shape, dtype=dtype, device=device)
x.requires_grad_()
return x
def input_helper(
BATCH: int,
HQ: int,
HK: int,
N_CTX_Q: int,
N_CTX_K: int,
D_HEAD: int,
CAUSAL: bool,
DROPOUT_P: float,
dtype: torch.dtype,
layout: Literal["bshd", "bhsd", "thd"],
packing: Optional[Literal["kv", "qkv"]] = None,
device: Literal["cpu", "cuda"] = "cuda",
DEBUG_INPUT: bool = False,
):
torch.manual_seed(20)
# Initialize q, k, v
if layout == 'bhsd':
q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
elif layout == 'bshd':
q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD)
if layout == "thd":
# set params
TOTAL_SEQLENS_Q = BATCH * N_CTX_Q
TOTAL_SEQLENS_K = BATCH * N_CTX_K
equal_seqlens=False
# gen tensors
# TODO: the gen functions should maybe have different gen modes like random, ones, increasing seqlen
q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q)
# setup metadata
if DEBUG_INPUT:
sm_scale = 1
else:
sm_scale = D_HEAD**-0.5
metadata = MetaData(sm_scale=sm_scale)
metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
metadata.need_causal(CAUSAL)
metadata.need_dropout(DROPOUT_P)
elif layout == 'bshd' or layout == "bhsd":
# gen tensors
if layout == "bshd":
q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q)
elif layout == "bhsd":
q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q)
# setup metadata
if DEBUG_INPUT:
sm_scale = 1
else:
sm_scale = D_HEAD**-0.5
metadata = MetaData(sm_scale=sm_scale)
metadata.max_seqlens_q = N_CTX_Q
metadata.max_seqlens_k = N_CTX_K
metadata.layout = layout
metadata.need_causal(CAUSAL)
metadata.need_dropout(DROPOUT_P)
else:
assert False, f'Got unsupported tensor layout: {layout}'
raise ValueError(f"Unknown layout: {layout}")
q = None
k = None
v = None
if DEBUG_INPUT:
if layout == "bhsd":
q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, 1, N_CTX_Q, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
# deal with packing
if packing is None:
return q, k, v, do, metadata
elif packing == "kv":
# pack k and v
if layout in ["bhsd", "thd"]:
kv = torch.stack([k, v], dim=1)
elif layout == "bshd":
q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, N_CTX_Q, 1, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
kv = torch.stack([k, v], dim=2)
else:
raise ValueError(f"Unknown layout: {layout}")
return q, kv, do, metadata
elif packing == "qkv":
# qkv packing - requires same sequence length for q and k
assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length"
assert HQ == HK, "For QKV packing, Q and K must have same number of heads"
# pack q, k, and v
if layout in ["bhsd", "thd"]:
qkv = torch.stack([q, k, v], dim=1)
elif layout == "bshd":
qkv = torch.stack([q, k, v], dim=2)
else:
raise ValueError(f"Unknown layout: {layout}")
return qkv, do, metadata
else:
q = torch.randn(q_tensor_shape, dtype=dtype, device=device, requires_grad=True)
k = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
v = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
assert False, f"Unsupported packing mode: {packing}"
if DEBUG_INPUT:
sm_scale = 1
# -------------------------------
# Alibi
# -------------------------------
@triton.jit
def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1],
# [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :]
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
if transpose:
return alibi_block.T
else:
sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.max_seqlens_q = N_CTX_Q
input_metadata.max_seqlens_k = N_CTX_K
input_metadata.layout = layout
return q, k, v, input_metadata
return alibi_block
def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False):
torch.manual_seed(20)
# Random or equal sequence lengths based on 'equal_seqlens' flag
if not equal_seqlens:
max_seqlens_q = N_CTX_Q // Z
max_seqlens_k = N_CTX_K // Z
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32)
seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32)
else:
seqlens_q = torch.full((Z,), N_CTX_Q // Z, dtype=torch.int32)
seqlens_k = torch.full((Z,), N_CTX_K // Z, dtype=torch.int32)
# Calculate cumulative sequence lengths
cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0)])
cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0)])
cu_seqlens_q = cu_seqlens_q.to(device=device).to(torch.int32)
cu_seqlens_k = cu_seqlens_k.to(device=device).to(torch.int32)
# Total lengths
total_q = cu_seqlens_q[-1].item()
total_k = cu_seqlens_k[-1].item()
if DEBUG_INPUT:
# Initialize q, k, v with deterministic values
q = torch.arange(total_q, dtype=dtype, device=device).view(total_q, 1, 1)
q = q.expand(total_q, HQ, D_HEAD).contiguous().requires_grad_()
k = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
k = k.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
v = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
v = v.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
sm_scale = 1
else:
# Initialize q, k, v with random values
q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_()
k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
sm_scale = D_HEAD ** -0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
return q, k, v, input_metadata
def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None):
# -------------------------------
# Misc
# -------------------------------
def get_shape_from_layout(
x: torch.Tensor,
layout: Literal["bshd", "bhsd", "thd"],
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
) -> tuple[int, int, int, int]:
if layout == 'bhsd':
batch_q, nheads_q, max_seqlen_q, head_size_q = q.shape
batch_k, nheads_k, max_seqlen_k, head_size_k = k.shape
batch, num_heads, max_seqlen_final, head_dim = x.shape
elif layout == 'bshd':
batch_q, max_seqlen_q, nheads_q, head_size_q = q.shape
batch_k, max_seqlen_k, nheads_k, head_size_k = k.shape
batch, max_seqlen_final, num_heads, head_dim = x.shape
elif layout == 'thd':
batch_q, max_seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] # pylint: disable=self-assigning-variable
batch_k, max_seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2] # pylint: disable=self-assigning-variable
total_seqlen, num_heads, head_dim = x.shape
if cu_seqlens is None:
raise ValueError("cu_seqlens must be provided for varlen (thd) layout")
if max_seqlen is None:
raise ValueError("max_seqlen must be provided for varlen (thd) layout")
batch, max_seqlen_final, num_heads, head_dim = len(cu_seqlens) - 1, max_seqlen, num_heads, head_dim
else:
assert False, "Got unsupported layout."
return batch, max_seqlen_final, num_heads, head_dim
def get_shapes_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None):
batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout(q, layout, cu_seqlens_q, max_seqlen_q)
batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout(k, layout, cu_seqlens_k, max_seqlen_k)
# assert
assert batch_q == batch_k
assert head_size_q == head_size_k
return batch_q, nheads_q, nheads_k, head_size_q, max_seqlen_q, max_seqlen_k
return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k
def get_strides_from_layout(q, k, v, o, layout):
def get_stride_from_layout(x: torch.Tensor, layout:Literal["bshd", "bhsd", "thd"]):
if layout == 'thd':
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
strides = (0, x.stride(1), x.stride(0), x.stride(2))
elif layout == 'bhsd':
q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))
k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))
v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))
o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))
strides = (x.stride(0), x.stride(1), x.stride(2), x.stride(3))
elif layout == 'bshd':
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
strides = (x.stride(0), x.stride(2), x.stride(1), x.stride(3))
else:
assert False, 'Got unsupported layout.'
return q_strides, k_strides, v_strides, o_strides
return strides
def get_shape_and_strides_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"], cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
return get_shape_from_layout(x, layout, cu_seqlens, max_seqlen), get_stride_from_layout(x, layout)
def get_strides_from_layout(q, k, v, o, layout):
q_strides = get_stride_from_layout(q, layout)
k_strides = get_stride_from_layout(k, layout)
v_strides = get_stride_from_layout(v, layout)
o_strides = get_stride_from_layout(o, layout)
return q_strides, k_strides, v_strides, o_strides
def get_padded_headsize(size):
# Get closest power of 2 over or equal to 32.
@ -252,24 +416,79 @@ def get_padded_headsize(size):
padded_d_model = max(padded_d_model, 16)
return padded_d_model
def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k):
q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K)
relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K)
return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
def _strides(x: torch.Tensor, *stride_names: str):
if x is None:
return {f"stride_{s}": 0 for i, s in enumerate(stride_names)}
# -------------------------------
# Dropouts
# -------------------------------
def create_dropout_mask(dropout_p, shape, seed):
device = "cuda"
rand_vals = torch.rand(shape, generator=torch.Generator(device=device).manual_seed(seed), device=device, dtype=torch.float32)
return rand_vals > dropout_p
assert x.ndim == len(stride_names)
return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
def create_dropout_mask_varlen(dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed):
device = "cuda"
qlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1])
klens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1])
max_qlen = qlens.max()
max_klen = klens.max()
dropout_mask = torch.zeros((batch, nheads_q, max_qlen, max_klen), device=device)
for b in range(batch):
qlen = qlens[b]
klen = klens[b]
rand_vals = torch.rand((nheads_q, qlen, klen), generator=torch.Generator(device=device).manual_seed(philox_seed), device=device, dtype=torch.float32)
submask = rand_vals > dropout_p
dropout_mask[b, :, :qlen, :klen] = submask
return dropout_mask
def get_input_shapes():
cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128)
for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)]
return cases
def write_dropout_mask(x, tensor_name = "tensor"):
batch, head, seqlen_m, seqlen_n = x.shape
x = x.tolist()
with open(f'{tensor_name}.csv', 'w') as f:
writer = csv.writer(f)
for b in range(batch):
for h in range(head):
dropout_mask = x[b][h]
if True:
BLOCK_M = 64
BLOCK_N = 64
# Calculate number of blocks in each dimension
m_blocks = math.ceil(seqlen_m / BLOCK_M)
n_blocks = math.ceil(seqlen_n / BLOCK_N)
# Process each block
for m_block in range(m_blocks):
# Calculate row range for current block
row_start = m_block * BLOCK_M
row_end = min(row_start + BLOCK_M, seqlen_m)
for n_block in range(n_blocks):
# Calculate column range for current block
col_start = n_block * BLOCK_N
col_end = min(col_start + BLOCK_N, seqlen_n)
# Extract and write the current block
for row_idx in range(row_start, row_end):
row_data = dropout_mask[row_idx][col_start:col_end]
writer.writerow(row_data)
else:
writer.writerows(dropout_mask)
# -------------------------------
# Runtime info
# -------------------------------
@functools.cache
def is_cdna():
return Agent(triton.runtime.driver.active.get_current_target().arch).arch == MicroArchitecture.CDNA
@functools.cache
def is_rdna():
return Agent(triton.runtime.driver.active.get_current_target().arch).arch == MicroArchitecture.RDNA

View File

@ -33,16 +33,8 @@ MEM_BUS_WIDTH = {
}
_topk = torch.topk
def topk(input: torch.Tensor, *args, **kwargs): # pylint: disable=redefined-builtin
device = input.device
values, indices = _topk(input.cpu(), *args, **kwargs)
return torch.return_types.topk((values.to(device), indices.to(device),))
class DeviceProperties:
PROPERTIES_OVERRIDE = {
"regs_per_multiprocessor": 65535,
# sometimes gcnArchName contains device name ("AMD Radeon RX ..."), not architecture name ("gfx...")
"gcnArchName": "UNKNOWN ARCHITECTURE",
}
@ -68,7 +60,6 @@ def torch__C__cuda_getCurrentRawStream(device):
def do_hijack():
torch.topk = topk
if zluda.default_agent is not None:
DeviceProperties.PROPERTIES_OVERRIDE["gcnArchName"] = zluda.default_agent.name
torch.cuda._get_device_properties = torch_cuda__get_device_properties # pylint: disable=protected-access
@ -104,10 +95,13 @@ def do_hijack():
query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8])
key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8])
value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8])
out_padded = interface_fa.fwd(
query.transpose(1, 2),
query = query.transpose(1, 2)
out_padded = torch.zeros_like(query)
interface_fa.fwd(
query,
key.transpose(1, 2),
value.transpose(1, 2),
out_padded,
dropout_p,
scale,
is_causal,

View File

@ -69,8 +69,8 @@ def set_default_agent(agent: rocm.Agent):
default_agent = agent
def is_reinstall_needed() -> bool: # ZLUDA<3.8.7
return not os.path.exists(os.path.join(path, 'cufftw.dll'))
def is_reinstall_needed() -> bool: # ZLUDA<3.9.4
return os.path.exists(os.path.join(path, 'cudart.dll'))
def install():
@ -78,7 +78,7 @@ def install():
return
platform = "windows"
commit = os.environ.get("ZLUDA_HASH", "dba64c0966df2c71e82255e942c96e2e1cea3a2d")
commit = os.environ.get("ZLUDA_HASH", "8d2128caf460b853b165cab0b4d8826b6b734ae7")
if os.environ.get("ZLUDA_NIGHTLY", "0") == "1":
log.warning("Environment variable 'ZLUDA_NIGHTLY' will be removed. Please use command-line argument '--use-nightly' instead.")
args.use_nightly = True