automatic/modules/sdnq/dequantizer.py

194 lines
8.9 KiB
Python

# pylint: disable=redefined-builtin,no-member,protected-access
import torch
from modules import shared
from .common import dtype_dict
from .packed_int import pack_int_symetric, unpack_int_symetric, packed_int_function_dict
def dequantize_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, dtype: torch.dtype, result_shape: torch.Size) -> torch.FloatTensor:
result = torch.addcmul(zero_point, weight.to(dtype=scale.dtype), scale).to(dtype=dtype)
if result_shape is not None:
result = result.reshape(result_shape)
return result
def dequantize_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, dtype: torch.dtype, result_shape: torch.Size, skip_quantized_matmul: bool = False) -> torch.FloatTensor:
if skip_quantized_matmul:
result = weight.transpose(0,1).to(dtype=scale.dtype).mul_(scale.transpose(0,1)).to(dtype=dtype)
else:
result = weight.to(dtype=scale.dtype).mul_(scale).to(dtype=dtype)
if result_shape is not None:
result = result.reshape(result_shape)
return result
def dequantize_symmetric_with_bias(weight: torch.CharTensor, scale: torch.FloatTensor, bias: torch.FloatTensor, dtype: torch.dtype, result_shape: torch.Size) -> torch.FloatTensor:
return torch.addcmul(bias, weight.to(dtype=scale.dtype), scale).to(dtype=dtype).reshape(result_shape)
def dequantize_packed_int_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, dtype: torch.dtype, result_shape: torch.Size, weights_dtype: str) -> torch.FloatTensor:
return dequantize_asymmetric(packed_int_function_dict[weights_dtype]["unpack"](weight, shape), scale, zero_point, dtype, result_shape)
def dequantize_packed_int_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, dtype: torch.dtype, result_shape: torch.Size, weights_dtype: str, skip_quantized_matmul: bool = False) -> torch.FloatTensor:
if skip_quantized_matmul:
return dequantize_symmetric(unpack_int_symetric(weight, shape, weights_dtype, dtype=scale.dtype), scale.transpose(0,1), dtype, result_shape)
else:
return dequantize_symmetric(unpack_int_symetric(weight, shape, weights_dtype, dtype=scale.dtype), scale, dtype, result_shape)
class AsymmetricWeightsDequantizer(torch.nn.Module):
def __init__(
self,
scale: torch.FloatTensor,
zero_point: torch.FloatTensor,
result_dtype: torch.dtype,
result_shape: torch.Size,
original_shape: torch.Size,
weights_dtype: str,
**kwargs, # pylint: disable=unused-argument
):
super().__init__()
self.weights_dtype = weights_dtype
self.original_shape = original_shape
self.use_quantized_matmul = False
self.result_dtype = result_dtype
self.result_shape = result_shape
self.register_buffer("scale", scale)
self.register_buffer("zero_point", zero_point)
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
return weight.to(dtype=dtype_dict[self.weights_dtype]["torch_dtype"])
def forward(self, weight, **kwargs): # pylint: disable=unused-argument
return dequantize_asymmetric_compiled(weight, self.scale, self.zero_point, self.result_dtype, self.result_shape)
class SymmetricWeightsDequantizer(torch.nn.Module):
def __init__(
self,
scale: torch.FloatTensor,
result_dtype: torch.dtype,
result_shape: torch.Size,
original_shape: torch.Size,
weights_dtype: str,
use_quantized_matmul: bool = False,
**kwargs, # pylint: disable=unused-argument
):
super().__init__()
self.weights_dtype = weights_dtype
self.original_shape = original_shape
self.use_quantized_matmul = use_quantized_matmul
self.result_dtype = result_dtype
self.result_shape = result_shape
self.register_buffer("scale", scale)
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
return weight.to(dtype=dtype_dict[self.weights_dtype]["torch_dtype"])
def forward(self, weight, skip_quantized_matmul=False, **kwargs): # pylint: disable=unused-argument
return dequantize_symmetric_compiled(weight, self.scale, self.result_dtype, self.result_shape, skip_quantized_matmul=skip_quantized_matmul)
class PackedINTAsymmetricWeightsDequantizer(torch.nn.Module):
def __init__(
self,
scale: torch.FloatTensor,
zero_point: torch.FloatTensor,
quantized_weight_shape: torch.Size,
result_dtype: torch.dtype,
result_shape: torch.Size,
original_shape: torch.Size,
weights_dtype: str,
**kwargs, # pylint: disable=unused-argument
):
super().__init__()
self.weights_dtype = weights_dtype
self.use_quantized_matmul = False
self.original_shape = original_shape
self.quantized_weight_shape = quantized_weight_shape
self.result_dtype = result_dtype
self.result_shape = result_shape
self.register_buffer("scale", scale)
self.register_buffer("zero_point", zero_point)
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
return packed_int_function_dict[self.weights_dtype]["pack"](weight.to(dtype=dtype_dict[self.weights_dtype]["torch_dtype"]))
def forward(self, weight, **kwargs): # pylint: disable=unused-argument
return dequantize_packed_int_asymmetric_compiled(weight, self.scale, self.zero_point, self.quantized_weight_shape, self.result_dtype, self.result_shape, self.weights_dtype)
class PackedINTSymmetricWeightsDequantizer(torch.nn.Module):
def __init__(
self,
scale: torch.FloatTensor,
quantized_weight_shape: torch.Size,
result_dtype: torch.dtype,
result_shape: torch.Size,
original_shape: torch.Size,
weights_dtype: str,
use_quantized_matmul: bool = False,
**kwargs, # pylint: disable=unused-argument
):
super().__init__()
self.weights_dtype = weights_dtype
self.original_shape = original_shape
self.use_quantized_matmul = use_quantized_matmul
self.quantized_weight_shape = quantized_weight_shape
self.result_dtype = result_dtype
self.result_shape = result_shape
self.register_buffer("scale", scale)
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
return pack_int_symetric(weight, self.weights_dtype)
def forward(self, weight, skip_quantized_matmul=False, **kwargs): # pylint: disable=unused-argument
return dequantize_packed_int_symmetric_compiled(weight, self.scale, self.quantized_weight_shape, self.result_dtype, self.result_shape, self.weights_dtype, skip_quantized_matmul=skip_quantized_matmul)
dequantizer_dict = {
"int8": SymmetricWeightsDequantizer,
"int7": PackedINTSymmetricWeightsDequantizer,
"int6": PackedINTSymmetricWeightsDequantizer,
"int5": PackedINTSymmetricWeightsDequantizer,
"int4": PackedINTSymmetricWeightsDequantizer,
"int3": PackedINTSymmetricWeightsDequantizer,
"int2": PackedINTSymmetricWeightsDequantizer,
"uint8": AsymmetricWeightsDequantizer,
"uint7": PackedINTAsymmetricWeightsDequantizer,
"uint6": PackedINTAsymmetricWeightsDequantizer,
"uint5": PackedINTAsymmetricWeightsDequantizer,
"uint4": PackedINTAsymmetricWeightsDequantizer,
"uint3": PackedINTAsymmetricWeightsDequantizer,
"uint2": PackedINTAsymmetricWeightsDequantizer,
"uint1": AsymmetricWeightsDequantizer,
"bool": AsymmetricWeightsDequantizer,
"float8_e4m3fn": SymmetricWeightsDequantizer,
"float8_e4m3fnuz": SymmetricWeightsDequantizer,
"float8_e5m2": SymmetricWeightsDequantizer,
"float8_e5m2fnuz": SymmetricWeightsDequantizer,
}
if shared.opts.sdnq_dequantize_compile:
try:
torch._dynamo.config.cache_size_limit = max(8192, torch._dynamo.config.cache_size_limit)
dequantize_asymmetric_compiled = torch.compile(dequantize_asymmetric, fullgraph=True)
dequantize_symmetric_compiled = torch.compile(dequantize_symmetric, fullgraph=True)
dequantize_packed_int_asymmetric_compiled = torch.compile(dequantize_packed_int_asymmetric, fullgraph=True)
dequantize_packed_int_symmetric_compiled = torch.compile(dequantize_packed_int_symmetric, fullgraph=True)
except Exception as e:
shared.log.warning(f"Quantization: type=sdnq Dequantize using torch.compile is not available: {e}")
dequantize_asymmetric_compiled = dequantize_asymmetric
dequantize_symmetric_compiled = dequantize_symmetric
dequantize_packed_int_asymmetric_compiled = dequantize_packed_int_asymmetric
dequantize_packed_int_symmetric_compiled = dequantize_packed_int_symmetric
else:
dequantize_asymmetric_compiled = dequantize_asymmetric
dequantize_symmetric_compiled = dequantize_symmetric
dequantize_packed_int_asymmetric_compiled = dequantize_packed_int_asymmetric
dequantize_packed_int_symmetric_compiled = dequantize_packed_int_symmetric