automatic/modules/sdnq/forward.py

410 lines
19 KiB
Python

# pylint: disable=redefined-builtin,no-member,protected-access
from typing import Callable, List, Tuple, Optional
import torch
from modules import shared
from .common import conv_types, conv_transpose_types
from .dequantizer import dequantize_symmetric, dequantize_symmetric_with_bias
from .packed_int import unpack_int_symetric
def get_forward_func(layer_class_name: str, use_quantized_matmul: bool, is_integer: bool, use_tensorwise_fp8_matmul: bool) -> Callable: # pylint: disable=inconsistent-return-statements
if layer_class_name in conv_types:
if use_quantized_matmul:
if is_integer:
return quantized_conv_forward_int8_matmul
else:
if use_tensorwise_fp8_matmul:
return quantized_conv_forward_fp8_matmul_tensorwise
else:
return quantized_conv_forward_fp8_matmul
else:
return quantized_conv_forward
elif layer_class_name in conv_transpose_types:
if layer_class_name.endswith("1d"):
return quantized_conv_transpose_1d_forward
elif layer_class_name.endswith("2d"):
return quantized_conv_transpose_2d_forward
elif layer_class_name.endswith("3d"):
return quantized_conv_transpose_3d_forward
else:
if use_quantized_matmul:
if is_integer:
return quantized_linear_forward_int8_matmul
else:
if use_tensorwise_fp8_matmul:
return quantized_linear_forward_fp8_matmul_tensorwise
else:
return quantized_linear_forward_fp8_matmul
else:
return quantized_linear_forward
def quantize_fp8_matmul_input(input: torch.FloatTensor) -> Tuple[torch.Tensor, torch.FloatTensor]:
input = input.flatten(0,-2).contiguous()
input_scale = torch.amax(input.abs(), dim=-1, keepdims=True).div_(448)
input = torch.div(input, input_scale).clamp_(-448, 448).to(dtype=torch.float8_e4m3fn)
input_scale = input_scale.to(dtype=torch.float32)
return input, input_scale
def quantize_fp8_matmul_input_tensorwise(input: torch.FloatTensor, scale: torch.FloatTensor) -> Tuple[torch.Tensor, torch.FloatTensor]:
input = input.flatten(0,-2).contiguous()
input_scale = torch.amax(input.abs(), dim=-1, keepdims=True).div_(448)
input = torch.div(input, input_scale).clamp_(-448, 448).to(dtype=torch.float8_e4m3fn)
scale = torch.mul(input_scale, scale)
if scale.dtype == torch.float16: # fp16 will overflow
scale = scale.to(dtype=torch.float32)
return input, scale
def quantize_int8_matmul_input(input: torch.FloatTensor, scale: torch.FloatTensor) -> Tuple[torch.CharTensor, torch.FloatTensor]:
input = input.flatten(0,-2).contiguous()
input_scale = torch.amax(input.abs(), dim=-1, keepdims=True).div_(127)
input = torch.div(input, input_scale).round_().clamp_(-128, 127).to(dtype=torch.int8)
scale = torch.mul(input_scale, scale)
if scale.dtype == torch.float16: # fp16 will overflow
scale = scale.to(dtype=torch.float32)
return input, scale
def fp8_matmul(
input: torch.FloatTensor,
weight: torch.Tensor,
bias: torch.FloatTensor,
scale: torch.FloatTensor,
) -> torch.FloatTensor:
return_dtype = input.dtype
output_shape = list(input.shape)
output_shape[-1] = weight.shape[-1]
input, input_scale = quantize_fp8_matmul_input(input)
return torch._scaled_mm(input, weight, scale_a=input_scale, scale_b=scale, bias=bias, out_dtype=return_dtype).reshape(output_shape)
# sm89 doesn't support row wise scale in Windows
def fp8_matmul_tensorwise(
input: torch.FloatTensor,
weight: torch.Tensor,
bias: torch.FloatTensor,
scale: torch.FloatTensor,
) -> torch.FloatTensor:
return_dtype = input.dtype
output_shape = list(input.shape)
output_shape[-1] = weight.shape[-1]
dummy_input_scale = torch.ones(1, device=input.device, dtype=torch.float32)
input, scale = quantize_fp8_matmul_input_tensorwise(input, scale)
if bias is not None:
return dequantize_symmetric_with_bias(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype), scale, bias, return_dtype, output_shape)
else:
return dequantize_symmetric(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype), scale, return_dtype, output_shape)
def int8_matmul(
input: torch.FloatTensor,
weight: torch.Tensor,
bias: torch.FloatTensor,
scale: torch.FloatTensor,
quantized_weight_shape: torch.Size,
weights_dtype: str,
) -> torch.FloatTensor:
if quantized_weight_shape is not None:
weight = unpack_int_symetric(weight, quantized_weight_shape, weights_dtype, dtype=torch.int8, transpose=True)
return_dtype = input.dtype
output_shape = list(input.shape)
output_shape[-1] = weight.shape[-1]
input, scale = quantize_int8_matmul_input(input, scale)
if bias is not None:
return dequantize_symmetric_with_bias(torch._int_mm(input, weight), scale, bias, return_dtype, output_shape)
else:
return dequantize_symmetric(torch._int_mm(input, weight), scale, return_dtype, output_shape)
def process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation):
if conv_type == 1:
batch_size, _, L_in = input.shape
C_out, _, K_l = result_shape
L_out = (L_in + 2 * padding[1] - dilation[1] * (K_l - 1) - 1) // stride[1] + 1
mm_output_shape = (batch_size, L_out, C_out)
kernel_size = (1, K_l)
if conv_type == 2:
batch_size, _, H_in, W_in = input.shape
C_out, _, K_h, K_w = result_shape
H_out = (H_in + 2 * padding[0] - dilation[0] * (K_h - 1) - 1) // stride[0] + 1
W_out = (W_in + 2 * padding[1] - dilation[1] * (K_w - 1) - 1) // stride[1] + 1
mm_output_shape = (batch_size, H_out, W_out, C_out)
kernel_size = (K_h, K_w)
else:
batch_size, _, D_in, H_in, W_in = input.shape
C_out, _, K_d, K_h, K_w = result_shape
D_out = (D_in + 2 * padding[0] - dilation[0] * (K_d - 1) - 1) // stride[0] + 1
H_out = (H_in + 2 * padding[1] - dilation[1] * (K_h - 1) - 1) // stride[1] + 1
W_out = (W_in + 2 * padding[2] - dilation[2] * (K_w - 1) - 1) // stride[2] + 1
mm_output_shape = (batch_size, D_out, H_out, W_out, C_out)
kernel_size = (K_d, K_h, K_w)
if padding_mode != "zeros":
input = torch.nn.functional.pad(input, reversed_padding_repeated_twice, mode=padding_mode)
padding = (0,) * (conv_type if conv_type != 1 else 2)
elif conv_type == 3:
input = torch.nn.functional.pad(input, reversed_padding_repeated_twice)
if conv_type == 1:
input = input.unsqueeze(2)
if conv_type == 3:
K_D_eff = K_d + (K_d - 1) * (dilation[0] - 1)
K_H_eff = K_h + (K_h - 1) * (dilation[0] - 1)
K_W_eff = K_w + (K_w - 1) * (dilation[0] - 1)
input = input.unfold(2, K_D_eff, stride[0]).unfold(3, K_H_eff, stride[1]).unfold(4, K_W_eff, stride[2])
if dilation[0] > 1:
input = input[..., ::dilation[0], :, :]
if dilation[1] > 1:
input = input[..., ::dilation[1], :]
if dilation[2] > 1:
input = input[..., ::dilation[2]]
input = input.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(batch_size, D_out * H_out * W_out, -1)
else:
input = torch.nn.functional.unfold(input, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation).transpose(1,2)
return input, mm_output_shape
def conv_fp8_matmul(
input: torch.FloatTensor,
weight: torch.Tensor,
bias: torch.FloatTensor,
scale: torch.FloatTensor,
result_shape: torch.Size,
reversed_padding_repeated_twice: List[int],
padding_mode: str, conv_type: int,
groups: int, stride: List[int],
padding: List[int], dilation: List[int],
) -> torch.FloatTensor:
return_dtype = input.dtype
input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation)
input, input_scale = quantize_fp8_matmul_input(input)
if groups == 1:
result = torch._scaled_mm(input, weight, scale_a=input_scale, scale_b=scale, bias=bias, out_dtype=return_dtype).reshape(mm_output_shape)
else:
scale = scale.reshape(groups, 1, scale.shape[1] // groups)
input_scale = input_scale.reshape(groups, input_scale.shape[0] // groups, 1)
weight = weight.reshape(weight.shape[0], groups, weight.shape[1] // groups).transpose(0,1)
input = input.reshape(input.shape[0], groups, input.shape[1] // groups).transpose(0,1)
result = []
if bias is not None:
bias = bias.reshape(groups, bias.shape[0] // groups)
for i in range(groups):
result.append(torch._scaled_mm(input[i], weight[i], scale_a=input_scale[i], scale_b=scale[i], bias=bias[i], out_dtype=return_dtype))
else:
for i in range(groups):
result.append(torch._scaled_mm(input[i], weight[i], scale_a=input_scale[i], scale_b=scale[i], bias=None, out_dtype=return_dtype))
result = torch.cat(result, dim=-1).reshape(mm_output_shape)
if conv_type == 1:
result = result.transpose(1,2)
elif conv_type == 2:
result = result.permute(0,3,1,2)
elif conv_type == 3:
result = result.permute(0,4,1,2,3)
return result
def conv_fp8_matmul_tensorwise(
input: torch.FloatTensor,
weight: torch.Tensor,
bias: torch.FloatTensor,
scale: torch.FloatTensor,
result_shape: torch.Size,
reversed_padding_repeated_twice: List[int],
padding_mode: str, conv_type: int,
groups: int, stride: List[int],
padding: List[int], dilation: List[int],
) -> torch.FloatTensor:
return_dtype = input.dtype
input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation)
input, scale = quantize_fp8_matmul_input_tensorwise(input, scale)
dummy_input_scale = torch.ones(1, device=input.device, dtype=torch.float32)
if groups == 1:
result = torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype)
else:
weight = weight.reshape(weight.shape[0], groups, weight.shape[1] // groups).transpose(0,1)
input = input.reshape(input.shape[0], groups, input.shape[1] // groups).transpose(0,1)
result = []
for i in range(groups):
result.append(torch._scaled_mm(input[i], weight[i], scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype))
result = torch.cat(result, dim=-1)
if bias is not None:
dequantize_symmetric_with_bias(result, scale, bias, return_dtype, mm_output_shape)
else:
dequantize_symmetric(result, scale, return_dtype, mm_output_shape)
if conv_type == 1:
result = result.transpose(1,2)
elif conv_type == 2:
result = result.permute(0,3,1,2)
elif conv_type == 3:
result = result.permute(0,4,1,2,3)
return result
def conv_int8_matmul(
input: torch.FloatTensor,
weight: torch.CharTensor,
bias: torch.FloatTensor,
scale: torch.FloatTensor,
result_shape: torch.Size,
quantized_weight_shape: torch.Size,
weights_dtype: str,
reversed_padding_repeated_twice: List[int],
padding_mode: str, conv_type: int,
groups: int, stride: List[int],
padding: List[int], dilation: List[int],
) -> torch.FloatTensor:
return_dtype = input.dtype
input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation)
input, scale = quantize_int8_matmul_input(input, scale)
if quantized_weight_shape is not None:
weight = unpack_int_symetric(weight, quantized_weight_shape, weights_dtype, dtype=torch.int8, transpose=True)
if groups == 1:
result = torch._int_mm(input, weight)
else:
weight = weight.reshape(weight.shape[0], groups, weight.shape[1] // groups).transpose(0,1)
input = input.reshape(input.shape[0], groups, input.shape[1] // groups).transpose(0,1)
result = []
for i in range(groups):
result.append(torch._int_mm(input[i], weight[i]))
result = torch.cat(result, dim=-1)
if bias is not None:
result = dequantize_symmetric_with_bias(result, scale, bias, return_dtype, mm_output_shape)
else:
result = dequantize_symmetric(result, scale, return_dtype, mm_output_shape)
if conv_type == 1:
result = result.transpose(1,2)
elif conv_type == 2:
result = result.permute(0,3,1,2)
elif conv_type == 3:
result = result.permute(0,4,1,2,3)
return result
def quantized_linear_forward_fp8_matmul(self, input: torch.FloatTensor) -> torch.FloatTensor:
if torch.numel(input) / input.shape[-1] < 32:
return torch.nn.functional.linear(input, self.sdnq_dequantizer(self.weight, skip_quantized_matmul=True), self.bias)
return fp8_matmul(input, self.weight, self.bias, self.sdnq_dequantizer.scale)
def quantized_linear_forward_fp8_matmul_tensorwise(self, input: torch.FloatTensor) -> torch.FloatTensor:
if torch.numel(input) / input.shape[-1] < 32:
return torch.nn.functional.linear(input, self.sdnq_dequantizer(self.weight, skip_quantized_matmul=True), self.bias)
return fp8_matmul_tensorwise(input, self.weight, self.bias, self.sdnq_dequantizer.scale)
def quantized_linear_forward_int8_matmul(self, input: torch.FloatTensor) -> torch.FloatTensor:
if torch.numel(input) / input.shape[-1] < 32:
return torch.nn.functional.linear(input, self.sdnq_dequantizer(self.weight, skip_quantized_matmul=True), self.bias)
return int8_matmul(input, self.weight, self.bias, self.sdnq_dequantizer.scale, getattr(self.sdnq_dequantizer, "quantized_weight_shape", None), self.sdnq_dequantizer.weights_dtype)
def quantized_linear_forward(self, input: torch.FloatTensor) -> torch.FloatTensor:
return torch.nn.functional.linear(input, self.sdnq_dequantizer(self.weight), self.bias)
def get_conv_args(input_ndim: int, stride, padding, dilation):
if input_ndim == 3:
conv_type = 1
elif input_ndim == 4:
conv_type = 2
else:
conv_type = 3
if isinstance(stride, int):
stride = (stride,) * conv_type
if isinstance(padding, int):
padding = (padding,) * conv_type
if isinstance(dilation, int):
dilation = (dilation,) * conv_type
if conv_type == 1:
stride = (1, stride[0])
padding = (0, padding[0])
dilation = (1, dilation[0])
return conv_type, stride, padding, dilation
def quantized_conv_forward_fp8_matmul(self, input) -> torch.FloatTensor:
if torch.numel(input) / input.shape[2] < 32:
return self._conv_forward(input, self.sdnq_dequantizer(self.weight, skip_quantized_matmul=True), self.bias)
conv_type, stride, padding, dilation = get_conv_args(input.ndim, self.stride, self.padding, self.dilation)
return conv_fp8_matmul(
input, self.weight, self.bias,
self.sdnq_dequantizer.scale,
self.sdnq_dequantizer.result_shape,
self._reversed_padding_repeated_twice,
self.padding_mode, conv_type,
self.groups, stride, padding, dilation,
)
def quantized_conv_forward_fp8_matmul_tensorwise(self, input) -> torch.FloatTensor:
if torch.numel(input) / input.shape[2] < 32:
return self._conv_forward(input, self.sdnq_dequantizer(self.weight, skip_quantized_matmul=True), self.bias)
conv_type, stride, padding, dilation = get_conv_args(input.ndim, self.stride, self.padding, self.dilation)
return conv_fp8_matmul_tensorwise(
input, self.weight, self.bias,
self.sdnq_dequantizer.scale,
self.sdnq_dequantizer.result_shape,
self._reversed_padding_repeated_twice,
self.padding_mode, conv_type,
self.groups, stride, padding, dilation,
)
def quantized_conv_forward_int8_matmul(self, input) -> torch.FloatTensor:
if torch.numel(input) / input.shape[2] < 32:
return self._conv_forward(input, self.sdnq_dequantizer(self.weight, skip_quantized_matmul=True), self.bias)
conv_type, stride, padding, dilation = get_conv_args(input.ndim, self.stride, self.padding, self.dilation)
return conv_int8_matmul(
input, self.weight, self.bias,
self.sdnq_dequantizer.scale,
self.sdnq_dequantizer.result_shape,
getattr(self.sdnq_dequantizer, "quantized_weight_shape", None),
self.sdnq_dequantizer.weights_dtype,
self._reversed_padding_repeated_twice,
self.padding_mode, conv_type,
self.groups, stride, padding, dilation,
)
def quantized_conv_forward(self, input) -> torch.FloatTensor:
return self._conv_forward(input, self.sdnq_dequantizer(self.weight), self.bias)
def quantized_conv_transpose_1d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 1, self.dilation)
return torch.nn.functional.conv_transpose1d(input, self.sdnq_dequantizer(self.weight), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
def quantized_conv_transpose_2d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 2, self.dilation)
return torch.nn.functional.conv_transpose2d(input, self.sdnq_dequantizer(self.weight), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
def quantized_conv_transpose_3d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 3, self.dilation)
return torch.nn.functional.conv_transpose3d(input, self.sdnq_dequantizer(self.weight), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
if shared.opts.sdnq_dequantize_compile:
try:
torch._dynamo.config.cache_size_limit = max(8192, torch._dynamo.config.cache_size_limit)
int8_matmul = torch.compile(int8_matmul, fullgraph=True)
fp8_matmul = torch.compile(fp8_matmul, fullgraph=True)
fp8_matmul_tensorwise = torch.compile(fp8_matmul_tensorwise, fullgraph=True)
conv_int8_matmul = torch.compile(conv_int8_matmul, fullgraph=True)
conv_fp8_matmul = torch.compile(conv_fp8_matmul, fullgraph=True)
conv_fp8_matmul_tensorwise = torch.compile(conv_fp8_matmul_tensorwise, fullgraph=True)
except Exception as e:
shared.log.warning(f"Quantization: type=sdnq MatMul using torch.compile is not available: {e}")