automatic/modules/sdnq/layers/conv/conv_fp8.py

98 lines
4.3 KiB
Python

# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
import torch
from ...common import compile_func # noqa: TID252
from ...packed_float import unpack_float # noqa: TID252
from .forward import get_conv_args, process_conv_input
from ..linear.linear_fp8 import quantize_fp_mm_input # noqa: TID252
from ..linear.forward import check_mats # noqa: TID252
def conv_fp8_matmul(
input: torch.FloatTensor,
weight: torch.Tensor,
scale: torch.FloatTensor,
result_shape: torch.Size,
reversed_padding_repeated_twice: list[int],
padding_mode: str, conv_type: int,
groups: int, stride: list[int],
padding: list[int], dilation: list[int],
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
quantized_weight_shape: torch.Size = None,
weights_dtype: str = None,
) -> torch.FloatTensor:
return_dtype = input.dtype
input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation)
if svd_up is not None:
input = input.flatten(0,-2)
svd_bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
if quantized_weight_shape is not None:
weight = unpack_float(weight, weights_dtype, quantized_weight_shape).to(dtype=torch.float8_e4m3fn).t_()
scale = scale.t()
input, input_scale = quantize_fp_mm_input(input)
input, weight = check_mats(input, weight)
if groups == 1:
if bias is not None and bias.dtype != torch.bfloat16:
bias = bias.to(dtype=torch.bfloat16)
result = torch._scaled_mm(input, weight, scale_a=input_scale, scale_b=scale, bias=bias, out_dtype=torch.bfloat16)
else:
scale = scale.view(groups, 1, scale.shape[1] // groups)
input_scale = input_scale.view(groups, input_scale.shape[0] // groups, 1)
weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups)
input = input.view(input.shape[0], groups, input.shape[1] // groups)
result = []
if bias is not None:
bias = bias.view(groups, bias.shape[0] // groups)
if bias.dtype != torch.bfloat16:
bias = bias.to(dtype=torch.bfloat16)
for i in range(groups):
result.append(torch._scaled_mm(input[:, i], weight[:, i], scale_a=input_scale[i], scale_b=scale[i], bias=bias[i], out_dtype=torch.bfloat16))
else:
for i in range(groups):
result.append(torch._scaled_mm(input[:, i], weight[:, i], scale_a=input_scale[i], scale_b=scale[i], bias=None, out_dtype=torch.bfloat16))
result = torch.cat(result, dim=-1)
if svd_up is not None:
result.add_(svd_bias)
result = result.view(mm_output_shape).to(return_dtype)
if conv_type == 1:
result = result.transpose_(1,2)
elif conv_type == 2:
result = result.permute(0,3,1,2)
elif conv_type == 3:
result = result.permute(0,4,1,2,3)
return result
def quantized_conv_forward_fp8_matmul(self, input) -> torch.FloatTensor:
if torch.numel(input) / input.shape[2] < 32:
return self._conv_forward(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down, skip_quantized_matmul=True), self.bias)
if self.sdnq_dequantizer.re_quantize_for_matmul:
weight, scale = self.sdnq_dequantizer.re_quantize_matmul(self.weight, self.scale, self.zero_point, None, None)
quantized_weight_shape = None
else:
weight, scale = self.weight, self.scale
quantized_weight_shape = self.sdnq_dequantizer.quantized_weight_shape if self.sdnq_dequantizer.is_packed else None
conv_type, stride, padding, dilation = get_conv_args(input.ndim, self.stride, self.padding, self.dilation)
return conv_fp8_matmul(
input, weight, scale,
self.sdnq_dequantizer.result_shape,
self._reversed_padding_repeated_twice,
self.padding_mode, conv_type,
self.groups, stride, padding, dilation,
bias=self.bias,
svd_up=self.svd_up,
svd_down=self.svd_down,
quantized_weight_shape=quantized_weight_shape,
weights_dtype=self.sdnq_dequantizer.weights_dtype,
)
conv_fp8_matmul = compile_func(conv_fp8_matmul)