"""Sharpfin sparse matrix backend for Triton DDS matmul. Vendored from https://github.com/drhead/sharpfin (Apache 2.0) Adapted from https://github.com/stanford-futuredata/stk (Apache 2.0) """ import numpy as np import torch import triton import triton.language as tl from typing import Tuple from dataclasses import dataclass from .triton_functional import linear_to_srgb_triton, srgb_to_linear_triton, magic_kernel_sharp_2021_triton, lanczos_triton # Code is all adapted from https://github.com/stanford-futuredata/stk, licensed under Apache-2.0 # Very reduced set of functions for handling DDS (Dense = Dense @ Sparse) matmul only, with the # DDS kernel modified to be more flexible on input shapes. def _validate_matrix(shape, data, row_indices, column_indices, offsets): if data.dim() == 1: data = torch.reshape(data, [data.numel(), 1, 1]) if data.shape[-2] != data.shape[-1]: raise ValueError( "Expected square blocking in data. " f"Got block shape {[data.shape[-2], data.shape[-1]]}") block_size = data.shape[-1] data = data.view([-1, block_size, block_size]) if data.dim() != 3: raise ValueError( "Expected 3D shape for data (nnz, block, block). " f"Got shape {data.dim()}D shape.") block_size = data.shape[1] if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: raise ValueError( "Matrix shape must be dividible by blocking. " f"Got shape {shape} with " f"{[block_size, block_size]} blocking.") if np.prod(shape) < data.numel(): raise ValueError( "Invalid matrix. Number of nonzeros exceeds matrix capacity " f"({data.numel()} v. {np.prod(shape)})") if row_indices.dim() != 1: raise ValueError( f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") if column_indices.dim() != 1: raise ValueError( f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") if offsets.dim() != 1: raise ValueError( f"Expected 1D offsets. Got {offsets.dim()}D offsets.") if row_indices.numel() != data.shape[0]: raise ValueError( "Expected 1 index per nonzero block. " f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") if column_indices.numel() != data.shape[0]: raise ValueError( "Expected 1 index per nonzero block. " f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") block_rows = np.prod(shape[:-1]) / block_size if offsets.numel() != block_rows + 1: raise ValueError( "Expected one offset per block row plus one. " f"Got {offsets.numel()} offsets with {block_rows} block rows.") is_cuda = (data.is_cuda and row_indices.is_cuda and column_indices.is_cuda and offsets.is_cuda) is_cpu = (not data.is_cuda and not row_indices.is_cuda and not column_indices.is_cuda and not offsets.is_cuda) if not (is_cuda or is_cpu): raise ValueError( "Expected data & meta-data on common device. " f"Got data on {data.device}, row_indices on {row_indices.device} " f"column_indices on {column_indices.device} and " f"offsets on {offsets.device}.") if data.dtype != torch.float16: raise ValueError( f"Expected float16 data. Got {data.dtype} data.") if row_indices.dtype != torch.int16: raise ValueError( f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") if column_indices.dtype != torch.int16: raise ValueError( f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") if offsets.dtype != torch.int32: raise ValueError( f"Expected int32 offsets. Got {offsets.dtype} offsets.") return data def _transpose(size, data: torch.Tensor, row_indices: torch.Tensor, column_indices: torch.Tensor, offsets): block_columns = size[1] // data.shape[1] gather_indices = column_indices.argsort() column_indices_t = row_indices.gather(0, gather_indices) block_offsets_t = gather_indices.int() column_indices_float = column_indices.float() zero = torch.zeros((1,), dtype=torch.int32, device=data.device) nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) nnz_per_column = nnz_per_column.int() offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) return column_indices_t, offsets_t, block_offsets_t class SBSCMatrix(torch.nn.Module): """Single Block Sparse Column (SBSC) matrix format.""" def __init__( self, size, data: torch.Tensor, offset: int, block_size: int ): super().__init__() self.data = data self.offset = offset self.size = size self.num_blocks = data.shape[0] self.col_width = data.shape[2] self.col_block_size = block_size class Matrix(torch.nn.Module): """A matrix stored in block compressed sparse row (BCSR) format.""" def __init__(self, size, data: torch.Tensor, row_indices: torch.Tensor, column_indices: torch.Tensor, offsets: torch.Tensor, column_indices_t: torch.Tensor=None, offsets_t: torch.Tensor=None, block_offsets_t: torch.Tensor=None): super().__init__() self._size = size self._data = data self._row_indices = row_indices self._column_indices = column_indices self._offsets = offsets if ((column_indices_t is None) or (offsets_t is None) or (block_offsets_t is None)): column_indices_t, offsets_t, block_offsets_t = _transpose( size, data, row_indices, column_indices, offsets) self._column_indices_t = column_indices_t self._offsets_t = offsets_t self._block_offsets_t = block_offsets_t self._transposed = False max_dim = np.iinfo(np.int16).max * self.blocking if column_indices.dtype == torch.int16: if size[0] > max_dim or size[1] > max_dim: raise ValueError( "Sparse matrix with shape {size} exceeds representable " "size with 16-bit indices.") def validate(self): _validate_matrix(self._size, self._data, self._row_indices, self._column_indices, self._offsets) def to(self, device): self._data = self._data.to(device) self._row_indices = self._row_indices.to(device) self._column_indices = self._column_indices.to(device) self._offsets = self._offsets.to(device) self._column_indices_t = self._column_indices_t.to(device) self._offsets_t = self._offsets_t.to(device) self._block_offsets_t = self._block_offsets_t.to(device) return self def cuda(self): return self.to(torch.cuda.current_device()) def clone(self): return Matrix( self.size(), self.data.clone(), self.row_indices.clone(), self.column_indices.clone(), self.offsets.clone(), self.column_indices_t.clone(), self.offsets_t.clone(), self.block_offsets_t.clone()) def t(self): if self.dim() != 2: raise ValueError( "t() expects a tensor with <= 2 dimensions, " f"but self is {self.dim()}D.") out = Matrix(self.size(), self.data, self.row_indices, self.column_indices, self.offsets, self.column_indices_t, self.offsets_t, self.block_offsets_t) out._transposed = not self._transposed out._size = torch.Size((self._size[1], self._size[0])) return out def contiguous(self): raise ValueError("Not yet implemented.") def is_contiguous(self): return not self._transposed @property def is_cuda(self): return self._data.is_cuda @property def device(self): return self._data.device def size(self): return self._size @property def shape(self): return self.size() def dim(self): return len(self._size) @property def data(self): return self._data @property def row_indices(self): return self._row_indices @property def column_indices(self): return self._column_indices @property def offsets(self): return self._offsets @property def offsets_t(self): return self._offsets_t @property def column_indices_t(self): return self._column_indices_t @property def block_offsets_t(self): return self._block_offsets_t @property def dtype(self): return self.data.dtype @property def nnz(self): return self.data.numel() @property def blocking(self): return self.data.shape[1] @property def requires_grad(self): return self.data.requires_grad def requires_grad_(self, x): self.data.requires_grad_(x) return self def view(self, *shape): assert self.is_contiguous() if shape[-1] != self.size()[-1]: raise ValueError( "Can't change view on compressed dimension. " f"{self.size()[-1]} v. {shape[-1]}.") if np.prod(shape) != np.prod(self.size()): raise ValueError( "Mismatch in numel of Matrix and new shape. " f"{np.prod(self.size())} v. {np.prod(shape)}") return Matrix(shape, self.data, self.row_indices, self.column_indices, self.offsets, self.column_indices_t, self.offsets_t, self.block_offsets_t) @property def grad(self): size = self.size() if not self.is_contiguous(): size = torch.Size((size[1], size[0])) out = Matrix(size, self.data.grad, self.row_indices, self.column_indices, self.offsets, self.column_indices_t, self.offsets_t, self.block_offsets_t) return out if self.is_contiguous() else out.t() @torch.no_grad() def _expand_for_blocking(idxs, blocking): idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) idxs[:, :, 1] *= blocking idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) idxs = idxs.repeat(1, blocking, 1, 1) idxs[:, :, :, 0] *= blocking idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) idxs = torch.reshape(idxs, [-1, 2]) return idxs @torch.no_grad() def to_dense(x): assert isinstance(x, Matrix) shape = (np.prod(x.shape[:-1]), x.shape[-1]) row_idxs = x.row_indices.type(torch.int32) col_idxs = x.column_indices.type(torch.int32) indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) out.scatter_(0, indices, x.data.flatten()) return out.reshape(x.size()) @dataclass class TritonConfig: BLOCK_M: int = 128 BLOCK_N: int = 128 BLOCK_K: int = 32 BLOCK_SIZE: int = 64 NUM_STAGES: int = 4 NUM_WARPS: int = 4 @triton.autotune( configs=[ triton.Config({}, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), ], key=['M', 'N', 'K'], ) @triton.jit def _dds_kernel( A: tl.tensor, B: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N, stride_ac, stride_am, stride_ak, stride_bk, stride_bn, stride_cc, stride_cm, stride_cn, row_indices: tl.tensor, column_indices: tl.tensor, offsets: tl.tensor, block_offsets_t: tl.tensor, fuse_srgb: tl.constexpr, clamp_output: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, ): pid_c = tl.program_id(0) pid_m = tl.program_id(1) pid_n = tl.program_id(2) num_pid_m = tl.num_programs(1) num_pid_n = tl.num_programs(2) pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) offsets += pid_n start_inx = tl.load(offsets) end_inx = tl.load(offsets + 1) column_indices += start_inx block_offsets_t += start_inx BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE A_block_ptr = tl.make_block_ptr( base=A + pid_c * stride_ac, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(0, 1) ) rn = tl.arange(0, BLOCK_N) rbk = tl.arange(0, BLOCK_K) B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16) nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) bk_sub_incr = BLOCK_K * stride_bk for block_inx in range(end_inx - start_inx): a_col_idx = tl.load(column_indices + block_inx) ptr_A = tl.advance(A_block_ptr, (0, a_col_idx * BLOCK_SIZE)) b_block_offset = tl.load(block_offsets_t + block_inx) ptr_B = B + b_block_offset * BLOCK_ELEMENTS for sub_block_inx in range(nsub_blocks): a = tl.load(ptr_A) b = tl.load(ptr_B) acc = tl.dot(a, b, acc, out_dtype=tl.float16) ptr_A = tl.advance(ptr_A, (0, BLOCK_K)) ptr_B += bk_sub_incr if fuse_srgb: acc = linear_to_srgb_triton(acc) if clamp_output: acc = tl.clamp(acc, 0.0, 1.0) if fuse_srgb or clamp_output: acc = acc.to(C.dtype.element_ty) C_block_ptr = tl.make_block_ptr( base=C + pid_c * stride_cc, shape=(O_M, O_N), strides=(stride_cm, stride_cn), offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0) ) tl.store(C_block_ptr, acc, boundary_check=(0, 1)) def triton_dds( lhs: torch.Tensor, rhs: Matrix, fuse_srgb: bool = False, clamp_output: bool = False, output_mt: bool = False, output_slice: None | Tuple[int,int] = None ): assert isinstance(lhs, torch.Tensor) assert isinstance(rhs, Matrix) assert lhs.ndim == 3 CH = lhs.shape[0] stride_ac = lhs.stride(0) M, K = lhs.shape[-2:] N = rhs.shape[-1] if output_mt: if output_slice is not None: O_N, O_M = output_slice out = torch.empty( (*lhs.shape[:-2], *output_slice), dtype=lhs.dtype, device=lhs.device ) else: O_N, O_M = N, M out = torch.empty( (*lhs.shape[:-2], rhs.shape[1], lhs.shape[-2]), dtype=lhs.dtype, device=lhs.device ) stride_cm, stride_cn = out.stride(-1), out.stride(-2) stride_cc = out.stride(-3) else: if output_slice is not None: O_M, O_N = output_slice out = torch.empty( (*lhs.shape[:-2], *output_slice), dtype=lhs.dtype, device=lhs.device ) else: O_M, O_N = M, N out = torch.empty( (*lhs.shape[:-1], rhs.shape[1]), dtype=lhs.dtype, device=lhs.device ) stride_cm, stride_cn = out.stride(-2), out.stride(-1) stride_cc = out.stride(-3) trans_B = not rhs.is_contiguous() trans_A = (lhs.stride(-2) > 1 and lhs.stride(-1) > 1) assert trans_A == False, trans_B == False assert lhs.shape[-1] <= rhs.shape[0], "incompatible dimensions" stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1) if trans_B: stride_bk, stride_bn = rhs.data.stride(2), rhs.data.stride(1) b_column_indices, b_offsets = rhs.column_indices, rhs.offsets else: stride_bk, stride_bn = rhs.data.stride(1), rhs.data.stride(2) b_column_indices, b_offsets = rhs.column_indices_t, rhs.offsets_t grid = lambda META: (CH, triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) _dds_kernel[grid]( lhs, rhs.data, out, M, N, K, O_M, O_N, stride_ac, stride_am, stride_ak, stride_bk, stride_bn, stride_cc, stride_cm, stride_cn, rhs.row_indices, b_column_indices, b_offsets, rhs.block_offsets_t, fuse_srgb, clamp_output, GROUP_M=128, ACC_TYPE=tl.float16, BLOCK_M=min(rhs.data.shape[1], 64), BLOCK_N=rhs.data.shape[1], BLOCK_SIZE=rhs.data.shape[1], BLOCK_K=min(rhs.data.shape[1], 64) ) return out @triton.autotune( configs=[ triton.Config({}, num_stages=4, num_warps=2), ], key=['BLOCK_SIZE', 'BLOCK_N'], ) @triton.jit def _dds_sbsc_kernel( A: tl.tensor, B: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N, stride_ac, stride_am, stride_ak, stride_bb, stride_bk, stride_bn, stride_cc, stride_cm, stride_cn, block_offset: tl.constexpr, fuse_srgb: tl.constexpr, clamp_output: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, ): pid_n = tl.program_id(0) pid_m = tl.program_id(1) pid_c = tl.program_id(2) nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) start_row = block_offset * pid_n A_block_ptr = tl.make_block_ptr( base=A + pid_c * stride_ac, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid_m * BLOCK_M, start_row), block_shape=(BLOCK_M, BLOCK_K), order=(0, 1) ) B_block_ptr = tl.make_block_ptr( base=B + pid_n * stride_bb, shape=(BLOCK_SIZE, BLOCK_N), strides=(stride_bk, stride_bn), offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1) ) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for block_slice in range(nsub_blocks): a = tl.load(A_block_ptr, eviction_policy='evict_first', boundary_check=(0,), padding_option='zero') b = tl.load(B_block_ptr, eviction_policy='evict_last') acc = tl.dot(a, b, acc, out_dtype=tl.float32) A_block_ptr = A_block_ptr.advance((0, BLOCK_K)) B_block_ptr = B_block_ptr.advance((BLOCK_K, 0)) if fuse_srgb: acc = linear_to_srgb_triton(acc) if clamp_output: acc = tl.clamp(acc, 0.0, 1.0) acc = acc.to(C.dtype.element_ty) C_block_ptr = tl.make_block_ptr( base=C + pid_c * stride_cc, shape=(O_M, O_N), strides=(stride_cm, stride_cn), offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0) ) tl.store(C_block_ptr, acc, boundary_check=(0, 1), cache_modifier='.cs') def triton_dds_sbsc( lhs: torch.Tensor, rhs: SBSCMatrix, fuse_srgb: bool = False, clamp_output: bool = False, output_mt: bool = False, output_slice: None | Tuple[int,int] = None ): assert isinstance(lhs, torch.Tensor) assert isinstance(rhs, SBSCMatrix) assert lhs.ndim == 3 CH = lhs.shape[0] stride_ac = lhs.stride(0) M, K = lhs.shape[-2:] N = rhs.size[-1] if output_mt: if output_slice is not None: O_N, O_M = output_slice out = torch.empty( (*lhs.shape[:-2], *output_slice), dtype=lhs.dtype, device=lhs.device ) else: O_N, O_M = N, M out = torch.empty( (*lhs.shape[:-2], rhs.size[1], lhs.shape[-2]), dtype=lhs.dtype, device=lhs.device ) stride_cm, stride_cn = out.stride(-1), out.stride(-2) stride_cc = out.stride(-3) else: if output_slice is not None: O_M, O_N = output_slice out = torch.empty( (*lhs.shape[:-2], *output_slice), dtype=lhs.dtype, device=lhs.device ) else: O_M, O_N = M, N out = torch.empty( (*lhs.shape[:-1], rhs.size[1]), dtype=lhs.dtype, device=lhs.device ) stride_cm, stride_cn = out.stride(-2), out.stride(-1) stride_cc = out.stride(-3) assert lhs.shape[-1] <= rhs.size[0], f"incompatible dimensions: {lhs.shape[-1]} > {rhs.size[0]}" stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1) stride_bb, stride_bk, stride_bn = rhs.data.stride(0), rhs.data.stride(1), rhs.data.stride(2) grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), triton.cdiv(M, META['BLOCK_M']), CH) _dds_sbsc_kernel[grid]( lhs, rhs.data, out, M, N, K, O_M, O_N, stride_ac, stride_am, stride_ak, stride_bb, stride_bk, stride_bn, stride_cc, stride_cm, stride_cn, rhs.offset, fuse_srgb, clamp_output, GROUP_M=32, ACC_TYPE=tl.float16, BLOCK_M=32, BLOCK_N=rhs.data.shape[2], BLOCK_SIZE=rhs.data.shape[1], BLOCK_K=rhs.col_block_size ) return out from triton.language.extra import libdevice @triton.autotune( configs=[ triton.Config({}, num_stages=4, num_warps=2), ], key=['BLOCK_SIZE', 'BLOCK_N'], ) @triton.jit def _dds_sbsc_zerorhs_kernel( A: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N, stride_ac, stride_am, stride_ak, stride_cc, stride_cm, stride_cn, k, PAD, block_offset: tl.constexpr, fuse_srgb: tl.constexpr, gamma_correction: tl.constexpr, clamp_output: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid_n = tl.program_id(0) pid_m = tl.program_id(1) pid_c = tl.program_id(2) nsub_blocks = triton.cdiv(BLOCK_SIZE, BLOCK_K) start_row = block_offset * pid_n offs_k = (start_row + tl.arange(0, BLOCK_K)) * stride_ak m_range = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) A_mask = (m_range < M)[None, :].broadcast_to(BLOCK_K, BLOCK_M) A_M_ptr = A + pid_c * stride_ac + stride_am * m_range b_k = ((start_row - PAD + tl.arange(0, BLOCK_K)).to(tl.float32) + 0.5) * k b_n = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.float32) + 0.5 b_base = (b_k[None, :] - b_n[:, None]) acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float16) for _ in tl.range(nsub_blocks): A_ptr = A_M_ptr[None, :] + tl.minimum(tl.maximum(offs_k, PAD) - PAD, K - 1)[:, None] b = magic_kernel_sharp_2021_triton(b_base) * k b = b.to(tl.float16) a = tl.load(A_ptr, mask=A_mask) if fuse_srgb == 'input': if gamma_correction == 'fast': a = libdevice.fast_powf(a, 2.2).to(tl.float16) elif gamma_correction == 'srgb': a = srgb_to_linear_triton(a).to(tl.float16) acc = tl.dot(b, a, acc, out_dtype=tl.float16) offs_k += BLOCK_K * stride_ak b_base += BLOCK_K * k if fuse_srgb == 'output': if gamma_correction == 'fast': acc = libdevice.fast_powf(acc, 1.0/2.2) elif gamma_correction == 'srgb': acc = linear_to_srgb_triton(acc) if clamp_output: acc = tl.clamp(acc, 0.0, 1.0) if fuse_srgb == 'output' or clamp_output: acc = acc.to(C.dtype.element_ty) C_block_ptr = tl.make_block_ptr( base=C + pid_c * stride_cc, shape=(O_N, O_M), strides=(stride_cn, stride_cm), offsets=(pid_n * BLOCK_N, pid_m * BLOCK_M), block_shape=(BLOCK_N, BLOCK_M), order=(1, 0) ) tl.store(C_block_ptr, acc, boundary_check=(0, 1), cache_modifier='.cs') import math def triton_dds_zerorhs_sbsc( lhs: torch.Tensor, target_size: int, source_size: int, kernel_window: float, block_specs, fuse_srgb: str = '', gamma_correction: str = 'fast', clamp_output: bool = False, output_mt: bool = False, output_slice: None | Tuple[int,int] = None ): assert isinstance(lhs, torch.Tensor) assert fuse_srgb in ['input', 'output', ''] assert gamma_correction in ['fast', 'srgb'] k = target_size / source_size PAD = math.ceil((kernel_window - 0.5) / k) offset, block_height, num_blocks, col_width = block_specs assert lhs.ndim == 3 CH = lhs.shape[0] stride_ac = lhs.stride(0) M, K = lhs.shape[-2:] N = target_size if output_mt: if output_slice is not None: O_N, O_M = output_slice out = torch.empty( (*lhs.shape[:-2], *output_slice), dtype=lhs.dtype, device=lhs.device ) else: O_N, O_M = N, M out = torch.empty( (*lhs.shape[:-2], N, M), dtype=lhs.dtype, device=lhs.device ) stride_cm, stride_cn = out.stride(-1), out.stride(-2) stride_cc = out.stride(-3) else: if output_slice is not None: O_M, O_N = output_slice out = torch.empty( (*lhs.shape[:-2], *output_slice), dtype=lhs.dtype, device=lhs.device ) else: O_M, O_N = M, N out = torch.empty( (*lhs.shape[:-2], M, N), dtype=lhs.dtype, device=lhs.device ) stride_cm, stride_cn = out.stride(-2), out.stride(-1) stride_cc = out.stride(-3) stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1) grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), triton.cdiv(M, META['BLOCK_M']), CH) _dds_sbsc_zerorhs_kernel[grid]( lhs, out, M, N, K, O_M, O_N, stride_ac, stride_am, stride_ak, stride_cc, stride_cm, stride_cn, k, PAD, offset, fuse_srgb, gamma_correction, clamp_output, BLOCK_M=32, BLOCK_K=16, BLOCK_N=col_width, BLOCK_SIZE=block_height, ) return out