mirror of https://github.com/vladmandic/automatic
492 lines
20 KiB
Python
492 lines
20 KiB
Python
import csv
|
|
import math
|
|
import torch
|
|
import os
|
|
import random
|
|
import functools
|
|
import triton
|
|
import triton.language as tl
|
|
from typing import Literal, Optional, Union
|
|
from modules.rocm import Agent, MicroArchitecture
|
|
|
|
|
|
AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes')
|
|
USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')
|
|
PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')
|
|
|
|
|
|
# -------------------------------
|
|
# Metadata
|
|
# -------------------------------
|
|
class MetaData():
|
|
cu_seqlens_q: Optional[torch.Tensor] = None
|
|
cu_seqlens_k: Optional[torch.Tensor] = None
|
|
max_seqlens_q: int = 0
|
|
max_seqlens_k: int = 0
|
|
bias: Optional[torch.Tensor] = None
|
|
alibi_slopes: Optional[torch.Tensor] = None
|
|
causal: bool = False
|
|
num_contexts = 0
|
|
varlen: bool = False
|
|
layout: Optional[Literal["bshd", "bhsd", "thd"]] = None
|
|
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None
|
|
cache_batch_idx = None
|
|
packing: Optional[bool] = None
|
|
return_scores: bool = False
|
|
dropout_p: float = 0.0
|
|
philox_seed: Optional[int] = None
|
|
philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing.
|
|
# NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW.
|
|
use_exp2: bool = False
|
|
rotary_sin: Optional[torch.Tensor] = None
|
|
rotary_cos: Optional[torch.Tensor] = None
|
|
rotary_interleaved: bool = False
|
|
rotary_conjunction: bool = False
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"MetaData(\n"
|
|
f" sm_scale={self.sm_scale},\n"
|
|
f" cu_seqlens_q={self.cu_seqlens_q},\n"
|
|
f" cu_seqlens_k={self.cu_seqlens_k},\n"
|
|
f" max_seqlens_q={self.max_seqlens_q},\n"
|
|
f" max_seqlens_k={self.max_seqlens_k},\n"
|
|
f" bias={self.bias},\n"
|
|
f" alibi_slopes={self.alibi_slopes},\n"
|
|
f" causal={self.causal},\n"
|
|
f" num_contexts={self.num_contexts},\n"
|
|
f" varlen={self.varlen},\n"
|
|
f" layout={self.layout},\n"
|
|
f" cache_seqlens={self.cache_seqlens},\n"
|
|
f" cache_batch_idx={self.cache_batch_idx},\n"
|
|
f" dropout_p={self.dropout_p},\n"
|
|
f" return_scores={self.return_scores}\n"
|
|
f")")
|
|
|
|
def __init__(self, sm_scale=1.0):
|
|
self.sm_scale = sm_scale
|
|
|
|
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
|
|
self.varlen = True
|
|
self.layout = 'thd'
|
|
self.cu_seqlens_q = cu_seqlens_q
|
|
self.cu_seqlens_k = cu_seqlens_k
|
|
self.max_seqlens_q = max_seqlen_q
|
|
self.max_seqlens_k = max_seqlen_k
|
|
|
|
# Without "varlen", there should still be one sequence.
|
|
assert len(cu_seqlens_q) >= 2
|
|
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
|
|
|
def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k):
|
|
assert bias.is_cuda
|
|
assert bias.dim() == 4
|
|
assert bias.shape[0] == 1
|
|
assert bias.shape[2:] == (seqlen_q, seqlen_k)
|
|
self.bias = bias
|
|
|
|
def need_alibi(self, alibi_slopes, batch, nheads):
|
|
assert alibi_slopes.is_cuda
|
|
assert alibi_slopes.dim() == 2
|
|
assert alibi_slopes.shape[0] == batch
|
|
assert alibi_slopes.shape[1] == nheads
|
|
self.alibi_slopes = alibi_slopes
|
|
|
|
def need_causal(self, causal):
|
|
self.causal = causal
|
|
|
|
def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False):
|
|
self.rotary_sin = sin
|
|
self.rotary_cos = cos
|
|
self.rotary_interleaved = rotary_interleaved
|
|
self.rotary_conjunction = rotary_conjunction
|
|
|
|
def need_dropout(self, dropout_p, return_scores = True):
|
|
if dropout_p > 0.0:
|
|
self.dropout_p = dropout_p
|
|
self.return_scores = return_scores
|
|
self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49
|
|
|
|
def check_args(self, q, k, v, o):
|
|
assert q.dim() == k.dim() and q.dim() == v.dim()
|
|
|
|
batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k)
|
|
if self.varlen:
|
|
assert q.dim() == 3
|
|
assert self.cu_seqlens_q is not None
|
|
assert self.cu_seqlens_k is not None
|
|
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
|
|
assert self.bias is None
|
|
# assert not self.return_scores
|
|
else:
|
|
assert q.dim() == 4
|
|
assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
|
|
assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
|
|
assert k.shape == v.shape
|
|
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
|
assert q.dtype == k.dtype and q.dtype == v.dtype
|
|
assert o.shape == q.shape
|
|
assert (nheads_q % nheads_k) == 0
|
|
assert self.layout is not None
|
|
assert self.layout == 'thd' or not self.varlen
|
|
|
|
# -------------------------------
|
|
# Input Helper
|
|
# -------------------------------
|
|
def random_seqlens_composition(SEQ_LEN, BATCH):
|
|
# generate a random composition of N into Z positive parts.
|
|
idx = torch.randperm(SEQ_LEN - 1)[: BATCH - 1] + 1
|
|
idx, _ = torch.sort(idx)
|
|
breakpoints = torch.cat([
|
|
torch.tensor([0], dtype=torch.long),
|
|
idx,
|
|
torch.tensor([SEQ_LEN], dtype=torch.long),
|
|
])
|
|
seqlens = (breakpoints[1:] - breakpoints[:-1]).to(torch.int32)
|
|
return seqlens
|
|
|
|
def generate_varlen_tensor(
|
|
total_seqlen: int,
|
|
num_heads: int,
|
|
head_size: int,
|
|
batch_size: Optional[int] = None,
|
|
equal_seqlens: bool = False,
|
|
device: str = "cuda",
|
|
dtype: torch.dtype = torch.float32,
|
|
DEBUG_INPUT: bool = False
|
|
):
|
|
# get valid batch_size
|
|
if batch_size is None:
|
|
valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen]
|
|
batch_size = random.choice(valid_batch_sizes)
|
|
|
|
# get seqlens
|
|
if equal_seqlens:
|
|
seqlens = torch.full(
|
|
(batch_size,),
|
|
total_seqlen // batch_size,
|
|
dtype=torch.int32,
|
|
device=device
|
|
)
|
|
seqlens[-1] += total_seqlen % batch_size
|
|
else:
|
|
seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device)
|
|
|
|
# create cumulative sequence lengths
|
|
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device)
|
|
max_seqlen = torch.max(seqlens).to(torch.int32).item()
|
|
|
|
# create varlen tensor
|
|
if DEBUG_INPUT:
|
|
x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device)
|
|
for i in range(batch_size):
|
|
start = cu_seqlens[i].item()
|
|
end = cu_seqlens[i+1].item()
|
|
length = end - start
|
|
|
|
x[start:end, :, :] = (
|
|
torch.arange(length, dtype=dtype, device=device)
|
|
.view(length, 1, 1)
|
|
.expand(length, num_heads, head_size)
|
|
)
|
|
else:
|
|
x = torch.randn((total_seqlen, num_heads, head_size), dtype=dtype, device=device)
|
|
|
|
x.requires_grad_()
|
|
return x, cu_seqlens, max_seqlen
|
|
|
|
def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False):
|
|
# gen tensor
|
|
tensor_shape = (BATCH, SEQ_LEN, NUM_HEADS, D_HEAD)
|
|
if DEBUG_INPUT:
|
|
x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1).expand(*tensor_shape).contiguous()
|
|
else:
|
|
x = torch.randn(tensor_shape, dtype=dtype, device=device)
|
|
|
|
x.requires_grad_()
|
|
return x
|
|
|
|
def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False):
|
|
# gen tensor
|
|
tensor_shape = (BATCH, NUM_HEADS, SEQ_LEN, D_HEAD)
|
|
if DEBUG_INPUT:
|
|
x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous()
|
|
else:
|
|
x = torch.randn(tensor_shape, dtype=dtype, device=device)
|
|
|
|
x.requires_grad_()
|
|
return x
|
|
|
|
def input_helper(
|
|
BATCH: int,
|
|
HQ: int,
|
|
HK: int,
|
|
N_CTX_Q: int,
|
|
N_CTX_K: int,
|
|
D_HEAD: int,
|
|
CAUSAL: bool,
|
|
DROPOUT_P: float,
|
|
dtype: torch.dtype,
|
|
layout: Literal["bshd", "bhsd", "thd"],
|
|
packing: Optional[Literal["kv", "qkv"]] = None,
|
|
device: Literal["cpu", "cuda"] = "cuda",
|
|
DEBUG_INPUT: bool = False,
|
|
):
|
|
torch.manual_seed(20)
|
|
|
|
if layout == "thd":
|
|
# set params
|
|
TOTAL_SEQLENS_Q = BATCH * N_CTX_Q
|
|
TOTAL_SEQLENS_K = BATCH * N_CTX_K
|
|
equal_seqlens=False
|
|
|
|
# gen tensors
|
|
q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
|
|
k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
|
|
v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
|
|
do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q)
|
|
|
|
# setup metadata
|
|
if DEBUG_INPUT:
|
|
sm_scale = 1
|
|
else:
|
|
sm_scale = D_HEAD**-0.5
|
|
metadata = MetaData(sm_scale=sm_scale)
|
|
metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
|
|
metadata.need_causal(CAUSAL)
|
|
metadata.need_dropout(DROPOUT_P)
|
|
elif layout == 'bshd' or layout == "bhsd":
|
|
# gen tensors
|
|
if layout == "bshd":
|
|
q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
|
|
k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
|
|
v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
|
|
do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q)
|
|
elif layout == "bhsd":
|
|
q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
|
|
k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
|
|
v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
|
|
do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q)
|
|
|
|
# setup metadata
|
|
if DEBUG_INPUT:
|
|
sm_scale = 1
|
|
else:
|
|
sm_scale = D_HEAD**-0.5
|
|
metadata = MetaData(sm_scale=sm_scale)
|
|
metadata.max_seqlens_q = N_CTX_Q
|
|
metadata.max_seqlens_k = N_CTX_K
|
|
metadata.layout = layout
|
|
metadata.need_causal(CAUSAL)
|
|
metadata.need_dropout(DROPOUT_P)
|
|
else:
|
|
raise ValueError(f"Unknown layout: {layout}")
|
|
|
|
# deal with packing
|
|
if packing is None:
|
|
return q, k, v, do, metadata
|
|
elif packing == "kv":
|
|
# pack k and v
|
|
if layout in ["bhsd", "thd"]:
|
|
kv = torch.stack([k, v], dim=1)
|
|
elif layout == "bshd":
|
|
kv = torch.stack([k, v], dim=2)
|
|
else:
|
|
raise ValueError(f"Unknown layout: {layout}")
|
|
|
|
return q, kv, do, metadata
|
|
elif packing == "qkv":
|
|
# qkv packing - requires same sequence length for q and k
|
|
assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length"
|
|
assert HQ == HK, "For QKV packing, Q and K must have same number of heads"
|
|
|
|
# pack q, k, and v
|
|
if layout in ["bhsd", "thd"]:
|
|
qkv = torch.stack([q, k, v], dim=1)
|
|
elif layout == "bshd":
|
|
qkv = torch.stack([q, k, v], dim=2)
|
|
else:
|
|
raise ValueError(f"Unknown layout: {layout}")
|
|
|
|
return qkv, do, metadata
|
|
else:
|
|
assert False, f"Unsupported packing mode: {packing}"
|
|
|
|
# -------------------------------
|
|
# Alibi
|
|
# -------------------------------
|
|
@triton.jit
|
|
def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False):
|
|
# when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix
|
|
# for casual mask we want something like this where (1 is kept and 0 is masked)
|
|
# seqlen_q = 2 and seqlen_k = 5
|
|
# 1 1 1 1 0
|
|
# 1 1 1 1 1
|
|
# seqlen_q = 5 and seqlen_k = 2
|
|
# 0 0
|
|
# 0 0
|
|
# 0 0
|
|
# 1 0
|
|
# 1 1
|
|
# for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal
|
|
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
|
|
# 1. offs_m[:,None] = [[0],
|
|
# [1],
|
|
# 2. offs_m[:,None] + seqlen_k = [[5],
|
|
# [6],
|
|
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
|
|
# [4],
|
|
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1],
|
|
# [4], [ 4, 3, 2, 1, 0]]
|
|
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
|
|
# [ -4, -3, -2, -1, 0]],
|
|
relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :]
|
|
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
|
|
if transpose:
|
|
return alibi_block.T
|
|
else:
|
|
return alibi_block
|
|
|
|
# -------------------------------
|
|
# Misc
|
|
# -------------------------------
|
|
def get_shape_from_layout(
|
|
x: torch.Tensor,
|
|
layout: Literal["bshd", "bhsd", "thd"],
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
) -> tuple[int, int, int, int]:
|
|
if layout == 'bhsd':
|
|
batch, num_heads, max_seqlen_final, head_dim = x.shape
|
|
elif layout == 'bshd':
|
|
batch, max_seqlen_final, num_heads, head_dim = x.shape
|
|
elif layout == 'thd':
|
|
total_seqlen, num_heads, head_dim = x.shape
|
|
if cu_seqlens is None:
|
|
raise ValueError("cu_seqlens must be provided for varlen (thd) layout")
|
|
if max_seqlen is None:
|
|
raise ValueError("max_seqlen must be provided for varlen (thd) layout")
|
|
|
|
batch, max_seqlen_final, num_heads, head_dim = len(cu_seqlens) - 1, max_seqlen, num_heads, head_dim
|
|
else:
|
|
assert False, "Got unsupported layout."
|
|
|
|
return batch, max_seqlen_final, num_heads, head_dim
|
|
|
|
|
|
def get_shapes_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None):
|
|
batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout(q, layout, cu_seqlens_q, max_seqlen_q)
|
|
batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout(k, layout, cu_seqlens_k, max_seqlen_k)
|
|
|
|
# assert
|
|
assert batch_q == batch_k
|
|
assert head_size_q == head_size_k
|
|
|
|
return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k
|
|
|
|
def get_stride_from_layout(x: torch.Tensor, layout:Literal["bshd", "bhsd", "thd"]):
|
|
if layout == 'thd':
|
|
strides = (0, x.stride(1), x.stride(0), x.stride(2))
|
|
elif layout == 'bhsd':
|
|
strides = (x.stride(0), x.stride(1), x.stride(2), x.stride(3))
|
|
elif layout == 'bshd':
|
|
strides = (x.stride(0), x.stride(2), x.stride(1), x.stride(3))
|
|
else:
|
|
assert False, 'Got unsupported layout.'
|
|
return strides
|
|
|
|
def get_shape_and_strides_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"], cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
|
|
return get_shape_from_layout(x, layout, cu_seqlens, max_seqlen), get_stride_from_layout(x, layout)
|
|
|
|
def get_strides_from_layout(q, k, v, o, layout):
|
|
q_strides = get_stride_from_layout(q, layout)
|
|
k_strides = get_stride_from_layout(k, layout)
|
|
v_strides = get_stride_from_layout(v, layout)
|
|
o_strides = get_stride_from_layout(o, layout)
|
|
return q_strides, k_strides, v_strides, o_strides
|
|
|
|
def get_padded_headsize(size):
|
|
# Get closest power of 2 over or equal to 32.
|
|
padded_d_model = 1 << (size - 1).bit_length()
|
|
# Smallest head_dim supported is 16. If smaller, the tile in the
|
|
# kernel is padded - there is no padding in memory for any dims.
|
|
padded_d_model = max(padded_d_model, 16)
|
|
return padded_d_model
|
|
|
|
def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k):
|
|
q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
|
|
k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K)
|
|
relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K)
|
|
return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
|
|
|
|
# -------------------------------
|
|
# Dropouts
|
|
# -------------------------------
|
|
def create_dropout_mask(dropout_p, shape, seed):
|
|
device = "cuda"
|
|
rand_vals = torch.rand(shape, generator=torch.Generator(device=device).manual_seed(seed), device=device, dtype=torch.float32)
|
|
return rand_vals > dropout_p
|
|
|
|
def create_dropout_mask_varlen(dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed):
|
|
device = "cuda"
|
|
qlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1])
|
|
klens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1])
|
|
max_qlen = qlens.max()
|
|
max_klen = klens.max()
|
|
dropout_mask = torch.zeros((batch, nheads_q, max_qlen, max_klen), device=device)
|
|
for b in range(batch):
|
|
qlen = qlens[b]
|
|
klen = klens[b]
|
|
rand_vals = torch.rand((nheads_q, qlen, klen), generator=torch.Generator(device=device).manual_seed(philox_seed), device=device, dtype=torch.float32)
|
|
submask = rand_vals > dropout_p
|
|
dropout_mask[b, :, :qlen, :klen] = submask
|
|
|
|
return dropout_mask
|
|
|
|
def write_dropout_mask(x, tensor_name = "tensor"):
|
|
batch, head, seqlen_m, seqlen_n = x.shape
|
|
x = x.tolist()
|
|
|
|
with open(f'{tensor_name}.csv', 'w') as f:
|
|
writer = csv.writer(f)
|
|
for b in range(batch):
|
|
for h in range(head):
|
|
dropout_mask = x[b][h]
|
|
if True:
|
|
BLOCK_M = 64
|
|
BLOCK_N = 64
|
|
|
|
# Calculate number of blocks in each dimension
|
|
m_blocks = math.ceil(seqlen_m / BLOCK_M)
|
|
n_blocks = math.ceil(seqlen_n / BLOCK_N)
|
|
|
|
# Process each block
|
|
for m_block in range(m_blocks):
|
|
# Calculate row range for current block
|
|
row_start = m_block * BLOCK_M
|
|
row_end = min(row_start + BLOCK_M, seqlen_m)
|
|
|
|
for n_block in range(n_blocks):
|
|
# Calculate column range for current block
|
|
col_start = n_block * BLOCK_N
|
|
col_end = min(col_start + BLOCK_N, seqlen_n)
|
|
|
|
# Extract and write the current block
|
|
for row_idx in range(row_start, row_end):
|
|
row_data = dropout_mask[row_idx][col_start:col_end]
|
|
writer.writerow(row_data)
|
|
else:
|
|
writer.writerows(dropout_mask)
|
|
|
|
# -------------------------------
|
|
# Runtime info
|
|
# -------------------------------
|
|
@functools.cache
|
|
def is_cdna():
|
|
return Agent(triton.runtime.driver.active.get_current_target().arch).arch == MicroArchitecture.CDNA
|
|
|
|
|
|
@functools.cache
|
|
def is_rdna():
|
|
return Agent(triton.runtime.driver.active.get_current_target().arch).arch == MicroArchitecture.RDNA
|