# pylint: disable=redefined-builtin,no-member,protected-access import torch from .common import dtype_dict def pack_int(tensor: torch.Tensor, weights_dtype: str) -> torch.Tensor: if not dtype_dict[weights_dtype]["is_unsigned"]: tensor = tensor.sub(dtype_dict[weights_dtype]["min"]) return packed_int_function_dict[weights_dtype]["pack"](tensor.to(dtype=dtype_dict[weights_dtype]["storage_dtype"])) def unpack_int(packed_tensor: torch.Tensor, weights_dtype: str, shape: torch.Size, dtype: torch.dtype = None) -> torch.Tensor: packed_tensor = packed_int_function_dict[weights_dtype]["unpack"](packed_tensor, shape) if not dtype_dict[weights_dtype]["is_unsigned"]: packed_tensor = packed_tensor.to(dtype=dtype_dict[weights_dtype]["torch_dtype"] if dtype is None else dtype).add_(dtype_dict[weights_dtype]["min"]) return packed_tensor def pack_uint14(tensor: torch.Tensor) -> torch.Tensor: packed_tensor = tensor.contiguous().view(-1, 8) packed_tensor = torch.bitwise_or( packed_tensor[:, :7], torch.bitwise_and( torch.stack( ( torch.bitwise_left_shift(packed_tensor[:, 7], 2), torch.bitwise_left_shift(packed_tensor[:, 7], 4), torch.bitwise_left_shift(packed_tensor[:, 7], 6), torch.bitwise_left_shift(packed_tensor[:, 7], 8), torch.bitwise_left_shift(packed_tensor[:, 7], 10), torch.bitwise_left_shift(packed_tensor[:, 7], 12), torch.bitwise_left_shift(packed_tensor[:, 7], 14), ), dim=-1 ), 49152 ), ) return packed_tensor def pack_uint12(tensor: torch.Tensor) -> torch.Tensor: packed_tensor = tensor.contiguous().view(-1, 4) packed_tensor = torch.bitwise_or( packed_tensor[:, :3], torch.bitwise_and( torch.stack( ( torch.bitwise_left_shift(packed_tensor[:, 3], 4), torch.bitwise_left_shift(packed_tensor[:, 3], 8), torch.bitwise_left_shift(packed_tensor[:, 3], 12), ), dim=-1 ), 61440 ) ) return packed_tensor def pack_uint10(tensor: torch.ByteTensor) -> torch.ByteTensor: packed_tensor = tensor.contiguous().view(-1, 8) packed_tensor = torch.cat( ( torch.bitwise_or(packed_tensor[:, :3], torch.bitwise_left_shift(packed_tensor[:, 5:8], 10)), torch.bitwise_or( packed_tensor[:, 3], torch.bitwise_or( torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 5], 4), 15360), torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 7], 6), 49152), ), ).unsqueeze(-1), torch.bitwise_or( packed_tensor[:, 4], torch.bitwise_or( torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 6], 4), 15360), torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 7], 8), 49152), ), ).unsqueeze(-1), ), dim=-1 ) return packed_tensor def pack_uint7(tensor: torch.ByteTensor) -> torch.ByteTensor: packed_tensor = tensor.contiguous().view(-1, 8) packed_tensor = torch.bitwise_or( packed_tensor[:, :7], torch.bitwise_and( torch.stack( ( torch.bitwise_left_shift(packed_tensor[:, 7], 1), torch.bitwise_left_shift(packed_tensor[:, 7], 2), torch.bitwise_left_shift(packed_tensor[:, 7], 3), torch.bitwise_left_shift(packed_tensor[:, 7], 4), torch.bitwise_left_shift(packed_tensor[:, 7], 5), torch.bitwise_left_shift(packed_tensor[:, 7], 6), torch.bitwise_left_shift(packed_tensor[:, 7], 7), ), dim=-1 ), 128 ), ) return packed_tensor def pack_uint6(tensor: torch.ByteTensor) -> torch.ByteTensor: packed_tensor = tensor.contiguous().view(-1, 4) packed_tensor = torch.bitwise_or( packed_tensor[:, :3], torch.bitwise_and( torch.stack( ( torch.bitwise_left_shift(packed_tensor[:, 3], 2), torch.bitwise_left_shift(packed_tensor[:, 3], 4), torch.bitwise_left_shift(packed_tensor[:, 3], 6), ), dim=-1 ), 192 ) ) return packed_tensor def pack_uint5(tensor: torch.ByteTensor) -> torch.ByteTensor: packed_tensor = tensor.contiguous().view(-1, 8) packed_tensor = torch.cat( ( torch.bitwise_or(packed_tensor[:, :3], torch.bitwise_left_shift(packed_tensor[:, 5:8], 5)), torch.bitwise_or( packed_tensor[:, 3], torch.bitwise_or( torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 5], 2), 96), torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 7], 3), 128), ), ).unsqueeze(-1), torch.bitwise_or( packed_tensor[:, 4], torch.bitwise_or( torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 6], 2), 96), torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 7], 4), 128), ), ).unsqueeze(-1), ), dim=-1 ) return packed_tensor def pack_uint4(tensor: torch.ByteTensor) -> torch.ByteTensor: packed_tensor = tensor.contiguous().view(-1, 2) packed_tensor = torch.bitwise_or(packed_tensor[:, 0], torch.bitwise_left_shift(packed_tensor[:, 1], 4)) return packed_tensor def pack_uint3(tensor: torch.ByteTensor) -> torch.ByteTensor: packed_tensor = tensor.contiguous().view(-1, 8) packed_tensor = torch.bitwise_or( torch.bitwise_or(packed_tensor[:, :3], torch.bitwise_left_shift(packed_tensor[:, 3:6], 3)), torch.cat( ( torch.bitwise_left_shift(packed_tensor[:, 6:8], 6), torch.bitwise_or( torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 6], 4), 64), torch.bitwise_and(torch.bitwise_left_shift(packed_tensor[:, 7], 5), 128), ).unsqueeze(-1), ), dim=-1 ) ) return packed_tensor def pack_uint2(tensor: torch.ByteTensor) -> torch.ByteTensor: packed_tensor = tensor.contiguous().view(-1, 4) packed_tensor = torch.bitwise_or( torch.bitwise_or(packed_tensor[:, 0], torch.bitwise_left_shift(packed_tensor[:, 1], 2)), torch.bitwise_or(torch.bitwise_left_shift(packed_tensor[:, 2], 4), torch.bitwise_left_shift(packed_tensor[:, 3], 6)), ) return packed_tensor def pack_uint1(tensor: torch.Tensor) -> torch.Tensor: packed_tensor = tensor.contiguous().view(-1, 8) packed_tensor = torch.bitwise_or( torch.bitwise_or( torch.bitwise_or(packed_tensor[:, 0], torch.bitwise_left_shift(packed_tensor[:, 1], 1)), torch.bitwise_or(torch.bitwise_left_shift(packed_tensor[:, 2], 2), torch.bitwise_left_shift(packed_tensor[:, 3], 3)) ), torch.bitwise_or( torch.bitwise_or(torch.bitwise_left_shift(packed_tensor[:, 4], 4), torch.bitwise_left_shift(packed_tensor[:, 5], 5)), torch.bitwise_or(torch.bitwise_left_shift(packed_tensor[:, 6], 6), torch.bitwise_left_shift(packed_tensor[:, 7], 7)) ), ) return packed_tensor def unpack_uint14(packed_tensor: torch.Tensor, shape: torch.Size) -> torch.Tensor: result = torch.cat( ( torch.bitwise_and(packed_tensor[:, :7], 16383), torch.bitwise_or( torch.bitwise_or( torch.bitwise_or( torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 0], 2), 12288), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 1], 4), 3072), ), torch.bitwise_or( torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 2], 6), 768), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3], 8), 192), ), ), torch.bitwise_or( torch.bitwise_or( torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 4], 10), 48), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 5], 12), 12), ), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 6], 14), 3), ), ).unsqueeze(-1) ), dim=-1 ).view(shape) return result def unpack_uint12(packed_tensor: torch.Tensor, shape: torch.Size) -> torch.Tensor: result = torch.cat( ( torch.bitwise_and(packed_tensor[:, :3], 4095), torch.bitwise_or( torch.bitwise_or( torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 0], 4), 3840), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 1], 8), 240), ), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 2], 12), 15) ).unsqueeze(-1) ), dim=-1 ).view(shape) return result def unpack_uint10(packed_tensor: torch.Tensor, shape: torch.Size) -> torch.Tensor: result_bitwise_right_shift = torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, :3], 10), 63) result = torch.cat( ( torch.bitwise_and(packed_tensor[:, :5], 1023), torch.bitwise_or( result_bitwise_right_shift[:, :2], torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3:5], 4), 960), ), torch.bitwise_or( result_bitwise_right_shift[:, 2], torch.bitwise_or( torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3], 6), 768), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 4], 8), 192), ), ).unsqueeze(-1), ), dim=-1 ).view(shape) return result def unpack_uint7(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor: result = torch.cat( ( torch.bitwise_and(packed_tensor[:, :7], 127), torch.bitwise_or( torch.bitwise_or( torch.bitwise_or( torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 0], 1), 64), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 1], 2), 32), ), torch.bitwise_or( torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 2], 3), 16), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3], 4), 8), ), ), torch.bitwise_or( torch.bitwise_or( torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 4], 5), 4), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 5], 6), 2), ), torch.bitwise_right_shift(packed_tensor[:, 6], 7), ), ).unsqueeze(-1) ), dim=-1 ).view(shape) return result def unpack_uint6(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor: result = torch.cat( ( torch.bitwise_and(packed_tensor[:, :3], 63), torch.bitwise_or( torch.bitwise_or( torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 0], 2), 48), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 1], 4), 12), ), torch.bitwise_right_shift(packed_tensor[:, 2], 6) ).unsqueeze(-1) ), dim=-1 ).view(shape) return result def unpack_uint5(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor: result_bitwise_right_shift = torch.bitwise_right_shift(packed_tensor[:, :3], 5) result = torch.cat( ( torch.bitwise_and(packed_tensor[:, :5], 31), torch.bitwise_or( result_bitwise_right_shift[:, :2], torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3:5], 2), 24), ), torch.bitwise_or( result_bitwise_right_shift[:, 2], torch.bitwise_or( torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 3], 3), 16), torch.bitwise_and(torch.bitwise_right_shift(packed_tensor[:, 4], 4), 8), ), ).unsqueeze(-1), ), dim=-1 ).view(shape) return result def unpack_uint4(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor: result = torch.stack((torch.bitwise_and(packed_tensor, 15), torch.bitwise_right_shift(packed_tensor, 4)), dim=-1).view(shape) return result def unpack_uint3(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor: result = torch.bitwise_and( torch.cat( ( packed_tensor[:, :3], torch.bitwise_right_shift(packed_tensor[:, :3], 3), torch.bitwise_or( torch.bitwise_right_shift(packed_tensor[:, :2], 6), torch.bitwise_and( torch.stack( ( torch.bitwise_right_shift(packed_tensor[:, 2], 4), torch.bitwise_right_shift(packed_tensor[:, 2], 5), ), dim=-1 ), 4 ), ), ), dim=-1 ), 7 ).view(shape) return result def unpack_uint2(packed_tensor: torch.ByteTensor, shape: torch.Size) -> torch.ByteTensor: result = torch.bitwise_and( torch.stack( ( packed_tensor, torch.bitwise_right_shift(packed_tensor, 2), torch.bitwise_right_shift(packed_tensor, 4), torch.bitwise_right_shift(packed_tensor, 6), ), dim=-1 ), 3 ).view(shape) return result def unpack_uint1(packed_tensor: torch.Tensor, shape: torch.Size) -> torch.Tensor: result = torch.bitwise_and( torch.stack( ( packed_tensor, torch.bitwise_right_shift(packed_tensor, 1), torch.bitwise_right_shift(packed_tensor, 2), torch.bitwise_right_shift(packed_tensor, 3), torch.bitwise_right_shift(packed_tensor, 4), torch.bitwise_right_shift(packed_tensor, 5), torch.bitwise_right_shift(packed_tensor, 6), torch.bitwise_right_shift(packed_tensor, 7), ), dim=-1 ), 1 ).view(shape) return result packed_int_function_dict = { "uint14": {"pack": pack_uint14, "unpack": unpack_uint14}, "uint12": {"pack": pack_uint12, "unpack": unpack_uint12}, "uint10": {"pack": pack_uint10, "unpack": unpack_uint10}, "uint7": {"pack": pack_uint7, "unpack": unpack_uint7}, "uint6": {"pack": pack_uint6, "unpack": unpack_uint6}, "uint5": {"pack": pack_uint5, "unpack": unpack_uint5}, "uint4": {"pack": pack_uint4, "unpack": unpack_uint4}, "uint3": {"pack": pack_uint3, "unpack": unpack_uint3}, "uint2": {"pack": pack_uint2, "unpack": unpack_uint2}, "uint1": {"pack": pack_uint1, "unpack": unpack_uint1}, } packed_int_function_dict["int14"] = packed_int_function_dict["uint14"] packed_int_function_dict["int12"] = packed_int_function_dict["uint12"] packed_int_function_dict["int10"] = packed_int_function_dict["uint10"] packed_int_function_dict["int7"] = packed_int_function_dict["uint7"] packed_int_function_dict["int6"] = packed_int_function_dict["uint6"] packed_int_function_dict["int5"] = packed_int_function_dict["uint5"] packed_int_function_dict["int4"] = packed_int_function_dict["uint4"] packed_int_function_dict["int3"] = packed_int_function_dict["uint3"] packed_int_function_dict["int2"] = packed_int_function_dict["uint2"] packed_int_function_dict["bool"] = packed_int_function_dict["uint1"]