automatic/modules/sdnq/packed_float.py

110 lines
3.2 KiB
Python

import torch
from .common import dtype_dict
from .packed_int import pack_int, unpack_int
float_bits_to_uint_dict = {
1: "uint1",
2: "uint2",
3: "uint3",
4: "uint4",
5: "uint5",
6: "uint6",
7: "uint7",
9: "uint9",
10: "uint10",
11: "uint11",
12: "uint12",
13: "uint13",
14: "uint14",
15: "uint15",
}
def pack_float(x: torch.FloatTensor, weights_dtype: str) -> torch.Tensor:
exponent_bits = dtype_dict[weights_dtype]["exponent"]
mantissa_bits = dtype_dict[weights_dtype]["mantissa"]
total_bits = dtype_dict[weights_dtype]["num_bits"]
if dtype_dict[weights_dtype]["is_unsigned"]:
sign_mask = (1 << (total_bits-1)) # pylint: disable=superfluous-parens
else:
sign_mask = (1 << (total_bits-1)) + (1 << (total_bits-2))
mantissa_difference = 23 - mantissa_bits
exponent_difference = 8 - exponent_bits
mantissa_mask = (1 << mantissa_difference) # pylint: disable=superfluous-parens
x = x.to(dtype=torch.float32).view(torch.int32)
x = torch.where(
torch.gt(
torch.bitwise_and(x, -(1 << (mantissa_difference-4)) & ~(-mantissa_mask)),
(1 << (mantissa_difference-1)),
),
torch.add(x, mantissa_mask),
x,
)
x = torch.where(torch.lt(x.view(torch.float32).abs(), dtype_dict[weights_dtype]["min_normal"]), 0, x)
x = torch.bitwise_right_shift(x, mantissa_difference)
x = torch.bitwise_and(
torch.bitwise_or(
torch.bitwise_and(torch.bitwise_right_shift(x, exponent_difference), sign_mask),
torch.bitwise_and(x, ~sign_mask),
),
~(-(1 << total_bits)),
).view(torch.uint32)
if total_bits not in {8, 16}:
x = pack_int(x, float_bits_to_uint_dict[total_bits])
else:
x = x.to(dtype=dtype_dict[weights_dtype]["storage_dtype"])
return x
def unpack_float(x: torch.Tensor, weights_dtype: str, shape: torch.Size) -> torch.FloatTensor:
exponent_bits = dtype_dict[weights_dtype]["exponent"]
mantissa_bits = dtype_dict[weights_dtype]["mantissa"]
total_bits = dtype_dict[weights_dtype]["num_bits"]
if dtype_dict[weights_dtype]["is_unsigned"]:
sign_mask = (1 << (total_bits-1)) # pylint: disable=superfluous-parens
else:
sign_mask = (1 << (total_bits-1)) + (1 << (total_bits-2))
mantissa_difference = 23 - mantissa_bits
exponent_difference = 8 - exponent_bits
if total_bits not in {8, 16}:
x = unpack_int(x, float_bits_to_uint_dict[total_bits], shape)
x = x.to(dtype=torch.uint32).view(torch.int32)
x = torch.bitwise_left_shift(
torch.bitwise_or(
torch.bitwise_left_shift(torch.bitwise_and(x, sign_mask), exponent_difference),
torch.bitwise_and(x, ~sign_mask),
),
mantissa_difference,
)
x = torch.bitwise_or(
x,
torch.bitwise_and(
torch.bitwise_right_shift(
-torch.bitwise_and(torch.bitwise_not(x), 1073741824),
exponent_difference,
),
1065353216,
),
)
overflow_mask = (~(-(1 << (22 + exponent_bits))) | 1090519039)
x = torch.where(torch.bitwise_and(x, overflow_mask).to(dtype=torch.bool), x, 0)
x = x.view(torch.float32)
return x