mirror of https://github.com/vladmandic/automatic
92 lines
4.0 KiB
Python
92 lines
4.0 KiB
Python
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
|
|
|
|
import torch
|
|
|
|
from ...common import compile_func, int_mm_func # noqa: TID252
|
|
from ...dequantizer import dequantize_symmetric, dequantize_symmetric_with_bias # noqa: TID252
|
|
from ...packed_int import unpack_int # noqa: TID252
|
|
|
|
from .forward import get_conv_args, process_conv_input
|
|
from ..linear.linear_int8 import quantize_int_mm_input # noqa: TID252
|
|
from ..linear.forward import check_mats # noqa: TID252
|
|
|
|
|
|
def conv_int8_matmul(
|
|
input: torch.FloatTensor,
|
|
weight: torch.Tensor,
|
|
scale: torch.FloatTensor,
|
|
result_shape: torch.Size,
|
|
reversed_padding_repeated_twice: list[int],
|
|
padding_mode: str, conv_type: int,
|
|
groups: int, stride: list[int],
|
|
padding: list[int], dilation: list[int],
|
|
bias: torch.FloatTensor | None = None,
|
|
svd_up: torch.FloatTensor | None = None,
|
|
svd_down: torch.FloatTensor | None = None,
|
|
quantized_weight_shape: torch.Size | None = None,
|
|
weights_dtype: str | None = None,
|
|
) -> torch.FloatTensor:
|
|
return_dtype = input.dtype
|
|
input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation)
|
|
if svd_up is not None:
|
|
input = input.flatten(0,-2)
|
|
if bias is not None:
|
|
bias = torch.addmm(bias.to(dtype=svd_down.dtype), torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
|
|
else:
|
|
bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
|
|
|
|
if quantized_weight_shape is not None:
|
|
weight = unpack_int(weight, weights_dtype, quantized_weight_shape, dtype=torch.int8).t_()
|
|
scale = scale.t()
|
|
input, input_scale = quantize_int_mm_input(input, dtype=scale.dtype)
|
|
input, weight = check_mats(input, weight)
|
|
|
|
if groups == 1:
|
|
result = int_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale)
|
|
else:
|
|
weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups)
|
|
input = input.view(input.shape[0], groups, input.shape[1] // groups)
|
|
result = []
|
|
for i in range(groups):
|
|
result.append(int_mm_func(input[:, i], weight[:, i]))
|
|
result = torch.cat(result, dim=-1).to(dtype=input_scale.dtype).mul_(input_scale)
|
|
if bias is not None:
|
|
result = dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape)
|
|
else:
|
|
result = dequantize_symmetric(result, scale, dtype=return_dtype, result_shape=mm_output_shape)
|
|
|
|
if conv_type == 1:
|
|
result = result.transpose_(1,2)
|
|
elif conv_type == 2:
|
|
result = result.permute(0,3,1,2)
|
|
elif conv_type == 3:
|
|
result = result.permute(0,4,1,2,3)
|
|
return result
|
|
|
|
|
|
def quantized_conv_forward_int8_matmul(self, input) -> torch.FloatTensor:
|
|
if torch.numel(input) / input.shape[2] < 32:
|
|
return self._conv_forward(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down, skip_quantized_matmul=True), self.bias)
|
|
conv_type, stride, padding, dilation = get_conv_args(input.ndim, self.stride, self.padding, self.dilation)
|
|
if self.sdnq_dequantizer.re_quantize_for_matmul:
|
|
weight, scale = self.sdnq_dequantizer.re_quantize_matmul(self.weight, self.scale, self.zero_point, None, None)
|
|
quantized_weight_shape = None
|
|
else:
|
|
weight, scale = self.weight, self.scale
|
|
quantized_weight_shape = self.sdnq_dequantizer.quantized_weight_shape if self.sdnq_dequantizer.is_packed else None
|
|
return conv_int8_matmul(
|
|
input, weight, scale,
|
|
self.sdnq_dequantizer.result_shape,
|
|
self._reversed_padding_repeated_twice,
|
|
self.padding_mode, conv_type,
|
|
self.groups, stride, padding, dilation,
|
|
bias=self.bias,
|
|
svd_up=self.svd_up,
|
|
svd_down=self.svd_down,
|
|
quantized_weight_shape=quantized_weight_shape,
|
|
weights_dtype=self.sdnq_dequantizer.weights_dtype,
|
|
)
|
|
|
|
|
|
conv_int8_matmul = compile_func(conv_int8_matmul)
|