automatic/modules/sdnq/layers/linear/linear_fp8_tensorwise.py

53 lines
2.4 KiB
Python

# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
from typing import Tuple
import torch
from ...common import compile_func # noqa: TID252
from ...dequantizer import quantize_fp8, dequantize_symmetric, dequantize_symmetric_with_bias # noqa: TID252
from .forward import check_mats
def quantize_fp8_matmul_input_tensorwise(input: torch.FloatTensor, scale: torch.FloatTensor) -> Tuple[torch.Tensor, torch.FloatTensor]:
input = input.flatten(0,-2).to(dtype=scale.dtype)
input, input_scale = quantize_fp8(input, dim=-1)
scale = torch.mul(input_scale, scale)
if scale.dtype == torch.float16: # fp16 will overflow
scale = scale.to(dtype=torch.float32)
return input, scale
def fp8_matmul_tensorwise(
input: torch.FloatTensor,
weight: torch.Tensor,
bias: torch.FloatTensor,
scale: torch.FloatTensor,
svd_up: torch.FloatTensor,
svd_down: torch.FloatTensor,
) -> torch.FloatTensor:
return_dtype = input.dtype
output_shape = (*input.shape[:-1], weight.shape[-1])
if svd_up is not None:
input.flatten(0,-2)
if bias is not None:
bias = torch.addmm(bias.to(dtype=svd_down.dtype), torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
else:
bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
dummy_input_scale = torch.ones(1, device=input.device, dtype=torch.float32)
input, scale = quantize_fp8_matmul_input_tensorwise(input, scale)
input, weight = check_mats(input, weight)
if bias is not None:
return dequantize_symmetric_with_bias(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype), scale, bias, return_dtype, output_shape)
else:
return dequantize_symmetric(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype), scale, return_dtype, output_shape)
def quantized_linear_forward_fp8_matmul_tensorwise(self, input: torch.FloatTensor) -> torch.FloatTensor:
if torch.numel(input) / input.shape[-1] < 32:
return torch.nn.functional.linear(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down, skip_quantized_matmul=True), self.bias)
return fp8_matmul_tensorwise(input, self.weight, self.bias, self.scale, self.svd_up, self.svd_down)
fp8_matmul_tensorwise = compile_func(fp8_matmul_tensorwise)