mirror of https://github.com/vladmandic/automatic
110 lines
3.2 KiB
Python
110 lines
3.2 KiB
Python
import torch
|
|
|
|
from .common import dtype_dict
|
|
from .packed_int import pack_int, unpack_int
|
|
|
|
|
|
float_bits_to_uint_dict = {
|
|
1: "uint1",
|
|
2: "uint2",
|
|
3: "uint3",
|
|
4: "uint4",
|
|
5: "uint5",
|
|
6: "uint6",
|
|
7: "uint7",
|
|
9: "uint9",
|
|
10: "uint10",
|
|
11: "uint11",
|
|
12: "uint12",
|
|
13: "uint13",
|
|
14: "uint14",
|
|
15: "uint15",
|
|
}
|
|
|
|
|
|
def pack_float(x: torch.FloatTensor, weights_dtype: str) -> torch.Tensor:
|
|
exponent_bits = dtype_dict[weights_dtype]["exponent"]
|
|
mantissa_bits = dtype_dict[weights_dtype]["mantissa"]
|
|
total_bits = dtype_dict[weights_dtype]["num_bits"]
|
|
|
|
if dtype_dict[weights_dtype]["is_unsigned"]:
|
|
sign_mask = (1 << (total_bits-1)) # pylint: disable=superfluous-parens
|
|
else:
|
|
sign_mask = (1 << (total_bits-1)) + (1 << (total_bits-2))
|
|
|
|
mantissa_difference = 23 - mantissa_bits
|
|
exponent_difference = 8 - exponent_bits
|
|
mantissa_mask = (1 << mantissa_difference) # pylint: disable=superfluous-parens
|
|
|
|
x = x.to(dtype=torch.float32).view(torch.int32)
|
|
|
|
x = torch.where(
|
|
torch.gt(
|
|
torch.bitwise_and(x, -(1 << (mantissa_difference-4)) & ~(-mantissa_mask)),
|
|
(1 << (mantissa_difference-1)),
|
|
),
|
|
torch.add(x, mantissa_mask),
|
|
x,
|
|
)
|
|
|
|
x = torch.where(torch.lt(x.view(torch.float32).abs(), dtype_dict[weights_dtype]["min_normal"]), 0, x)
|
|
|
|
x = torch.bitwise_right_shift(x, mantissa_difference)
|
|
x = torch.bitwise_and(
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_right_shift(x, exponent_difference), sign_mask),
|
|
torch.bitwise_and(x, ~sign_mask),
|
|
),
|
|
~(-(1 << total_bits)),
|
|
).view(torch.uint32)
|
|
|
|
if total_bits not in {8, 16}:
|
|
x = pack_int(x, float_bits_to_uint_dict[total_bits])
|
|
else:
|
|
x = x.to(dtype=dtype_dict[weights_dtype]["storage_dtype"])
|
|
|
|
return x
|
|
|
|
|
|
def unpack_float(x: torch.Tensor, weights_dtype: str, shape: torch.Size) -> torch.FloatTensor:
|
|
exponent_bits = dtype_dict[weights_dtype]["exponent"]
|
|
mantissa_bits = dtype_dict[weights_dtype]["mantissa"]
|
|
total_bits = dtype_dict[weights_dtype]["num_bits"]
|
|
|
|
if dtype_dict[weights_dtype]["is_unsigned"]:
|
|
sign_mask = (1 << (total_bits-1)) # pylint: disable=superfluous-parens
|
|
else:
|
|
sign_mask = (1 << (total_bits-1)) + (1 << (total_bits-2))
|
|
|
|
mantissa_difference = 23 - mantissa_bits
|
|
exponent_difference = 8 - exponent_bits
|
|
|
|
if total_bits not in {8, 16}:
|
|
x = unpack_int(x, float_bits_to_uint_dict[total_bits], shape)
|
|
|
|
x = x.to(dtype=torch.uint32).view(torch.int32)
|
|
x = torch.bitwise_left_shift(
|
|
torch.bitwise_or(
|
|
torch.bitwise_left_shift(torch.bitwise_and(x, sign_mask), exponent_difference),
|
|
torch.bitwise_and(x, ~sign_mask),
|
|
),
|
|
mantissa_difference,
|
|
)
|
|
|
|
x = torch.bitwise_or(
|
|
x,
|
|
torch.bitwise_and(
|
|
torch.bitwise_right_shift(
|
|
-torch.bitwise_and(torch.bitwise_not(x), 1073741824),
|
|
exponent_difference,
|
|
),
|
|
1065353216,
|
|
),
|
|
)
|
|
|
|
overflow_mask = (~(-(1 << (22 + exponent_bits))) | 1090519039)
|
|
x = torch.where(torch.bitwise_and(x, overflow_mask).to(dtype=torch.bool), x, 0)
|
|
x = x.view(torch.float32)
|
|
|
|
return x
|