mirror of https://github.com/vladmandic/automatic
846 lines
26 KiB
Python
846 lines
26 KiB
Python
"""Sharpfin sparse matrix backend for Triton DDS matmul.
|
|
|
|
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
|
|
Adapted from https://github.com/stanford-futuredata/stk (Apache 2.0)
|
|
"""
|
|
|
|
import numpy as np
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
from typing import Tuple
|
|
from dataclasses import dataclass
|
|
from .triton_functional import linear_to_srgb_triton, srgb_to_linear_triton, magic_kernel_sharp_2021_triton, lanczos_triton
|
|
|
|
# Code is all adapted from https://github.com/stanford-futuredata/stk, licensed under Apache-2.0
|
|
# Very reduced set of functions for handling DDS (Dense = Dense @ Sparse) matmul only, with the
|
|
# DDS kernel modified to be more flexible on input shapes.
|
|
|
|
def _validate_matrix(shape, data, row_indices, column_indices, offsets):
|
|
if data.dim() == 1:
|
|
data = torch.reshape(data, [data.numel(), 1, 1])
|
|
|
|
if data.shape[-2] != data.shape[-1]:
|
|
raise ValueError(
|
|
"Expected square blocking in data. "
|
|
f"Got block shape {[data.shape[-2], data.shape[-1]]}")
|
|
|
|
block_size = data.shape[-1]
|
|
data = data.view([-1, block_size, block_size])
|
|
|
|
if data.dim() != 3:
|
|
raise ValueError(
|
|
"Expected 3D shape for data (nnz, block, block). "
|
|
f"Got shape {data.dim()}D shape.")
|
|
|
|
block_size = data.shape[1]
|
|
if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
|
|
raise ValueError(
|
|
"Matrix shape must be dividible by blocking. "
|
|
f"Got shape {shape} with "
|
|
f"{[block_size, block_size]} blocking.")
|
|
|
|
if np.prod(shape) < data.numel():
|
|
raise ValueError(
|
|
"Invalid matrix. Number of nonzeros exceeds matrix capacity "
|
|
f"({data.numel()} v. {np.prod(shape)})")
|
|
|
|
if row_indices.dim() != 1:
|
|
raise ValueError(
|
|
f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
|
|
|
|
if column_indices.dim() != 1:
|
|
raise ValueError(
|
|
f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
|
|
|
|
if offsets.dim() != 1:
|
|
raise ValueError(
|
|
f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
|
|
|
|
if row_indices.numel() != data.shape[0]:
|
|
raise ValueError(
|
|
"Expected 1 index per nonzero block. "
|
|
f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
|
|
|
|
if column_indices.numel() != data.shape[0]:
|
|
raise ValueError(
|
|
"Expected 1 index per nonzero block. "
|
|
f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
|
|
|
|
block_rows = np.prod(shape[:-1]) / block_size
|
|
if offsets.numel() != block_rows + 1:
|
|
raise ValueError(
|
|
"Expected one offset per block row plus one. "
|
|
f"Got {offsets.numel()} offsets with {block_rows} block rows.")
|
|
|
|
is_cuda = (data.is_cuda and
|
|
row_indices.is_cuda and
|
|
column_indices.is_cuda and
|
|
offsets.is_cuda)
|
|
is_cpu = (not data.is_cuda and
|
|
not row_indices.is_cuda and
|
|
not column_indices.is_cuda and
|
|
not offsets.is_cuda)
|
|
if not (is_cuda or is_cpu):
|
|
raise ValueError(
|
|
"Expected data & meta-data on common device. "
|
|
f"Got data on {data.device}, row_indices on {row_indices.device} "
|
|
f"column_indices on {column_indices.device} and "
|
|
f"offsets on {offsets.device}.")
|
|
|
|
if data.dtype != torch.float16:
|
|
raise ValueError(
|
|
f"Expected float16 data. Got {data.dtype} data.")
|
|
if row_indices.dtype != torch.int16:
|
|
raise ValueError(
|
|
f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
|
|
if column_indices.dtype != torch.int16:
|
|
raise ValueError(
|
|
f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
|
|
if offsets.dtype != torch.int32:
|
|
raise ValueError(
|
|
f"Expected int32 offsets. Got {offsets.dtype} offsets.")
|
|
return data
|
|
|
|
def _transpose(size, data: torch.Tensor, row_indices: torch.Tensor, column_indices: torch.Tensor, offsets):
|
|
block_columns = size[1] // data.shape[1]
|
|
|
|
gather_indices = column_indices.argsort()
|
|
column_indices_t = row_indices.gather(0, gather_indices)
|
|
block_offsets_t = gather_indices.int()
|
|
|
|
column_indices_float = column_indices.float()
|
|
|
|
zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
|
|
nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
|
|
nnz_per_column = nnz_per_column.int()
|
|
offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
|
|
return column_indices_t, offsets_t, block_offsets_t
|
|
|
|
class SBSCMatrix(torch.nn.Module):
|
|
"""Single Block Sparse Column (SBSC) matrix format."""
|
|
def __init__(
|
|
self,
|
|
size,
|
|
data: torch.Tensor,
|
|
offset: int,
|
|
block_size: int
|
|
):
|
|
super().__init__()
|
|
self.data = data
|
|
self.offset = offset
|
|
self.size = size
|
|
self.num_blocks = data.shape[0]
|
|
self.col_width = data.shape[2]
|
|
self.col_block_size = block_size
|
|
|
|
class Matrix(torch.nn.Module):
|
|
"""A matrix stored in block compressed sparse row (BCSR) format."""
|
|
|
|
def __init__(self,
|
|
size,
|
|
data: torch.Tensor,
|
|
row_indices: torch.Tensor,
|
|
column_indices: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
column_indices_t: torch.Tensor=None,
|
|
offsets_t: torch.Tensor=None,
|
|
block_offsets_t: torch.Tensor=None):
|
|
super().__init__()
|
|
self._size = size
|
|
self._data = data
|
|
self._row_indices = row_indices
|
|
self._column_indices = column_indices
|
|
self._offsets = offsets
|
|
|
|
if ((column_indices_t is None) or (offsets_t is None) or
|
|
(block_offsets_t is None)):
|
|
column_indices_t, offsets_t, block_offsets_t = _transpose(
|
|
size, data, row_indices, column_indices, offsets)
|
|
self._column_indices_t = column_indices_t
|
|
self._offsets_t = offsets_t
|
|
self._block_offsets_t = block_offsets_t
|
|
|
|
self._transposed = False
|
|
|
|
max_dim = np.iinfo(np.int16).max * self.blocking
|
|
if column_indices.dtype == torch.int16:
|
|
if size[0] > max_dim or size[1] > max_dim:
|
|
raise ValueError(
|
|
"Sparse matrix with shape {size} exceeds representable "
|
|
"size with 16-bit indices.")
|
|
|
|
def validate(self):
|
|
_validate_matrix(self._size,
|
|
self._data,
|
|
self._row_indices,
|
|
self._column_indices,
|
|
self._offsets)
|
|
|
|
def to(self, device):
|
|
self._data = self._data.to(device)
|
|
self._row_indices = self._row_indices.to(device)
|
|
self._column_indices = self._column_indices.to(device)
|
|
self._offsets = self._offsets.to(device)
|
|
self._column_indices_t = self._column_indices_t.to(device)
|
|
self._offsets_t = self._offsets_t.to(device)
|
|
self._block_offsets_t = self._block_offsets_t.to(device)
|
|
return self
|
|
|
|
def cuda(self):
|
|
return self.to(torch.cuda.current_device())
|
|
|
|
def clone(self):
|
|
return Matrix(
|
|
self.size(),
|
|
self.data.clone(),
|
|
self.row_indices.clone(),
|
|
self.column_indices.clone(),
|
|
self.offsets.clone(),
|
|
self.column_indices_t.clone(),
|
|
self.offsets_t.clone(),
|
|
self.block_offsets_t.clone())
|
|
|
|
def t(self):
|
|
if self.dim() != 2:
|
|
raise ValueError(
|
|
"t() expects a tensor with <= 2 dimensions, "
|
|
f"but self is {self.dim()}D.")
|
|
out = Matrix(self.size(),
|
|
self.data,
|
|
self.row_indices,
|
|
self.column_indices,
|
|
self.offsets,
|
|
self.column_indices_t,
|
|
self.offsets_t,
|
|
self.block_offsets_t)
|
|
out._transposed = not self._transposed
|
|
out._size = torch.Size((self._size[1], self._size[0]))
|
|
return out
|
|
|
|
def contiguous(self):
|
|
raise ValueError("Not yet implemented.")
|
|
|
|
def is_contiguous(self):
|
|
return not self._transposed
|
|
|
|
@property
|
|
def is_cuda(self):
|
|
return self._data.is_cuda
|
|
|
|
@property
|
|
def device(self):
|
|
return self._data.device
|
|
|
|
def size(self):
|
|
return self._size
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.size()
|
|
|
|
def dim(self):
|
|
return len(self._size)
|
|
|
|
@property
|
|
def data(self):
|
|
return self._data
|
|
|
|
@property
|
|
def row_indices(self):
|
|
return self._row_indices
|
|
|
|
@property
|
|
def column_indices(self):
|
|
return self._column_indices
|
|
|
|
@property
|
|
def offsets(self):
|
|
return self._offsets
|
|
|
|
@property
|
|
def offsets_t(self):
|
|
return self._offsets_t
|
|
|
|
@property
|
|
def column_indices_t(self):
|
|
return self._column_indices_t
|
|
|
|
@property
|
|
def block_offsets_t(self):
|
|
return self._block_offsets_t
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self.data.dtype
|
|
|
|
@property
|
|
def nnz(self):
|
|
return self.data.numel()
|
|
|
|
@property
|
|
def blocking(self):
|
|
return self.data.shape[1]
|
|
|
|
@property
|
|
def requires_grad(self):
|
|
return self.data.requires_grad
|
|
|
|
def requires_grad_(self, x):
|
|
self.data.requires_grad_(x)
|
|
return self
|
|
|
|
def view(self, *shape):
|
|
assert self.is_contiguous()
|
|
if shape[-1] != self.size()[-1]:
|
|
raise ValueError(
|
|
"Can't change view on compressed dimension. "
|
|
f"{self.size()[-1]} v. {shape[-1]}.")
|
|
if np.prod(shape) != np.prod(self.size()):
|
|
raise ValueError(
|
|
"Mismatch in numel of Matrix and new shape. "
|
|
f"{np.prod(self.size())} v. {np.prod(shape)}")
|
|
return Matrix(shape,
|
|
self.data,
|
|
self.row_indices,
|
|
self.column_indices,
|
|
self.offsets,
|
|
self.column_indices_t,
|
|
self.offsets_t,
|
|
self.block_offsets_t)
|
|
|
|
@property
|
|
def grad(self):
|
|
size = self.size()
|
|
if not self.is_contiguous():
|
|
size = torch.Size((size[1], size[0]))
|
|
out = Matrix(size,
|
|
self.data.grad,
|
|
self.row_indices,
|
|
self.column_indices,
|
|
self.offsets,
|
|
self.column_indices_t,
|
|
self.offsets_t,
|
|
self.block_offsets_t)
|
|
return out if self.is_contiguous() else out.t()
|
|
|
|
@torch.no_grad()
|
|
def _expand_for_blocking(idxs, blocking):
|
|
idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
|
|
|
|
idxs[:, :, 1] *= blocking
|
|
idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
|
|
|
|
idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
|
|
idxs = idxs.repeat(1, blocking, 1, 1)
|
|
|
|
idxs[:, :, :, 0] *= blocking
|
|
idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
|
|
idxs = torch.reshape(idxs, [-1, 2])
|
|
return idxs
|
|
|
|
|
|
@torch.no_grad()
|
|
def to_dense(x):
|
|
assert isinstance(x, Matrix)
|
|
|
|
shape = (np.prod(x.shape[:-1]), x.shape[-1])
|
|
row_idxs = x.row_indices.type(torch.int32)
|
|
col_idxs = x.column_indices.type(torch.int32)
|
|
indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
|
|
indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
|
|
|
|
out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
|
|
out.scatter_(0, indices, x.data.flatten())
|
|
return out.reshape(x.size())
|
|
|
|
|
|
@dataclass
|
|
class TritonConfig:
|
|
BLOCK_M: int = 128
|
|
BLOCK_N: int = 128
|
|
BLOCK_K: int = 32
|
|
BLOCK_SIZE: int = 64
|
|
NUM_STAGES: int = 4
|
|
NUM_WARPS: int = 4
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({}, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
|
|
],
|
|
key=['M', 'N', 'K'],
|
|
)
|
|
@triton.jit
|
|
def _dds_kernel(
|
|
A: tl.tensor, B: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N,
|
|
stride_ac, stride_am, stride_ak,
|
|
stride_bk, stride_bn,
|
|
stride_cc, stride_cm, stride_cn,
|
|
row_indices: tl.tensor, column_indices: tl.tensor,
|
|
offsets: tl.tensor, block_offsets_t: tl.tensor,
|
|
fuse_srgb: tl.constexpr, clamp_output: tl.constexpr,
|
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
|
BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
|
|
):
|
|
|
|
pid_c = tl.program_id(0)
|
|
pid_m = tl.program_id(1)
|
|
pid_n = tl.program_id(2)
|
|
|
|
num_pid_m = tl.num_programs(1)
|
|
num_pid_n = tl.num_programs(2)
|
|
|
|
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
|
|
|
|
offsets += pid_n
|
|
|
|
start_inx = tl.load(offsets)
|
|
end_inx = tl.load(offsets + 1)
|
|
|
|
column_indices += start_inx
|
|
block_offsets_t += start_inx
|
|
|
|
BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
|
|
|
|
A_block_ptr = tl.make_block_ptr(
|
|
base=A + pid_c * stride_ac, shape=(M, K),
|
|
strides=(stride_am, stride_ak),
|
|
offsets=(pid_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_K),
|
|
order=(0, 1)
|
|
)
|
|
|
|
rn = tl.arange(0, BLOCK_N)
|
|
rbk = tl.arange(0, BLOCK_K)
|
|
|
|
B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
|
|
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16)
|
|
nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
|
|
|
|
bk_sub_incr = BLOCK_K * stride_bk
|
|
|
|
for block_inx in range(end_inx - start_inx):
|
|
a_col_idx = tl.load(column_indices + block_inx)
|
|
ptr_A = tl.advance(A_block_ptr, (0, a_col_idx * BLOCK_SIZE))
|
|
|
|
b_block_offset = tl.load(block_offsets_t + block_inx)
|
|
ptr_B = B + b_block_offset * BLOCK_ELEMENTS
|
|
|
|
for sub_block_inx in range(nsub_blocks):
|
|
a = tl.load(ptr_A)
|
|
b = tl.load(ptr_B)
|
|
|
|
acc = tl.dot(a, b, acc, out_dtype=tl.float16)
|
|
|
|
ptr_A = tl.advance(ptr_A, (0, BLOCK_K))
|
|
ptr_B += bk_sub_incr
|
|
|
|
if fuse_srgb:
|
|
acc = linear_to_srgb_triton(acc)
|
|
|
|
if clamp_output:
|
|
acc = tl.clamp(acc, 0.0, 1.0)
|
|
|
|
if fuse_srgb or clamp_output:
|
|
acc = acc.to(C.dtype.element_ty)
|
|
|
|
C_block_ptr = tl.make_block_ptr(
|
|
base=C + pid_c * stride_cc, shape=(O_M, O_N),
|
|
strides=(stride_cm, stride_cn),
|
|
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
|
|
block_shape=(BLOCK_M, BLOCK_N),
|
|
order=(1, 0)
|
|
)
|
|
|
|
tl.store(C_block_ptr, acc, boundary_check=(0, 1))
|
|
|
|
|
|
def triton_dds(
|
|
lhs: torch.Tensor,
|
|
rhs: Matrix,
|
|
fuse_srgb: bool = False,
|
|
clamp_output: bool = False,
|
|
output_mt: bool = False,
|
|
output_slice: None | Tuple[int,int] = None
|
|
):
|
|
assert isinstance(lhs, torch.Tensor)
|
|
assert isinstance(rhs, Matrix)
|
|
assert lhs.ndim == 3
|
|
CH = lhs.shape[0]
|
|
stride_ac = lhs.stride(0)
|
|
|
|
M, K = lhs.shape[-2:]
|
|
N = rhs.shape[-1]
|
|
|
|
if output_mt:
|
|
if output_slice is not None:
|
|
O_N, O_M = output_slice
|
|
out = torch.empty(
|
|
(*lhs.shape[:-2], *output_slice),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
else:
|
|
O_N, O_M = N, M
|
|
out = torch.empty(
|
|
(*lhs.shape[:-2], rhs.shape[1], lhs.shape[-2]),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
stride_cm, stride_cn = out.stride(-1), out.stride(-2)
|
|
stride_cc = out.stride(-3)
|
|
else:
|
|
if output_slice is not None:
|
|
O_M, O_N = output_slice
|
|
out = torch.empty(
|
|
(*lhs.shape[:-2], *output_slice),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
else:
|
|
O_M, O_N = M, N
|
|
out = torch.empty(
|
|
(*lhs.shape[:-1], rhs.shape[1]),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
stride_cm, stride_cn = out.stride(-2), out.stride(-1)
|
|
stride_cc = out.stride(-3)
|
|
|
|
trans_B = not rhs.is_contiguous()
|
|
trans_A = (lhs.stride(-2) > 1 and lhs.stride(-1) > 1)
|
|
assert trans_A == False, trans_B == False
|
|
|
|
assert lhs.shape[-1] <= rhs.shape[0], "incompatible dimensions"
|
|
|
|
stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1)
|
|
|
|
if trans_B:
|
|
stride_bk, stride_bn = rhs.data.stride(2), rhs.data.stride(1)
|
|
b_column_indices, b_offsets = rhs.column_indices, rhs.offsets
|
|
else:
|
|
stride_bk, stride_bn = rhs.data.stride(1), rhs.data.stride(2)
|
|
b_column_indices, b_offsets = rhs.column_indices_t, rhs.offsets_t
|
|
|
|
grid = lambda META: (CH, triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
|
|
|
|
_dds_kernel[grid](
|
|
lhs, rhs.data, out, M, N, K, O_M, O_N,
|
|
stride_ac, stride_am, stride_ak,
|
|
stride_bk, stride_bn,
|
|
stride_cc, stride_cm, stride_cn,
|
|
rhs.row_indices, b_column_indices, b_offsets,
|
|
rhs.block_offsets_t, fuse_srgb, clamp_output,
|
|
GROUP_M=128, ACC_TYPE=tl.float16, BLOCK_M=min(rhs.data.shape[1], 64),
|
|
BLOCK_N=rhs.data.shape[1], BLOCK_SIZE=rhs.data.shape[1], BLOCK_K=min(rhs.data.shape[1], 64)
|
|
)
|
|
return out
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({}, num_stages=4, num_warps=2),
|
|
],
|
|
key=['BLOCK_SIZE', 'BLOCK_N'],
|
|
)
|
|
@triton.jit
|
|
def _dds_sbsc_kernel(
|
|
A: tl.tensor, B: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N,
|
|
stride_ac, stride_am, stride_ak,
|
|
stride_bb, stride_bk, stride_bn,
|
|
stride_cc, stride_cm, stride_cn,
|
|
block_offset: tl.constexpr,
|
|
fuse_srgb: tl.constexpr, clamp_output: tl.constexpr,
|
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
|
BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
|
|
):
|
|
|
|
pid_n = tl.program_id(0)
|
|
pid_m = tl.program_id(1)
|
|
pid_c = tl.program_id(2)
|
|
|
|
nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
|
|
|
|
start_row = block_offset * pid_n
|
|
|
|
A_block_ptr = tl.make_block_ptr(
|
|
base=A + pid_c * stride_ac, shape=(M, K),
|
|
strides=(stride_am, stride_ak),
|
|
offsets=(pid_m * BLOCK_M, start_row),
|
|
block_shape=(BLOCK_M, BLOCK_K),
|
|
order=(0, 1)
|
|
)
|
|
|
|
B_block_ptr = tl.make_block_ptr(
|
|
base=B + pid_n * stride_bb, shape=(BLOCK_SIZE, BLOCK_N),
|
|
strides=(stride_bk, stride_bn),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_K, BLOCK_N),
|
|
order=(0, 1)
|
|
)
|
|
|
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
|
|
|
|
|
for block_slice in range(nsub_blocks):
|
|
a = tl.load(A_block_ptr, eviction_policy='evict_first', boundary_check=(0,), padding_option='zero')
|
|
b = tl.load(B_block_ptr, eviction_policy='evict_last')
|
|
|
|
acc = tl.dot(a, b, acc, out_dtype=tl.float32)
|
|
|
|
A_block_ptr = A_block_ptr.advance((0, BLOCK_K))
|
|
B_block_ptr = B_block_ptr.advance((BLOCK_K, 0))
|
|
|
|
if fuse_srgb:
|
|
acc = linear_to_srgb_triton(acc)
|
|
|
|
if clamp_output:
|
|
acc = tl.clamp(acc, 0.0, 1.0)
|
|
|
|
acc = acc.to(C.dtype.element_ty)
|
|
|
|
C_block_ptr = tl.make_block_ptr(
|
|
base=C + pid_c * stride_cc, shape=(O_M, O_N),
|
|
strides=(stride_cm, stride_cn),
|
|
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
|
|
block_shape=(BLOCK_M, BLOCK_N),
|
|
order=(1, 0)
|
|
)
|
|
|
|
tl.store(C_block_ptr, acc, boundary_check=(0, 1), cache_modifier='.cs')
|
|
|
|
def triton_dds_sbsc(
|
|
lhs: torch.Tensor,
|
|
rhs: SBSCMatrix,
|
|
fuse_srgb: bool = False,
|
|
clamp_output: bool = False,
|
|
output_mt: bool = False,
|
|
output_slice: None | Tuple[int,int] = None
|
|
):
|
|
assert isinstance(lhs, torch.Tensor)
|
|
assert isinstance(rhs, SBSCMatrix)
|
|
assert lhs.ndim == 3
|
|
CH = lhs.shape[0]
|
|
stride_ac = lhs.stride(0)
|
|
|
|
M, K = lhs.shape[-2:]
|
|
N = rhs.size[-1]
|
|
|
|
if output_mt:
|
|
if output_slice is not None:
|
|
O_N, O_M = output_slice
|
|
out = torch.empty(
|
|
(*lhs.shape[:-2], *output_slice),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
else:
|
|
O_N, O_M = N, M
|
|
out = torch.empty(
|
|
(*lhs.shape[:-2], rhs.size[1], lhs.shape[-2]),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
stride_cm, stride_cn = out.stride(-1), out.stride(-2)
|
|
stride_cc = out.stride(-3)
|
|
else:
|
|
if output_slice is not None:
|
|
O_M, O_N = output_slice
|
|
out = torch.empty(
|
|
(*lhs.shape[:-2], *output_slice),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
else:
|
|
O_M, O_N = M, N
|
|
out = torch.empty(
|
|
(*lhs.shape[:-1], rhs.size[1]),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
stride_cm, stride_cn = out.stride(-2), out.stride(-1)
|
|
stride_cc = out.stride(-3)
|
|
|
|
assert lhs.shape[-1] <= rhs.size[0], f"incompatible dimensions: {lhs.shape[-1]} > {rhs.size[0]}"
|
|
|
|
stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1)
|
|
|
|
stride_bb, stride_bk, stride_bn = rhs.data.stride(0), rhs.data.stride(1), rhs.data.stride(2)
|
|
|
|
grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), triton.cdiv(M, META['BLOCK_M']), CH)
|
|
|
|
_dds_sbsc_kernel[grid](
|
|
lhs, rhs.data, out, M, N, K, O_M, O_N,
|
|
stride_ac, stride_am, stride_ak,
|
|
stride_bb, stride_bk, stride_bn,
|
|
stride_cc, stride_cm, stride_cn,
|
|
rhs.offset, fuse_srgb, clamp_output,
|
|
GROUP_M=32, ACC_TYPE=tl.float16, BLOCK_M=32,
|
|
BLOCK_N=rhs.data.shape[2], BLOCK_SIZE=rhs.data.shape[1], BLOCK_K=rhs.col_block_size
|
|
)
|
|
return out
|
|
|
|
from triton.language.extra import libdevice
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({}, num_stages=4, num_warps=2),
|
|
],
|
|
key=['BLOCK_SIZE', 'BLOCK_N'],
|
|
)
|
|
@triton.jit
|
|
def _dds_sbsc_zerorhs_kernel(
|
|
A: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N,
|
|
stride_ac, stride_am, stride_ak,
|
|
stride_cc, stride_cm, stride_cn,
|
|
k, PAD, block_offset: tl.constexpr,
|
|
fuse_srgb: tl.constexpr, gamma_correction: tl.constexpr, clamp_output: tl.constexpr,
|
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
|
|
pid_n = tl.program_id(0)
|
|
pid_m = tl.program_id(1)
|
|
pid_c = tl.program_id(2)
|
|
|
|
nsub_blocks = triton.cdiv(BLOCK_SIZE, BLOCK_K)
|
|
|
|
start_row = block_offset * pid_n
|
|
|
|
offs_k = (start_row + tl.arange(0, BLOCK_K)) * stride_ak
|
|
m_range = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
A_mask = (m_range < M)[None, :].broadcast_to(BLOCK_K, BLOCK_M)
|
|
|
|
A_M_ptr = A + pid_c * stride_ac + stride_am * m_range
|
|
|
|
b_k = ((start_row - PAD + tl.arange(0, BLOCK_K)).to(tl.float32) + 0.5) * k
|
|
b_n = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.float32) + 0.5
|
|
|
|
b_base = (b_k[None, :] - b_n[:, None])
|
|
|
|
acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float16)
|
|
|
|
for _ in tl.range(nsub_blocks):
|
|
A_ptr = A_M_ptr[None, :] + tl.minimum(tl.maximum(offs_k, PAD) - PAD, K - 1)[:, None]
|
|
|
|
b = magic_kernel_sharp_2021_triton(b_base) * k
|
|
|
|
b = b.to(tl.float16)
|
|
|
|
a = tl.load(A_ptr, mask=A_mask)
|
|
|
|
if fuse_srgb == 'input':
|
|
if gamma_correction == 'fast':
|
|
a = libdevice.fast_powf(a, 2.2).to(tl.float16)
|
|
elif gamma_correction == 'srgb':
|
|
a = srgb_to_linear_triton(a).to(tl.float16)
|
|
|
|
acc = tl.dot(b, a, acc, out_dtype=tl.float16)
|
|
|
|
offs_k += BLOCK_K * stride_ak
|
|
b_base += BLOCK_K * k
|
|
|
|
if fuse_srgb == 'output':
|
|
if gamma_correction == 'fast':
|
|
acc = libdevice.fast_powf(acc, 1.0/2.2)
|
|
elif gamma_correction == 'srgb':
|
|
acc = linear_to_srgb_triton(acc)
|
|
|
|
if clamp_output:
|
|
acc = tl.clamp(acc, 0.0, 1.0)
|
|
|
|
if fuse_srgb == 'output' or clamp_output:
|
|
acc = acc.to(C.dtype.element_ty)
|
|
|
|
C_block_ptr = tl.make_block_ptr(
|
|
base=C + pid_c * stride_cc, shape=(O_N, O_M),
|
|
strides=(stride_cn, stride_cm),
|
|
offsets=(pid_n * BLOCK_N, pid_m * BLOCK_M),
|
|
block_shape=(BLOCK_N, BLOCK_M),
|
|
order=(1, 0)
|
|
)
|
|
|
|
tl.store(C_block_ptr, acc, boundary_check=(0, 1), cache_modifier='.cs')
|
|
|
|
import math
|
|
|
|
|
|
def triton_dds_zerorhs_sbsc(
|
|
lhs: torch.Tensor,
|
|
target_size: int,
|
|
source_size: int,
|
|
kernel_window: float,
|
|
block_specs,
|
|
fuse_srgb: str = '',
|
|
gamma_correction: str = 'fast',
|
|
clamp_output: bool = False,
|
|
output_mt: bool = False,
|
|
output_slice: None | Tuple[int,int] = None
|
|
):
|
|
assert isinstance(lhs, torch.Tensor)
|
|
|
|
assert fuse_srgb in ['input', 'output', '']
|
|
assert gamma_correction in ['fast', 'srgb']
|
|
|
|
k = target_size / source_size
|
|
|
|
PAD = math.ceil((kernel_window - 0.5) / k)
|
|
|
|
offset, block_height, num_blocks, col_width = block_specs
|
|
|
|
assert lhs.ndim == 3
|
|
CH = lhs.shape[0]
|
|
stride_ac = lhs.stride(0)
|
|
|
|
M, K = lhs.shape[-2:]
|
|
N = target_size
|
|
|
|
if output_mt:
|
|
if output_slice is not None:
|
|
O_N, O_M = output_slice
|
|
out = torch.empty(
|
|
(*lhs.shape[:-2], *output_slice),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
else:
|
|
O_N, O_M = N, M
|
|
out = torch.empty(
|
|
(*lhs.shape[:-2], N, M),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
stride_cm, stride_cn = out.stride(-1), out.stride(-2)
|
|
stride_cc = out.stride(-3)
|
|
else:
|
|
if output_slice is not None:
|
|
O_M, O_N = output_slice
|
|
out = torch.empty(
|
|
(*lhs.shape[:-2], *output_slice),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
else:
|
|
O_M, O_N = M, N
|
|
out = torch.empty(
|
|
(*lhs.shape[:-2], M, N),
|
|
dtype=lhs.dtype,
|
|
device=lhs.device
|
|
)
|
|
stride_cm, stride_cn = out.stride(-2), out.stride(-1)
|
|
stride_cc = out.stride(-3)
|
|
|
|
stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1)
|
|
|
|
grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), triton.cdiv(M, META['BLOCK_M']), CH)
|
|
|
|
_dds_sbsc_zerorhs_kernel[grid](
|
|
lhs, out, M, N, K, O_M, O_N,
|
|
stride_ac, stride_am, stride_ak,
|
|
stride_cc, stride_cm, stride_cn,
|
|
k, PAD, offset, fuse_srgb, gamma_correction, clamp_output,
|
|
BLOCK_M=32, BLOCK_K=16, BLOCK_N=col_width, BLOCK_SIZE=block_height,
|
|
)
|
|
return out
|