mirror of https://github.com/vladmandic/automatic
430 lines
17 KiB
Python
430 lines
17 KiB
Python
# pylint: disable=redefined-builtin,no-member,protected-access
|
|
|
|
import torch
|
|
|
|
from .common import dtype_dict
|
|
|
|
|
|
def pack_int(tensor: torch.Tensor, weights_dtype: str) -> torch.Tensor:
|
|
if not dtype_dict[weights_dtype]["is_unsigned"]:
|
|
tensor = tensor.sub(dtype_dict[weights_dtype]["min"])
|
|
return packed_int_function_dict[weights_dtype]["pack"](tensor.to(dtype=dtype_dict[weights_dtype]["storage_dtype"]))
|
|
|
|
|
|
def unpack_int(packed_tensor: torch.Tensor, weights_dtype: str, shape: torch.Size, dtype: torch.dtype = None) -> torch.Tensor:
|
|
packed_tensor = packed_int_function_dict[weights_dtype]["unpack"](packed_tensor, shape)
|
|
if not dtype_dict[weights_dtype]["is_unsigned"]:
|
|
packed_tensor = packed_tensor.to(dtype=dtype_dict[weights_dtype]["torch_dtype"] if dtype is None else dtype).add_(dtype_dict[weights_dtype]["min"])
|
|
return packed_tensor
|
|
|
|
|
|
def pack_uint14(tensor: torch.Tensor) -> torch.Tensor:
|
|
packed_tensor = tensor.contiguous().view(-1, 8)
|
|
packed_tensor = torch.bitwise_or(
|
|
packed_tensor[:, :7],
|
|
torch.bitwise_and(
|
|
torch.stack(
|
|
(
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 2),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 4),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 6),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 8),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 10),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 12),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 14),
|
|
),
|
|
dim=-1
|
|
),
|
|
49152
|
|
),
|
|
)
|
|
return packed_tensor
|
|
|
|
|
|
def pack_uint12(tensor: torch.Tensor) -> torch.Tensor:
|
|
packed_tensor = tensor.contiguous().view(-1, 4)
|
|
packed_tensor = torch.bitwise_or(
|
|
packed_tensor[:, :3],
|
|
torch.bitwise_and(
|
|
torch.stack(
|
|
(
|
|
torch.bitwise_left_shift(packed_tensor[:, 3], 4),
|
|
torch.bitwise_left_shift(packed_tensor[:, 3], 8),
|
|
torch.bitwise_left_shift(packed_tensor[:, 3], 12),
|
|
),
|
|
dim=-1
|
|
),
|
|
61440
|
|
)
|
|
)
|
|
return packed_tensor
|
|
|
|
|
|
def pack_uint10(tensor: torch.ByteTensor) -> torch.ByteTensor:
|
|
packed_tensor = tensor.contiguous().view(-1, 8)
|
|
packed_tensor = torch.cat(
|
|
(
|
|
torch.bitwise_or(packed_tensor[:, :3], torch.bitwise_left_shift(packed_tensor[:, 5:8], 10)),
|
|
torch.bitwise_or(
|
|
packed_tensor[:, 3],
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 5], 4), 15360),
|
|
torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 7], 6), 49152),
|
|
),
|
|
).unsqueeze(-1),
|
|
torch.bitwise_or(
|
|
packed_tensor[:, 4],
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 6], 4), 15360),
|
|
torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 7], 8), 49152),
|
|
),
|
|
).unsqueeze(-1),
|
|
),
|
|
dim=-1
|
|
)
|
|
return packed_tensor
|
|
|
|
|
|
def pack_uint7(tensor: torch.ByteTensor) -> torch.ByteTensor:
|
|
packed_tensor = tensor.contiguous().view(-1, 8)
|
|
packed_tensor = torch.bitwise_or(
|
|
packed_tensor[:, :7],
|
|
torch.bitwise_and(
|
|
torch.stack(
|
|
(
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 1),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 2),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 3),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 4),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 5),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 6),
|
|
torch.bitwise_left_shift(packed_tensor[:, 7], 7),
|
|
),
|
|
dim=-1
|
|
),
|
|
128
|
|
),
|
|
)
|
|
return packed_tensor
|
|
|
|
|
|
def pack_uint6(tensor: torch.ByteTensor) -> torch.ByteTensor:
|
|
packed_tensor = tensor.contiguous().view(-1, 4)
|
|
packed_tensor = torch.bitwise_or(
|
|
packed_tensor[:, :3],
|
|
torch.bitwise_and(
|
|
torch.stack(
|
|
(
|
|
torch.bitwise_left_shift(packed_tensor[:, 3], 2),
|
|
torch.bitwise_left_shift(packed_tensor[:, 3], 4),
|
|
torch.bitwise_left_shift(packed_tensor[:, 3], 6),
|
|
),
|
|
dim=-1
|
|
),
|
|
192
|
|
)
|
|
)
|
|
return packed_tensor
|
|
|
|
|
|
def pack_uint5(tensor: torch.ByteTensor) -> torch.ByteTensor:
|
|
packed_tensor = tensor.contiguous().view(-1, 8)
|
|
packed_tensor = torch.cat(
|
|
(
|
|
torch.bitwise_or(packed_tensor[:, :3], torch.bitwise_left_shift(packed_tensor[:, 5:8], 5)),
|
|
torch.bitwise_or(
|
|
packed_tensor[:, 3],
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 5], 2), 96),
|
|
torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 7], 3), 128),
|
|
),
|
|
).unsqueeze(-1),
|
|
torch.bitwise_or(
|
|
packed_tensor[:, 4],
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 6], 2), 96),
|
|
torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 7], 4), 128),
|
|
),
|
|
).unsqueeze(-1),
|
|
),
|
|
dim=-1
|
|
)
|
|
return packed_tensor
|
|
|
|
|
|
def pack_uint4(tensor: torch.ByteTensor) -> torch.ByteTensor:
|
|
packed_tensor = tensor.contiguous().view(-1, 2)
|
|
packed_tensor = torch.bitwise_or(packed_tensor[:, 0], torch.bitwise_left_shift(packed_tensor[:, 1], 4))
|
|
return packed_tensor
|
|
|
|
|
|
def pack_uint3(tensor: torch.ByteTensor) -> torch.ByteTensor:
|
|
packed_tensor = tensor.contiguous().view(-1, 8)
|
|
packed_tensor = torch.bitwise_or(
|
|
torch.bitwise_or(packed_tensor[:, :3], torch.bitwise_left_shift(packed_tensor[:, 3:6], 3)),
|
|
torch.cat(
|
|
(
|
|
torch.bitwise_left_shift(packed_tensor[:, 6:8], 6),
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 6], 4), 64),
|
|
torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 7], 5), 128),
|
|
).unsqueeze(-1),
|
|
),
|
|
dim=-1
|
|
)
|
|
)
|
|
return packed_tensor
|
|
|
|
|
|
def pack_uint2(tensor: torch.ByteTensor) -> torch.ByteTensor:
|
|
packed_tensor = tensor.contiguous().view(-1, 4)
|
|
packed_tensor = torch.bitwise_or(
|
|
torch.bitwise_or(packed_tensor[:, 0], torch.bitwise_left_shift(packed_tensor[:, 1], 2)),
|
|
torch.bitwise_or(torch.bitwise_left_shift(packed_tensor[:, 2], 4), torch.bitwise_left_shift(packed_tensor[:, 3], 6)),
|
|
)
|
|
return packed_tensor
|
|
|
|
|
|
def pack_uint1(tensor: torch.Tensor) -> torch.Tensor:
|
|
packed_tensor = tensor.contiguous().view(-1, 8)
|
|
packed_tensor = torch.bitwise_or(
|
|
torch.bitwise_or(
|
|
torch.bitwise_or(packed_tensor[:, 0], torch.bitwise_left_shift(packed_tensor[:, 1], 1)),
|
|
torch.bitwise_or(torch.bitwise_left_shift(packed_tensor[:, 2], 2), torch.bitwise_left_shift(packed_tensor[:, 3], 3))
|
|
),
|
|
torch.bitwise_or(
|
|
torch.bitwise_or(torch.bitwise_left_shift(packed_tensor[:, 4], 4), torch.bitwise_left_shift(packed_tensor[:, 5], 5)),
|
|
torch.bitwise_or(torch.bitwise_left_shift(packed_tensor[:, 6], 6), torch.bitwise_left_shift(packed_tensor[:, 7], 7))
|
|
),
|
|
)
|
|
return packed_tensor
|
|
|
|
|
|
def unpack_uint14(packed_tensor: torch.Tensor, shape: torch.Size) -> torch.Tensor:
|
|
result = torch.cat(
|
|
(
|
|
torch.bitwise_and(packed_tensor[:, :7], 16383),
|
|
torch.bitwise_or(
|
|
torch.bitwise_or(
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 0], 2), 12288),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 1], 4), 3072),
|
|
),
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 2], 6), 768),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3], 8), 192),
|
|
),
|
|
),
|
|
torch.bitwise_or(
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 4], 10), 48),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 5], 12), 12),
|
|
),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 6], 14), 3),
|
|
),
|
|
).unsqueeze(-1)
|
|
),
|
|
dim=-1
|
|
).view(shape)
|
|
return result
|
|
|
|
|
|
def unpack_uint12(packed_tensor: torch.Tensor, shape: torch.Size) -> torch.Tensor:
|
|
result = torch.cat(
|
|
(
|
|
torch.bitwise_and(packed_tensor[:, :3], 4095),
|
|
torch.bitwise_or(
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 0], 4), 3840),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 1], 8), 240),
|
|
),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 2], 12), 15)
|
|
).unsqueeze(-1)
|
|
),
|
|
dim=-1
|
|
).view(shape)
|
|
return result
|
|
|
|
|
|
def unpack_uint10(packed_tensor: torch.Tensor, shape: torch.Size) -> torch.Tensor:
|
|
result_bitwise_right_shift = torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, :3], 10), 63)
|
|
result = torch.cat(
|
|
(
|
|
torch.bitwise_and(packed_tensor[:, :5], 1023),
|
|
torch.bitwise_or(
|
|
result_bitwise_right_shift[:, :2],
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3:5], 4), 960),
|
|
),
|
|
torch.bitwise_or(
|
|
result_bitwise_right_shift[:, 2],
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3], 6), 768),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 4], 8), 192),
|
|
),
|
|
).unsqueeze(-1),
|
|
),
|
|
dim=-1
|
|
).view(shape)
|
|
return result
|
|
|
|
|
|
def unpack_uint7(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor:
|
|
result = torch.cat(
|
|
(
|
|
torch.bitwise_and(packed_tensor[:, :7], 127),
|
|
torch.bitwise_or(
|
|
torch.bitwise_or(
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 0], 1), 64),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 1], 2), 32),
|
|
),
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 2], 3), 16),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3], 4), 8),
|
|
),
|
|
),
|
|
torch.bitwise_or(
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 4], 5), 4),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 5], 6), 2),
|
|
),
|
|
torch.bitwise_right_shift(packed_tensor[:, 6], 7),
|
|
),
|
|
).unsqueeze(-1)
|
|
),
|
|
dim=-1
|
|
).view(shape)
|
|
return result
|
|
|
|
|
|
def unpack_uint6(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor:
|
|
result = torch.cat(
|
|
(
|
|
torch.bitwise_and(packed_tensor[:, :3], 63),
|
|
torch.bitwise_or(
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 0], 2), 48),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 1], 4), 12),
|
|
),
|
|
torch.bitwise_right_shift(packed_tensor[:, 2], 6)
|
|
).unsqueeze(-1)
|
|
),
|
|
dim=-1
|
|
).view(shape)
|
|
return result
|
|
|
|
|
|
def unpack_uint5(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor:
|
|
result_bitwise_right_shift = torch.bitwise_right_shift(packed_tensor[:, :3], 5)
|
|
result = torch.cat(
|
|
(
|
|
torch.bitwise_and(packed_tensor[:, :5], 31),
|
|
torch.bitwise_or(
|
|
result_bitwise_right_shift[:, :2],
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3:5], 2), 24),
|
|
),
|
|
torch.bitwise_or(
|
|
result_bitwise_right_shift[:, 2],
|
|
torch.bitwise_or(
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3], 3), 16),
|
|
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 4], 4), 8),
|
|
),
|
|
).unsqueeze(-1),
|
|
),
|
|
dim=-1
|
|
).view(shape)
|
|
return result
|
|
|
|
|
|
def unpack_uint4(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor:
|
|
result = torch.stack((torch.bitwise_and(packed_tensor, 15), torch.bitwise_right_shift(packed_tensor, 4)), dim=-1).view(shape)
|
|
return result
|
|
|
|
|
|
def unpack_uint3(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor:
|
|
result = torch.bitwise_and(
|
|
torch.cat(
|
|
(
|
|
packed_tensor[:, :3],
|
|
torch.bitwise_right_shift(packed_tensor[:, :3], 3),
|
|
torch.bitwise_or(
|
|
torch.bitwise_right_shift(packed_tensor[:, :2], 6),
|
|
torch.bitwise_and(
|
|
torch.stack(
|
|
(
|
|
torch.bitwise_right_shift(packed_tensor[:, 2], 4),
|
|
torch.bitwise_right_shift(packed_tensor[:, 2], 5),
|
|
),
|
|
dim=-1
|
|
),
|
|
4
|
|
),
|
|
),
|
|
),
|
|
dim=-1
|
|
),
|
|
7
|
|
).view(shape)
|
|
return result
|
|
|
|
|
|
def unpack_uint2(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor:
|
|
result = torch.bitwise_and(
|
|
torch.stack(
|
|
(
|
|
packed_tensor,
|
|
torch.bitwise_right_shift(packed_tensor, 2),
|
|
torch.bitwise_right_shift(packed_tensor, 4),
|
|
torch.bitwise_right_shift(packed_tensor, 6),
|
|
),
|
|
dim=-1
|
|
),
|
|
3
|
|
).view(shape)
|
|
return result
|
|
|
|
|
|
def unpack_uint1(packed_tensor: torch.Tensor, shape: torch.Size) -> torch.Tensor:
|
|
result = torch.bitwise_and(
|
|
torch.stack(
|
|
(
|
|
packed_tensor,
|
|
torch.bitwise_right_shift(packed_tensor, 1),
|
|
torch.bitwise_right_shift(packed_tensor, 2),
|
|
torch.bitwise_right_shift(packed_tensor, 3),
|
|
torch.bitwise_right_shift(packed_tensor, 4),
|
|
torch.bitwise_right_shift(packed_tensor, 5),
|
|
torch.bitwise_right_shift(packed_tensor, 6),
|
|
torch.bitwise_right_shift(packed_tensor, 7),
|
|
),
|
|
dim=-1
|
|
),
|
|
1
|
|
).view(shape)
|
|
return result
|
|
|
|
|
|
packed_int_function_dict = {
|
|
"uint14": {"pack": pack_uint14, "unpack": unpack_uint14},
|
|
"uint12": {"pack": pack_uint12, "unpack": unpack_uint12},
|
|
"uint10": {"pack": pack_uint10, "unpack": unpack_uint10},
|
|
"uint7": {"pack": pack_uint7, "unpack": unpack_uint7},
|
|
"uint6": {"pack": pack_uint6, "unpack": unpack_uint6},
|
|
"uint5": {"pack": pack_uint5, "unpack": unpack_uint5},
|
|
"uint4": {"pack": pack_uint4, "unpack": unpack_uint4},
|
|
"uint3": {"pack": pack_uint3, "unpack": unpack_uint3},
|
|
"uint2": {"pack": pack_uint2, "unpack": unpack_uint2},
|
|
"uint1": {"pack": pack_uint1, "unpack": unpack_uint1},
|
|
}
|
|
|
|
packed_int_function_dict["int14"] = packed_int_function_dict["uint14"]
|
|
packed_int_function_dict["int12"] = packed_int_function_dict["uint12"]
|
|
packed_int_function_dict["int10"] = packed_int_function_dict["uint10"]
|
|
packed_int_function_dict["int7"] = packed_int_function_dict["uint7"]
|
|
packed_int_function_dict["int6"] = packed_int_function_dict["uint6"]
|
|
packed_int_function_dict["int5"] = packed_int_function_dict["uint5"]
|
|
packed_int_function_dict["int4"] = packed_int_function_dict["uint4"]
|
|
packed_int_function_dict["int3"] = packed_int_function_dict["uint3"]
|
|
packed_int_function_dict["int2"] = packed_int_function_dict["uint2"]
|
|
packed_int_function_dict["bool"] = packed_int_function_dict["uint1"]
|