SDNQ add tensor descriptor kernel to triton mm for Intel Arc

pull/4733/head^2
Disty0 2026-04-04 01:32:34 +03:00
parent ffeda702c5
commit 470a0d816e
7 changed files with 92 additions and 37 deletions

View File

@ -40,18 +40,18 @@ def conv_fp16_matmul(
scale = scale.t() scale = scale.t()
elif weight.dtype != torch.float16: elif weight.dtype != torch.float16:
weight = weight.to(dtype=torch.float16) # fp8 weights weight = weight.to(dtype=torch.float16) # fp8 weights
input, scale = quantize_fp_mm_input_tensorwise(input, scale, matmul_dtype="float16") input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype, matmul_dtype="float16")
input, weight = check_mats(input, weight) input, weight = check_mats(input, weight)
if groups == 1: if groups == 1:
result = fp_mm_func(input, weight) result = fp_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale)
else: else:
weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups) weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups)
input = input.view(input.shape[0], groups, input.shape[1] // groups) input = input.view(input.shape[0], groups, input.shape[1] // groups)
result = [] result = []
for i in range(groups): for i in range(groups):
result.append(fp_mm_func(input[:, i], weight[:, i])) result.append(fp_mm_func(input[:, i], weight[:, i]))
result = torch.cat(result, dim=-1) result = torch.cat(result, dim=-1).to(dtype=input_scale.dtype).mul_(input_scale)
if bias is not None: if bias is not None:
dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape) dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape)
else: else:

View File

@ -38,19 +38,19 @@ def conv_fp8_matmul_tensorwise(
if quantized_weight_shape is not None: if quantized_weight_shape is not None:
weight = unpack_float(weight, weights_dtype, quantized_weight_shape).to(dtype=torch.float8_e4m3fn).t_() weight = unpack_float(weight, weights_dtype, quantized_weight_shape).to(dtype=torch.float8_e4m3fn).t_()
scale = scale.t() scale = scale.t()
input, scale = quantize_fp_mm_input_tensorwise(input, scale) input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype)
input, weight = check_mats(input, weight) input, weight = check_mats(input, weight)
dummy_input_scale = torch.ones(1, device=input.device, dtype=torch.float32) dummy_input_scale = torch.ones(1, device=input.device, dtype=torch.float32)
if groups == 1: if groups == 1:
result = torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype) result = torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype).mul_(input_scale)
else: else:
weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups) weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups)
input = input.view(input.shape[0], groups, input.shape[1] // groups) input = input.view(input.shape[0], groups, input.shape[1] // groups)
result = [] result = []
for i in range(groups): for i in range(groups):
result.append(torch._scaled_mm(input[:, i], weight[:, i], scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype)) result.append(torch._scaled_mm(input[:, i], weight[:, i], scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype))
result = torch.cat(result, dim=-1) result = torch.cat(result, dim=-1).mul_(input_scale)
if bias is not None: if bias is not None:
dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape) dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape)
else: else:

View File

@ -38,18 +38,18 @@ def conv_int8_matmul(
if quantized_weight_shape is not None: if quantized_weight_shape is not None:
weight = unpack_int(weight, weights_dtype, quantized_weight_shape, dtype=torch.int8).t_() weight = unpack_int(weight, weights_dtype, quantized_weight_shape, dtype=torch.int8).t_()
scale = scale.t() scale = scale.t()
input, scale = quantize_int_mm_input(input, scale) input, input_scale = quantize_int_mm_input(input, dtype=scale.dtype)
input, weight = check_mats(input, weight) input, weight = check_mats(input, weight)
if groups == 1: if groups == 1:
result = int_mm_func(input, weight) result = int_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale)
else: else:
weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups) weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups)
input = input.view(input.shape[0], groups, input.shape[1] // groups) input = input.view(input.shape[0], groups, input.shape[1] // groups)
result = [] result = []
for i in range(groups): for i in range(groups):
result.append(int_mm_func(input[:, i], weight[:, i])) result.append(int_mm_func(input[:, i], weight[:, i]))
result = torch.cat(result, dim=-1) result = torch.cat(result, dim=-1).to(dtype=input_scale.dtype).mul_(input_scale)
if bias is not None: if bias is not None:
result = dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape) result = dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape)
else: else:

View File

@ -33,12 +33,12 @@ def fp16_matmul(
bias = torch.addmm(bias.to(dtype=svd_down.dtype), torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up) bias = torch.addmm(bias.to(dtype=svd_down.dtype), torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
else: else:
bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up) bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
input, scale = quantize_fp_mm_input_tensorwise(input, scale, matmul_dtype="float16") input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype, matmul_dtype="float16")
input, weight = check_mats(input, weight) input, weight = check_mats(input, weight)
if bias is not None: if bias is not None:
return dequantize_symmetric_with_bias(fp_mm_func(input, weight), scale, bias, dtype=return_dtype, result_shape=output_shape) return dequantize_symmetric_with_bias(fp_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, bias, dtype=return_dtype, result_shape=output_shape)
else: else:
return dequantize_symmetric(fp_mm_func(input, weight), scale, dtype=return_dtype, result_shape=output_shape) return dequantize_symmetric(fp_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, dtype=return_dtype, result_shape=output_shape)
def quantized_linear_forward_fp16_matmul(self, input: torch.FloatTensor) -> torch.FloatTensor: def quantized_linear_forward_fp16_matmul(self, input: torch.FloatTensor) -> torch.FloatTensor:

View File

@ -9,13 +9,14 @@ from ...dequantizer import quantize_fp_mm, dequantize_symmetric, dequantize_symm
from .forward import check_mats from .forward import check_mats
def quantize_fp_mm_input_tensorwise(input: torch.FloatTensor, scale: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]: def quantize_fp_mm_input_tensorwise(input: torch.FloatTensor, dtype: torch.dtype | None = None, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]:
input = input.flatten(0,-2).to(dtype=scale.dtype) input = input.flatten(0,-2)
if dtype is not None:
input = input.to(dtype=dtype)
input, input_scale = quantize_fp_mm(input, dim=-1, matmul_dtype=matmul_dtype) input, input_scale = quantize_fp_mm(input, dim=-1, matmul_dtype=matmul_dtype)
scale = torch.mul(input_scale, scale) if input_scale.dtype == torch.float16: # fp16 will overflow
if scale.dtype == torch.float16: # fp16 will overflow input_scale = input_scale.to(dtype=torch.float32)
scale = scale.to(dtype=torch.float32) return input, input_scale
return input, scale
def fp8_matmul_tensorwise( def fp8_matmul_tensorwise(
@ -40,12 +41,12 @@ def fp8_matmul_tensorwise(
else: else:
bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up) bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
dummy_input_scale = torch.ones(1, device=input.device, dtype=torch.float32) dummy_input_scale = torch.ones(1, device=input.device, dtype=torch.float32)
input, scale = quantize_fp_mm_input_tensorwise(input, scale) input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype)
input, weight = check_mats(input, weight, allow_contiguous_mm=False) input, weight = check_mats(input, weight, allow_contiguous_mm=False)
if bias is not None: if bias is not None:
return dequantize_symmetric_with_bias(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype), scale, bias, dtype=return_dtype, result_shape=output_shape) return dequantize_symmetric_with_bias(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype).to(dtype=input_scale.dtype).mul_(input_scale), scale, bias, dtype=return_dtype, result_shape=output_shape)
else: else:
return dequantize_symmetric(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype), scale, dtype=return_dtype, result_shape=output_shape) return dequantize_symmetric(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype).to(dtype=input_scale.dtype).mul_(input_scale), scale, dtype=return_dtype, result_shape=output_shape)
def quantized_linear_forward_fp8_matmul_tensorwise(self, input: torch.FloatTensor) -> torch.FloatTensor: def quantized_linear_forward_fp8_matmul_tensorwise(self, input: torch.FloatTensor) -> torch.FloatTensor:

View File

@ -9,13 +9,14 @@ from ...packed_int import unpack_int # noqa: TID252
from .forward import check_mats from .forward import check_mats
def quantize_int_mm_input(input: torch.FloatTensor, scale: torch.FloatTensor) -> tuple[torch.CharTensor, torch.FloatTensor]: def quantize_int_mm_input(input: torch.FloatTensor, dtype: torch.dtype | None = None) -> tuple[torch.CharTensor, torch.FloatTensor]:
input = input.flatten(0,-2).to(dtype=scale.dtype) input = input.flatten(0,-2)
if dtype is not None:
input = input.to(dtype=dtype)
input, input_scale = quantize_int_mm(input, dim=-1) input, input_scale = quantize_int_mm(input, dim=-1)
scale = torch.mul(input_scale, scale) if input_scale.dtype == torch.float16: # fp16 will overflow
if scale.dtype == torch.float16: # fp16 will overflow input_scale = input_scale.to(dtype=torch.float32)
scale = scale.to(dtype=torch.float32) return input, input_scale
return input, scale
def int8_matmul( def int8_matmul(
@ -39,12 +40,12 @@ def int8_matmul(
bias = torch.addmm(bias.to(dtype=svd_down.dtype), torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up) bias = torch.addmm(bias.to(dtype=svd_down.dtype), torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
else: else:
bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up) bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
input, scale = quantize_int_mm_input(input, scale) input, input_scale = quantize_int_mm_input(input, dtype=scale.dtype)
input, weight = check_mats(input, weight) input, weight = check_mats(input, weight)
if bias is not None: if bias is not None:
return dequantize_symmetric_with_bias(int_mm_func(input, weight), scale, bias, dtype=return_dtype, result_shape=output_shape) return dequantize_symmetric_with_bias(int_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, bias, dtype=return_dtype, result_shape=output_shape)
else: else:
return dequantize_symmetric(int_mm_func(input, weight), scale, dtype=return_dtype, result_shape=output_shape) return dequantize_symmetric(int_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, dtype=return_dtype, result_shape=output_shape)
def quantized_linear_forward_int8_matmul(self, input: torch.FloatTensor) -> torch.FloatTensor: def quantized_linear_forward_int8_matmul(self, input: torch.FloatTensor) -> torch.FloatTensor:

View File

@ -1,8 +1,10 @@
""" """
Modified from Triton MatMul example. Modified from Triton MatMul example.
PyTorch torch._int_mm is broken on backward pass with Nvidia. PyTorch torch._int_mm is broken on backward pass with Nvidia, so we use Triton on the backward pass with Nvidia.
AMD RDNA2 doesn't support torch._int_mm, so we use int_mm via Triton. AMD RDNA2 doesn't support torch._int_mm as it requires INT8 WMMA, so we use INT8 DP4A via Triton.
PyTorch doesn't support FP32 output type with FP16 MM so we use Triton for it too. PyTorch doesn't support FP32 output type with FP16 MM, so we use Triton for FP16 MM too.
matmul_configs we use takes AMD and Intel into consideration too.
SDNQ Triton configs can outperform RocBLAS and OneDNN.
""" """
import torch import torch
@ -22,7 +24,7 @@ matmul_configs = [
] ]
@triton.autotune(configs=matmul_configs, key=["M", "N", "K", "stride_bk", "ACCUMULATOR_DTYPE"]) @triton.autotune(configs=matmul_configs, key=["M", "N", "K", "stride_bk", "ACCUMULATOR_DTYPE"], cache_results=True)
@triton.jit @triton.jit
def triton_mm_kernel( def triton_mm_kernel(
a_ptr, b_ptr, c_ptr, a_ptr, b_ptr, c_ptr,
@ -76,6 +78,55 @@ def triton_mm_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
# Intel requires tensor descriptors to perform good
@triton.autotune(configs=matmul_configs, key=["M", "N", "K", "stride_bk", "ACCUMULATOR_DTYPE"], cache_results=True)
@triton.jit
def triton_mm_td_kernel(
a_ptr, b_ptr, c_ptr,
M: int, N: int, K: int,
stride_am: int, stride_ak: int,
stride_bk: int, stride_bn: int,
stride_cm: int, stride_cn: int,
ACCUMULATOR_DTYPE: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
a_desc = tl.make_tensor_descriptor(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
b_desc = tl.make_tensor_descriptor(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
off_k = 0
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)
off_k += BLOCK_SIZE_K
c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], accumulator)
def int_mm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: def int_mm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous" assert a.is_contiguous(), "Matrix A must be contiguous"
@ -84,7 +135,8 @@ def int_mm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
c = torch.empty((M, N), device=a.device, dtype=torch.int32) c = torch.empty((M, N), device=a.device, dtype=torch.int32)
def grid(META): def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
triton_mm_kernel[grid]( mm_kernel_func = triton_mm_td_kernel if b.is_contiguous() else triton_mm_kernel
mm_kernel_func[grid](
a, b, c, a, b, c,
M, N, K, M, N, K,
a.stride(0), a.stride(1), a.stride(0), a.stride(1),
@ -103,7 +155,8 @@ def fp_mm(a: torch.FloatTensor, b: torch.FloatTensor) -> torch.FloatTensor:
c = torch.empty((M, N), device=a.device, dtype=torch.float32) c = torch.empty((M, N), device=a.device, dtype=torch.float32)
def grid(META): def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
triton_mm_kernel[grid]( mm_kernel_func = triton_mm_td_kernel if b.is_contiguous() else triton_mm_kernel
mm_kernel_func[grid](
a, b, c, a, b, c,
M, N, K, M, N, K,
a.stride(0), a.stride(1), a.stride(0), a.stride(1),