automatic/modules/model_quant_nncf.py

554 lines
20 KiB
Python

from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from enum import Enum
import torch
from diffusers.quantizers.base import DiffusersQuantizer
from diffusers.quantizers.quantization_config import QuantizationConfigMixin
from diffusers.utils import get_module_from_name
from accelerate import init_empty_weights
from accelerate.utils import CustomDtype
from modules import devices, shared
torch_dtype_dict = {
"int8": torch.int8,
"uint8": torch.uint8,
"int4": CustomDtype.INT4,
"uint4": CustomDtype.INT4,
}
weights_dtype_dict = {
"int8_asym": "uint8",
"int8_sym": "int8",
"int4_asym": "uint4",
"int4_sym": "int4",
"int8": "uint8",
"int4": "uint4",
}
linear_types = ["NNCFLinear", "Linear"]
conv_types = ["NNCFConv1d", "NNCFConv2d", "NNCFConv3d", "Conv1d", "Conv2d", "Conv3d"]
conv_transpose_types = ["NNCFConvTranspose1d", "NNCFConvTranspose2d", "NNCFConvTranspose3d", "ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d"]
allowed_types = []
allowed_types.extend(linear_types)
allowed_types.extend(conv_types)
allowed_types.extend(conv_transpose_types)
class QuantizationMethod(str, Enum):
NNCF = "nncf"
# de-abstracted and modified slightly from the actual quant functions of nncf 2.16.0:
def nncf_compress_layer(layer, num_bits, is_asym_mode, torch_dtype=None, quant_conv=False, param_name=None):
if layer.__class__.__name__ in allowed_types:
if torch_dtype is None:
torch_dtype = devices.dtype
if layer.__class__.__name__ in conv_types:
if is_asym_mode or not quant_conv: # don't quant convs with asym mode
return layer
reduction_axes = [i for i in range(layer.weight.ndim) if i != 0]
if layer.__class__.__name__ in conv_transpose_types:
if is_asym_mode or not quant_conv: # don't quant convs with asym mode
return layer
reduction_axes = [i for i in range(layer.weight.ndim) if i != 1]
else:
reduction_axes = [layer.weight.ndim - 1]
if shared.opts.diffusers_offload_mode != "none":
return_device = layer.weight.data.device
else:
return_device = devices.device
layer.weight.data = layer.weight.data.to(devices.device, dtype=torch.float32)
if is_asym_mode:
level_low = 0
level_high = 2**num_bits - 1
min_values = torch.amin(layer.weight, dim=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
max_values = torch.amax(layer.weight, dim=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
levels = level_high - level_low + 1
scale = ((max_values - min_values) / (levels - 1)).to(dtype=torch.float32)
eps = torch.finfo(scale.dtype).eps
scale = torch.where(torch.abs(scale) < eps, eps, scale)
zero_point = level_low - torch.round(min_values / scale)
zero_point = torch.clip(zero_point.to(dtype=torch.int32), level_low, level_high)
else:
factor = 2 ** (num_bits - 1)
w_abs_min = torch.abs(torch.amin(layer.weight, dim=reduction_axes, keepdims=True))
w_max = torch.amax(layer.weight, dim=reduction_axes, keepdims=True)
scale = torch.where(w_abs_min >= w_max, w_abs_min, -w_max)
scale /= factor
eps = torch.finfo(scale.dtype).eps
scale = torch.where(torch.abs(scale) < eps, eps, scale)
zero_point = None
dtype = torch.uint8 if is_asym_mode else torch.int8
level_low = 0 if is_asym_mode else -(2 ** (num_bits - 1))
level_high = 2**num_bits - 1 if is_asym_mode else 2 ** (num_bits - 1) - 1
compressed_weight = layer.weight.data / scale
if zero_point is not None:
compressed_weight += zero_point.to(dtype=layer.weight.dtype)
compressed_weight = torch.round(compressed_weight)
compressed_weight = torch.clip(compressed_weight, level_low, level_high).to(dtype)
if num_bits == 4:
if is_asym_mode:
decompressor = INT4AsymmetricWeightsDecompressor(
scale=scale.data,
zero_point=zero_point.data,
compressed_weight_shape=compressed_weight.shape,
result_shape=layer.weight.shape,
result_dtype=torch_dtype
)
else:
decompressor = INT4SymmetricWeightsDecompressor(
scale=scale.data,
compressed_weight_shape=compressed_weight.shape,
result_shape=layer.weight.shape,
result_dtype=torch_dtype
)
else:
if is_asym_mode:
decompressor = INT8AsymmetricWeightsDecompressor(
scale=scale.data,
zero_point=zero_point.data,
result_dtype=torch_dtype
)
else:
decompressor = INT8SymmetricWeightsDecompressor(
scale=scale.data,
result_dtype=torch_dtype
)
compressed_weight = decompressor.pack_weight(compressed_weight)
compressed_weight = compressed_weight.to(return_device)
decompressor = decompressor.to(return_device)
layer.register_pre_forward_operation(decompressor)
layer.weight.requires_grad = False
layer.weight.data = compressed_weight
return layer
def apply_nncf_to_module(model, num_bits, is_asym_mode, quant_conv=False):
has_children = list(model.children())
if not has_children:
return model
for param_name, module in model.named_children():
if module.__class__.__name__.startswith("NNCF") and hasattr(module, "weight") and module.weight is not None:
module = nncf_compress_layer(module, num_bits, is_asym_mode, torch_dtype=devices.dtype, quant_conv=quant_conv, param_name=param_name)
module = apply_nncf_to_module(module, num_bits, is_asym_mode, quant_conv=quant_conv)
return model
def nncf_send_to_device(model, device):
for child in model.children():
if "WeightsDecompressor" in child.__class__.__name__:
child.scale = child.scale.to(device)
if hasattr(child, "zero_point"):
child.zero_point = child.zero_point.to(device)
nncf_send_to_device(child, device)
class NNCFQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for NNCF
"""
requires_parameters_quantization = True
use_keep_in_fp32_modules = True
requires_calibration = False
required_packages = ["nncf"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
def check_if_quantized_param(
self,
model,
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
module, tensor_name = get_module_from_name(model, param_name)
return module.__class__.__name__.startswith("NNCF") and param_name.endswith(".weight")
def check_quantized_param(self, *args, **kwargs) -> bool:
"""
needed for transformers compatibilty, returns self.check_if_quantized_param
"""
return self.check_if_quantized_param(*args, **kwargs)
def create_quantized_param(
self,
model,
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: List[str],
**kwargs,
):
# load the model params to target_device first
layer, tensor_name = get_module_from_name(model, param_name)
layer._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
split_param_name = param_name.split(".")
if param_name not in self.modules_to_not_convert and not any(param in split_param_name for param in self.modules_to_not_convert):
layer = nncf_compress_layer(
layer,
self.quantization_config.num_bits,
self.quantization_config.is_asym_mode,
torch_dtype=self.torch_dtype,
param_name=param_name
)
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.70 for key, val in max_memory.items()}
return max_memory
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
return torch_dtype_dict[self.quantization_config.weights_dtype]
def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
if torch_dtype is None:
torch_dtype = devices.dtype
self.torch_dtype = torch_dtype
return torch_dtype
def _process_model_before_weight_loading(
self,
model,
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
if keep_in_fp32_modules is not None:
self.modules_to_not_convert.extend(keep_in_fp32_modules)
model.config.quantization_config = self.quantization_config
if model.__class__.__name__ in {"T5EncoderModel", "UMT5EncoderModel"}:
for i in range(len(model.encoder.block)):
model.encoder.block[i].layer[1].DenseReluDense = NNCF_T5DenseGatedActDense(
model.encoder.block[i].layer[1].DenseReluDense,
dtype=torch.float32 if devices.dtype != torch.bfloat16 else torch.bfloat16
)
with init_empty_weights():
model, _ = replace_modules_by_nncf_modules(model)
def _process_model_after_weight_loading(self, model, **kwargs):
nncf_send_to_device(model, devices.device)
return model
def update_tp_plan(self, config):
"""
needed for transformers compatibilty, no-op function
"""
return config
def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str) -> List[str]:
"""
needed for transformers compatibilty, no-op function
"""
return unexpected_keys
def update_missing_keys_after_loading(self, model, missing_keys: List[str], prefix: str) -> List[str]:
"""
needed for transformers compatibilty, no-op function
"""
return missing_keys
def update_expected_keys(self, model, expected_keys: List[str], loaded_keys: List[str]) -> List[str]:
"""
needed for transformers compatibilty, no-op function
"""
return expected_keys
@property
def is_trainable(self):
return False
@property
def is_serializable(self):
return False
@dataclass
class NNCFConfig(QuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `nncf`.
Args:
weights_dtype (`str`, *optional*, defaults to `"int8"`):
The target dtype for the weights after quantization. Supported values are ("int8", "int8_sym", "int4", "int4_sym")
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
"""
def __init__(
self,
weights_dtype: str = "int8_sym",
modules_to_not_convert: Optional[List[str]] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.NNCF
self.weights_dtype = weights_dtype_dict[weights_dtype.lower()]
self.modules_to_not_convert = modules_to_not_convert
self.post_init()
self.num_bits = 8 if self.weights_dtype in {"int8", "uint8"} else 4
self.is_asym_mode = self.weights_dtype in {"uint8", "uint4"}
self.is_integer = True
self.group_size = -1
def post_init(self):
r"""
Safety checker that arguments are correct
"""
accepted_weights = ["int8", "uint8", "int4", "uint4"]
if self.weights_dtype not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
class NNCF_T5DenseGatedActDense(torch.nn.Module): # forward can't find what self is without creating a class
def __init__(self, T5DenseGatedActDense, dtype):
super().__init__()
self.wi_0 = T5DenseGatedActDense.wi_0
self.wi_1 = T5DenseGatedActDense.wi_1
self.wo = T5DenseGatedActDense.wo
self.dropout = T5DenseGatedActDense.dropout
self.act = T5DenseGatedActDense.act
self.torch_dtype = dtype
def forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states.to(self.torch_dtype) # this line needs to be forced
hidden_states = self.wo(hidden_states)
return hidden_states
# WeightsDecompressor classes and functions are modified from NNCF 2.16.0
def decompress_asymmetric(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor) -> torch.Tensor:
input = input.to(dtype=scale.dtype)
zero_point = zero_point.to(dtype=scale.dtype)
decompressed_input = (input - zero_point) * scale
return decompressed_input
def decompress_symmetric(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
input = input.to(dtype=scale.dtype)
decompressed_input = input * scale
return decompressed_input
def unpack_uint4(packed_tensor: torch.Tensor) -> torch.Tensor:
return torch.stack((torch.bitwise_and(packed_tensor, 15), torch.bitwise_right_shift(packed_tensor, 4)), dim=-1)
def unpack_int4(packed_tensor: torch.Tensor) -> torch.Tensor:
t = unpack_uint4(packed_tensor)
return t.to(dtype=torch.int8) - 8
def pack_uint4(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype != torch.uint8:
msg = f"Invalid tensor dtype {tensor.type}. torch.uint8 type is supported."
raise RuntimeError(msg)
packed_tensor = tensor.contiguous()
packed_tensor = packed_tensor.reshape(-1, 2)
packed_tensor = torch.bitwise_and(packed_tensor[..., ::2], 15) | packed_tensor[..., 1::2] << 4
return packed_tensor
def pack_int4(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype != torch.int8:
msg = f"Invalid tensor dtype {tensor.type}. torch.int8 type is supported."
raise RuntimeError(msg)
tensor = tensor + 8
return pack_uint4(tensor.to(dtype=torch.uint8))
class INT8AsymmetricWeightsDecompressor(torch.nn.Module):
def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor, result_dtype: Optional[torch.dtype] = None):
super().__init__()
self.scale = scale
self.zero_point = self.pack_weight(zero_point)
self.result_dtype = result_dtype
@property
def num_bits(self):
return 8
@property
def quantization_mode(self):
return "asymmetric"
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(weight):
msg = f"Invalid weight dtype {weight.type}. Integer types are supported."
raise ValueError(msg)
if torch.any((weight < 0) | (weight > 255)):
msg = "Weight values are not in [0, 255]."
raise ValueError(msg)
return weight.to(dtype=torch.uint8)
def forward(self, x, *args, return_decompressed_only=False):
result = decompress_asymmetric(x.weight, self.scale, self.zero_point)
result = result.to(dtype=self.result_dtype)
if return_decompressed_only:
return result
else:
x.weight = result
class INT8SymmetricWeightsDecompressor(torch.nn.Module):
def __init__(self, scale: torch.Tensor, result_dtype: Optional[torch.dtype] = None):
super().__init__()
self.scale = scale
self.result_dtype = result_dtype
@property
def num_bits(self):
return 8
@property
def quantization_mode(self):
return "symmetric"
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.any((weight < -128) | (weight > 127)):
msg = "Weight values are not in [-128, 127]."
raise ValueError(msg)
return weight.to(dtype=torch.int8)
def forward(self, x, *args, return_decompressed_only=False):
result = decompress_symmetric(x.weight, self.scale)
result = result.to(dtype=self.result_dtype)
if return_decompressed_only:
return result
else:
x.weight = result
class INT4AsymmetricWeightsDecompressor(torch.nn.Module):
def __init__(
self,
scale: torch.Tensor,
zero_point: torch.Tensor,
compressed_weight_shape: Tuple[int, ...],
result_shape: Optional[Tuple[int, ...]] = None,
result_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.scale = scale
self.zero_point_shape = zero_point.shape
self.zero_point = self.pack_weight(zero_point)
self.compressed_weight_shape = compressed_weight_shape
self.result_shape = result_shape
self.result_dtype = result_dtype
@property
def num_bits(self):
return 4
@property
def quantization_mode(self):
return "asymmetric"
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.any((weight < 0) | (weight > 15)):
msg = "Weight values are not in [0, 15]."
raise ValueError(msg)
return pack_uint4(weight.to(dtype=torch.uint8))
def forward(self, x, *args, return_decompressed_only=False):
result = unpack_uint4(x.weight)
result = result.reshape(self.compressed_weight_shape)
zero_point = unpack_uint4(self.zero_point)
zero_point = zero_point.reshape(self.zero_point_shape)
result = decompress_asymmetric(result, self.scale, zero_point)
result = result.reshape(self.result_shape) if self.result_shape is not None else result
result = result.to(dtype=self.result_dtype)
if return_decompressed_only:
return result
else:
x.weight = result
class INT4SymmetricWeightsDecompressor(torch.nn.Module):
def __init__(
self,
scale: torch.Tensor,
compressed_weight_shape: Tuple[int, ...],
result_shape: Optional[Tuple[int, ...]] = None,
result_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.scale = scale
self.compressed_weight_shape = compressed_weight_shape
self.result_shape = result_shape
self.result_dtype = result_dtype
@property
def num_bits(self):
return 4
@property
def quantization_mode(self):
return "symmetric"
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(weight):
msg = f"Invalid weight dtype {weight.type}. Integer types are supported."
raise ValueError(msg)
if torch.any((weight < -8) | (weight > 7)):
msg = "Tensor values are not in [-8, 7]."
raise ValueError(msg)
return pack_int4(weight.to(dtype=torch.int8))
def forward(self, x, *arg, return_decompressed_only=False):
result = unpack_int4(x.weight)
result = result.reshape(self.compressed_weight_shape)
result = decompress_symmetric(result, self.scale)
result = result.reshape(self.result_shape) if self.result_shape is not None else result
result = result.to(dtype=self.result_dtype)
if return_decompressed_only:
return result
else:
x.weight = result