SDNQ add new stack of custom floating point types and remove irrelevant qtypes from the ui list

pull/4500/head
Disty0 2025-12-26 20:09:17 +03:00
parent 6a2b7d37ab
commit 4a4784eafa
6 changed files with 508 additions and 108 deletions

View File

@ -8,18 +8,22 @@ from modules import shared, devices
sdnq_version = "0.1.3"
dtype_dict = {
### Integers
"int32": {"min": -2147483648, "max": 2147483647, "num_bits": 32, "sign": 1, "exponent": 0, "mantissa": 31, "target_dtype": torch.int32, "torch_dtype": torch.int32, "storage_dtype": torch.int32, "is_unsigned": False, "is_integer": True, "is_packed": False},
"int16": {"min": -32768, "max": 32767, "num_bits": 16, "sign": 1, "exponent": 0, "mantissa": 15, "target_dtype": torch.int16, "torch_dtype": torch.int16, "storage_dtype": torch.int16, "is_unsigned": False, "is_integer": True, "is_packed": False},
"int8": {"min": -128, "max": 127, "num_bits": 8, "sign": 1, "exponent": 0, "mantissa": 7, "target_dtype": torch.int8, "torch_dtype": torch.int8, "storage_dtype": torch.int8, "is_unsigned": False, "is_integer": True, "is_packed": False},
### Custom Integers
"int7": {"min": -64, "max": 63, "num_bits": 7, "sign": 1, "exponent": 0, "mantissa": 6, "target_dtype": "int7", "torch_dtype": torch.int8, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": True, "is_packed": True},
"int6": {"min": -32, "max": 31, "num_bits": 6, "sign": 1, "exponent": 0, "mantissa": 5, "target_dtype": "int6", "torch_dtype": torch.int8, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": True, "is_packed": True},
"int5": {"min": -16, "max": 15, "num_bits": 5, "sign": 1, "exponent": 0, "mantissa": 4, "target_dtype": "int5", "torch_dtype": torch.int8, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": True, "is_packed": True},
"int4": {"min": -8, "max": 7, "num_bits": 4, "sign": 1, "exponent": 0, "mantissa": 3, "target_dtype": "int4", "torch_dtype": torch.int8, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": True, "is_packed": True},
"int3": {"min": -4, "max": 3, "num_bits": 3, "sign": 1, "exponent": 0, "mantissa": 2, "target_dtype": "int3", "torch_dtype": torch.int8, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": True, "is_packed": True},
"int2": {"min": -2, "max": 1, "num_bits": 2, "sign": 1, "exponent": 0, "mantissa": 1, "target_dtype": "int2", "torch_dtype": torch.int8, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": True, "is_packed": True},
### Unsigned Integers
"uint32": {"min": 0, "max": 4294967295, "num_bits": 32, "sign": 0, "exponent": 0, "mantissa": 32, "target_dtype": torch.uint32, "torch_dtype": torch.uint32, "storage_dtype": torch.uint32, "is_unsigned": True, "is_integer": True, "is_packed": False},
"uint16": {"min": 0, "max": 65535, "num_bits": 16, "sign": 0, "exponent": 0, "mantissa": 16, "target_dtype": torch.uint16, "torch_dtype": torch.uint16, "storage_dtype": torch.uint16, "is_unsigned": True, "is_integer": True, "is_packed": False},
"uint8": {"min": 0, "max": 255, "num_bits": 8, "sign": 0, "exponent": 0, "mantissa": 8, "target_dtype": torch.uint8, "torch_dtype": torch.uint8, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": True, "is_packed": False},
### Custom Unsigned Integers
"uint7": {"min": 0, "max": 127, "num_bits": 7, "sign": 0, "exponent": 0, "mantissa": 7, "target_dtype": "uint7", "torch_dtype": torch.uint8, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": True, "is_packed": True},
"uint6": {"min": 0, "max": 63, "num_bits": 6, "sign": 0, "exponent": 0, "mantissa": 6, "target_dtype": "uint6", "torch_dtype": torch.uint8, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": True, "is_packed": True},
"uint5": {"min": 0, "max": 31, "num_bits": 5, "sign": 0, "exponent": 0, "mantissa": 5, "target_dtype": "uint5", "torch_dtype": torch.uint8, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": True, "is_packed": True},
@ -27,18 +31,107 @@ dtype_dict = {
"uint3": {"min": 0, "max": 7, "num_bits": 3, "sign": 0, "exponent": 0, "mantissa": 3, "target_dtype": "uint3", "torch_dtype": torch.uint8, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": True, "is_packed": True},
"uint2": {"min": 0, "max": 3, "num_bits": 2, "sign": 0, "exponent": 0, "mantissa": 2, "target_dtype": "uint2", "torch_dtype": torch.uint8, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": True, "is_packed": True},
"uint1": {"min": 0, "max": 1, "num_bits": 1, "sign": 0, "exponent": 0, "mantissa": 1, "target_dtype": torch.bool, "torch_dtype": torch.bool, "storage_dtype": torch.bool, "is_unsigned": True, "is_integer": True, "is_packed": True},
### Floats
"float32": {"min": -3.40282e+38, "max": 3.40282e+38, "num_bits": 32, "sign": 1, "exponent": 8, "mantissa": 23, "target_dtype": torch.float32, "torch_dtype": torch.float32, "storage_dtype": torch.float32, "is_unsigned": False, "is_integer": False, "is_packed": False},
"bfloat16": {"min": -3.38953e+38, "max": 3.38953e+38, "num_bits": 16, "sign": 1, "exponent": 8, "mantissa": 7, "target_dtype": torch.bfloat16, "torch_dtype": torch.bfloat16, "storage_dtype": torch.bfloat16, "is_unsigned": False, "is_integer": False, "is_packed": False},
"float16": {"min": -65504, "max": 65504, "num_bits": 16, "sign": 1, "exponent": 5, "mantissa": 10, "target_dtype": torch.float16, "torch_dtype": torch.float16, "storage_dtype": torch.float16, "is_unsigned": False, "is_integer": False, "is_packed": False},
"float8_e4m3fn": {"min": -448, "max": 448, "num_bits": 8, "sign": 1, "exponent": 4, "mantissa": 3, "target_dtype": torch.float8_e4m3fn, "torch_dtype": torch.float8_e4m3fn, "storage_dtype": torch.float8_e4m3fn, "is_unsigned": False, "is_integer": False, "is_packed": False},
"float8_e5m2": {"min": -57344, "max": 57344, "num_bits": 8, "sign": 1, "exponent": 5, "mantissa": 2, "target_dtype": torch.float8_e5m2, "torch_dtype": torch.float8_e5m2, "storage_dtype": torch.float8_e5m2, "is_unsigned": False, "is_integer": False, "is_packed": False},
"float16": {"min": -65504.0, "max": 65504.0, "num_bits": 16, "sign": 1, "exponent": 5, "mantissa": 10, "target_dtype": torch.float16, "torch_dtype": torch.float16, "storage_dtype": torch.float16, "is_unsigned": False, "is_integer": False, "is_packed": False},
"float8_e4m3fn": {"min": -448.0, "max": 448.0, "num_bits": 8, "sign": 1, "exponent": 4, "mantissa": 3, "target_dtype": torch.float8_e4m3fn, "torch_dtype": torch.float8_e4m3fn, "storage_dtype": torch.float8_e4m3fn, "is_unsigned": False, "is_integer": False, "is_packed": False},
"float8_e5m2": {"min": -57344.0, "max": 57344.0, "num_bits": 8, "sign": 1, "exponent": 5, "mantissa": 2, "target_dtype": torch.float8_e5m2, "torch_dtype": torch.float8_e5m2, "storage_dtype": torch.float8_e5m2, "is_unsigned": False, "is_integer": False, "is_packed": False},
### Custom Floats
"float16_e1m14fn": {"min": -3.9998779296875, "max": 3.9998779296875, "num_bits": 16, "sign": 1, "exponent": 1, "mantissa": 14, "min_normal": 1.00006103515625, "target_dtype": torch.float16, "torch_dtype": torch.float32, "storage_dtype": torch.uint16, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float16_e2m13fn": {"min": -7.99951171875, "max": 7.99951171875, "num_bits": 16, "sign": 1, "exponent": 2, "mantissa": 13, "min_normal": 0.50006103515625, "target_dtype": torch.float16, "torch_dtype": torch.float32, "storage_dtype": torch.uint16, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float16_e3m12fn": {"min": -31.99609375, "max": 31.99609375, "num_bits": 16, "sign": 1, "exponent": 3, "mantissa": 12, "min_normal": 0.125030517578125, "target_dtype": torch.float16, "torch_dtype": torch.float32, "storage_dtype": torch.uint16, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float16_e4m11fn": {"min": -511.875, "max": 511.875, "num_bits": 16, "sign": 1, "exponent": 4, "mantissa": 11, "min_normal": 0.007816314697265625, "target_dtype": torch.float16, "torch_dtype": torch.float32, "storage_dtype": torch.uint16, "is_unsigned": False, "is_integer": False, "is_packed": True},
# float16_e5m10 is native in PyTorch
"float8_e1m6fn": {"min": -3.96875, "max": 3.96875, "num_bits": 8, "sign": 1, "exponent": 1, "mantissa": 6, "min_normal": 1.015625, "target_dtype": "fp8", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float8_e2m5fn": {"min": -7.875, "max": 7.875, "num_bits": 8, "sign": 1, "exponent": 2, "mantissa": 5, "min_normal": 0.515625, "target_dtype": "fp8", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float8_e3m4fn": {"min": -31.0, "max": 31.0, "num_bits": 8, "sign": 1, "exponent": 3, "mantissa": 4, "min_normal": 0.1328125, "target_dtype": "fp8", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
# float8_e4m3fn is native in PyTorch
# float8_e5m2fn is native in PyTorch
"float7_e1m5fn": {"min": -3.9375, "max": 3.9375, "num_bits": 7, "sign": 1, "exponent": 1, "mantissa": 5, "min_normal": 1.03125, "target_dtype": "fp7", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float7_e2m4fn": {"min": -7.75, "max": 7.75, "num_bits": 7, "sign": 1, "exponent": 2, "mantissa": 4, "min_normal": 0.53125, "target_dtype": "fp7", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float7_e3m3fn": {"min": -30.0, "max": 30.0, "num_bits": 7, "sign": 1, "exponent": 3, "mantissa": 3, "min_normal": 0.140625, "target_dtype": "fp7", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float7_e4m2fn": {"min": -448.0, "max": 448.0, "num_bits": 7, "sign": 1, "exponent": 4, "mantissa": 2, "min_normal": 0.009765625, "target_dtype": "fp7", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float7_e5m1fn": {"min": -98304.0, "max": 98304.0, "num_bits": 7, "sign": 1, "exponent": 5, "mantissa": 1, "min_normal": 4.57763671875e-05, "target_dtype": "fp7", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
#
"float6_e1m4fn": {"min": -3.875, "max": 3.875, "num_bits": 6, "sign": 1, "exponent": 1, "mantissa": 4, "min_normal": 1.0625, "target_dtype": "fp6", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float6_e2m3fn": {"min": -7.5, "max": 7.5, "num_bits": 6, "sign": 1, "exponent": 2, "mantissa": 3, "min_normal": 0.5625, "target_dtype": "fp6", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float6_e3m2fn": {"min": -28.0, "max": 28.0, "num_bits": 6, "sign": 1, "exponent": 3, "mantissa": 2, "min_normal": 0.15625, "target_dtype": "fp6", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float6_e4m1fn": {"min": -384.0, "max": 384.0, "num_bits": 6, "sign": 1, "exponent": 4, "mantissa": 1, "min_normal": 0.01171875, "target_dtype": "fp6", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float6_e5m0fn": {"min": -65536.0, "max": 65536.0, "num_bits": 6, "sign": 1, "exponent": 5, "mantissa": 0, "min_normal": 6.103515625e-05, "target_dtype": "fp6", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
#
"float5_e1m3fn": {"min": -3.75, "max": 3.75, "num_bits": 5, "sign": 1, "exponent": 1, "mantissa": 3, "min_normal": 1.125, "target_dtype": "fp5", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float5_e2m2fn": {"min": -7.0, "max": 7.0, "num_bits": 5, "sign": 1, "exponent": 2, "mantissa": 2, "min_normal": 0.625, "target_dtype": "fp5", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float5_e3m1fn": {"min": -24.0, "max": 24.0, "num_bits": 5, "sign": 1, "exponent": 3, "mantissa": 1, "min_normal": 0.1875, "target_dtype": "fp5", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float5_e4m0fn": {"min": -256.0, "max": 256.0, "num_bits": 5, "sign": 1, "exponent": 4, "mantissa": 0, "min_normal": 0.015625, "target_dtype": "fp5", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
#
"float4_e1m2fn": {"min": -3.5, "max": 3.5, "num_bits": 4, "sign": 1, "exponent": 1, "mantissa": 2, "min_normal": 1.25, "target_dtype": "fp4", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float4_e2m1fn": {"min": -6.0, "max": 6.0, "num_bits": 4, "sign": 1, "exponent": 2, "mantissa": 1, "min_normal": 0.75, "target_dtype": "fp4", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float4_e3m0fn": {"min": -16.0, "max": 16.0, "num_bits": 4, "sign": 1, "exponent": 3, "mantissa": 0, "min_normal": 0.25, "target_dtype": "fp4", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
#
"float3_e1m1fn": {"min": -3.0, "max": 3.0, "num_bits": 3, "sign": 1, "exponent": 1, "mantissa": 1, "min_normal": 1.5, "target_dtype": "fp3", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
"float3_e2m0fn": {"min": -4.0, "max": 4.0, "num_bits": 3, "sign": 1, "exponent": 2, "mantissa": 0, "min_normal": 1.0, "target_dtype": "fp3", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
#
"float2_e1m0fn": {"min": -2.0, "max": 2.0, "num_bits": 2, "sign": 1, "exponent": 1, "mantissa": 0, "min_normal": 2.0, "target_dtype": "fp2", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": False, "is_integer": False, "is_packed": True},
### Custom Usigned Floats
"float16_e1m15fnu": {"min": 0, "max": 3.99993896484375, "num_bits": 16, "sign": 0, "exponent": 1, "mantissa": 15, "min_normal": 1.000030517578125, "target_dtype": torch.float16, "torch_dtype": torch.float32, "storage_dtype": torch.uint16, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float16_e2m14fnu": {"min": 0, "max": 7.999755859375, "num_bits": 16, "sign": 0, "exponent": 2, "mantissa": 14, "min_normal": 0.500030517578125, "target_dtype": torch.float16, "torch_dtype": torch.float32, "storage_dtype": torch.uint16, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float16_e3m13fnu": {"min": 0, "max": 31.998046875, "num_bits": 16, "sign": 0, "exponent": 3, "mantissa": 13, "min_normal": 0.1250152587890625, "target_dtype": torch.float16, "torch_dtype": torch.float32, "storage_dtype": torch.uint16, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float16_e4m12fnu": {"min": 0, "max": 511.9375, "num_bits": 16, "sign": 0, "exponent": 4, "mantissa": 12, "min_normal": 0.007814407348632812, "target_dtype": torch.float16, "torch_dtype": torch.float32, "storage_dtype": torch.uint16, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float16_e5m11fnu": {"min": 0, "max": 131040.0, "num_bits": 16, "sign": 0, "exponent": 5, "mantissa": 11, "min_normal": 3.053247928619385e-05, "target_dtype": torch.float16, "torch_dtype": torch.float32, "storage_dtype": torch.uint16, "is_unsigned": True, "is_integer": False, "is_packed": True},
#
"float8_e1m7fnu": {"min": 0, "max": 3.984375, "num_bits": 8, "sign": 0, "exponent": 1, "mantissa": 7, "min_normal": 1.0078125, "target_dtype": "fp8", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float8_e2m6fnu": {"min": 0, "max": 7.9375, "num_bits": 8, "sign": 0, "exponent": 2, "mantissa": 6, "min_normal": 0.5078125, "target_dtype": "fp8", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float8_e3m5fnu": {"min": 0, "max": 31.5, "num_bits": 8, "sign": 0, "exponent": 3, "mantissa": 5, "min_normal": 0.12890625, "target_dtype": "fp8", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float8_e4m4fnu": {"min": 0, "max": 496.0, "num_bits": 8, "sign": 0, "exponent": 4, "mantissa": 4, "min_normal": 0.00830078125, "target_dtype": "fp8", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float8_e5m3fnu": {"min": 0, "max": 122880.0, "num_bits": 8, "sign": 0, "exponent": 5, "mantissa": 3, "min_normal": 3.4332275390625e-05, "target_dtype": "fp8", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
#
"float7_e1m6fnu": {"min": 0, "max": 3.96875, "num_bits": 7, "sign": 0, "exponent": 1, "mantissa": 6, "min_normal": 1.015625, "target_dtype": "fp7", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float7_e2m5fnu": {"min": 0, "max": 7.875, "num_bits": 7, "sign": 0, "exponent": 2, "mantissa": 5, "min_normal": 0.515625, "target_dtype": "fp7", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float7_e3m4fnu": {"min": 0, "max": 31.0, "num_bits": 7, "sign": 0, "exponent": 3, "mantissa": 4, "min_normal": 0.1328125, "target_dtype": "fp7", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float7_e4m3fnu": {"min": 0, "max": 480.0, "num_bits": 7, "sign": 0, "exponent": 4, "mantissa": 3, "min_normal": 0.0087890625, "target_dtype": "fp7", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float7_e5m2fnu": {"min": 0, "max": 114688.0, "num_bits": 7, "sign": 0, "exponent": 5, "mantissa": 2, "min_normal": 3.814697265625e-05, "target_dtype": "fp7", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
#
"float6_e1m5fnu": {"min": 0, "max": 3.9375, "num_bits": 6, "sign": 0, "exponent": 1, "mantissa": 5, "min_normal": 1.03125, "target_dtype": "fp6", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float6_e2m4fnu": {"min": 0, "max": 7.75, "num_bits": 6, "sign": 0, "exponent": 2, "mantissa": 4, "min_normal": 0.53125, "target_dtype": "fp6", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float6_e3m3fnu": {"min": 0, "max": 30.0, "num_bits": 6, "sign": 0, "exponent": 3, "mantissa": 3, "min_normal": 0.140625, "target_dtype": "fp6", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float6_e4m2fnu": {"min": 0, "max": 448.0, "num_bits": 6, "sign": 0, "exponent": 4, "mantissa": 2, "min_normal": 0.009765625, "target_dtype": "fp6", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float6_e5m1fnu": {"min": 0, "max": 98304.0, "num_bits": 6, "sign": 0, "exponent": 5, "mantissa": 1, "min_normal": 4.57763671875e-05, "target_dtype": "fp6", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
#
"float5_e1m4fnu": {"min": 0, "max": 3.875, "num_bits": 5, "sign": 0, "exponent": 1, "mantissa": 4, "min_normal": 1.0625, "target_dtype": "fp5", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float5_e2m3fnu": {"min": 0, "max": 7.5, "num_bits": 5, "sign": 0, "exponent": 2, "mantissa": 3, "min_normal": 0.5625, "target_dtype": "fp5", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float5_e3m2fnu": {"min": 0, "max": 28.0, "num_bits": 5, "sign": 0, "exponent": 3, "mantissa": 2, "min_normal": 0.15625, "target_dtype": "fp5", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float5_e4m1fnu": {"min": 0, "max": 384.0, "num_bits": 5, "sign": 0, "exponent": 4, "mantissa": 1, "min_normal": 0.01171875, "target_dtype": "fp5", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float5_e5m0fnu": {"min": 0, "max": 65536.0, "num_bits": 5, "sign": 0, "exponent": 5, "mantissa": 0, "min_normal": 6.103515625e-05, "target_dtype": "fp5", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
#
"float4_e1m3fnu": {"min": 0, "max": 3.75, "num_bits": 4, "sign": 0, "exponent": 1, "mantissa": 3, "min_normal": 1.125, "target_dtype": "fp4", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float4_e2m2fnu": {"min": 0, "max": 7.0, "num_bits": 4, "sign": 0, "exponent": 2, "mantissa": 2, "min_normal": 0.625, "target_dtype": "fp4", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float4_e3m1fnu": {"min": 0, "max": 24.0, "num_bits": 4, "sign": 0, "exponent": 3, "mantissa": 1, "min_normal": 0.1875, "target_dtype": "fp4", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float4_e4m0fnu": {"min": 0, "max": 256.0, "num_bits": 4, "sign": 0, "exponent": 4, "mantissa": 0, "min_normal": 0.015625, "target_dtype": "fp4", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
#
"float3_e1m2fnu": {"min": 0, "max": 3.5, "num_bits": 3, "sign": 0, "exponent": 1, "mantissa": 2, "min_normal": 1.25, "target_dtype": "fp3", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float3_e2m1fnu": {"min": 0, "max": 6.0, "num_bits": 3, "sign": 0, "exponent": 2, "mantissa": 1, "min_normal": 0.75, "target_dtype": "fp3", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float3_e3m0fnu": {"min": 0, "max": 16.0, "num_bits": 3, "sign": 0, "exponent": 3, "mantissa": 0, "min_normal": 0.25, "target_dtype": "fp3", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
#
"float2_e1m1fnu": {"min": 0, "max": 3.0, "num_bits": 2, "sign": 0, "exponent": 1, "mantissa": 1, "min_normal": 1.5, "target_dtype": "fp2", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
"float2_e2m0fnu": {"min": 0, "max": 4.0, "num_bits": 2, "sign": 0, "exponent": 2, "mantissa": 0, "min_normal": 1.0, "target_dtype": "fp2", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
#
"float1_e1m0fnu": {"min": 0, "max": 2.0, "num_bits": 1, "sign": 0, "exponent": 1, "mantissa": 0, "min_normal": 2.0, "target_dtype": "fp1", "torch_dtype": torch.float32, "storage_dtype": torch.uint8, "is_unsigned": True, "is_integer": False, "is_packed": True},
}
dtype_dict["fp32"] = dtype_dict["float32"]
dtype_dict["bf16"] = dtype_dict["bfloat16"]
dtype_dict["fp16"] = dtype_dict["float16"]
dtype_dict["fp8"] = dtype_dict["float8_e4m3fn"]
dtype_dict["fp7"] = dtype_dict["float7_e3m3fn"]
dtype_dict["fp6"] = dtype_dict["float6_e2m3fn"]
dtype_dict["fp5"] = dtype_dict["float5_e2m2fn"]
dtype_dict["fp4"] = dtype_dict["float4_e2m1fn"]
dtype_dict["fp3"] = dtype_dict["float3_e1m1fn"]
dtype_dict["fp2"] = dtype_dict["float2_e1m0fn"]
dtype_dict["fp1"] = dtype_dict["float1_e1m0fnu"]
dtype_dict["bool"] = dtype_dict["uint1"]
dtype_dict["int1"] = dtype_dict["uint1"]
torch_dtype_dict = {
torch.int32: "int32",
@ -55,19 +148,42 @@ torch_dtype_dict = {
}
if hasattr(torch, "float8_e4m3fnuz"):
dtype_dict["float8_e4m3fnuz"] = {"min": -240, "max": 240, "num_bits": 8, "sign": 1, "exponent": 4, "mantissa": 3, "target_dtype": "fp8", "torch_dtype": torch.float8_e4m3fnuz, "storage_dtype": torch.float8_e4m3fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
dtype_dict["float8_e4m3fnuz"] = {"min": -240.0, "max": 240.0, "num_bits": 8, "sign": 1, "exponent": 4, "mantissa": 3, "target_dtype": "fp8", "torch_dtype": torch.float8_e4m3fnuz, "storage_dtype": torch.float8_e4m3fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
torch_dtype_dict[torch.float8_e4m3fnuz] = "float8_e4m3fnuz"
if hasattr(torch, "float8_e5m2fnuz"):
dtype_dict["float8_e5m2fnuz"] = {"min": -57344, "max": 57344, "num_bits": 8, "sign": 1, "exponent": 5, "mantissa": 2, "target_dtype": "fp8", "torch_dtype": torch.float8_e5m2fnuz, "storage_dtype": torch.float8_e5m2fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
dtype_dict["float8_e5m2fnuz"] = {"min": -57344.0, "max": 57344.0, "num_bits": 8, "sign": 1, "exponent": 5, "mantissa": 2, "target_dtype": "fp8", "torch_dtype": torch.float8_e5m2fnuz, "storage_dtype": torch.float8_e5m2fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
torch_dtype_dict[torch.float8_e5m2fnuz] = "float8_e5m2fnuz"
linear_types = {"Linear"}
conv_types = {"Conv1d", "Conv2d", "Conv3d"}
conv_transpose_types = {"ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d"}
allowed_types = set.union(linear_types, conv_types, conv_transpose_types)
accepted_weight_dtypes = set(dtype_dict.keys())
accepted_matmul_dtypes = {"int8", "fp8", "fp16", "float8_e4m3fnuz", "float16"}
weights_dtype_order = [
"uint1", "float1_e1m0fnu",
"int2", "float2_e1m0fn",
"uint2", "float2_e1m1fnu", "float2_e2m0fnu",
"int3", "float3_e1m1fn", "float3_e2m0fn",
"uint3", "float3_e1m2fnu", "float3_e2m1fnu", "float3_e3m0fnu",
"int4", "float4_e1m2fn", "float4_e2m1fn", "float4_e3m0fn",
"uint4", "float4_e1m3fnu", "float4_e2m2fnu", "float4_e3m1fnu", "float4_e4m0fnu"
"int5", "float5_e1m3fn", "float5_e2m2fn", "float5_e3m1fn", "float5_e4m0fn",
"uint5", "float5_e1m4fnu", "float5_e2m3fnu", "float5_e3m2fnu", "float5_e4m1fnu", "float5_e5m0fnu",
"int6", "float6_e1m4fn", "float6_e2m3fn", "float6_e3m2fn", "float6_e4m1fn", "float6_e5m0fn",
"uint6", "float6_e1m5fnu", "float6_e2m4fnu", "float6_e3m3fnu", "float6_e4m2fnu", "float6_e5m1fnu"
"int7", "float7_e1m5fn", "float7_e2m4fn", "float7_e3m3fn", "float7_e4m2fn", "float7_e5m1fn",
"uint7", "float7_e1m6fnu", "float7_e2m5fnu", "float7_e3m4fnu", "float7_e4m3fnu", "float7_e5m2fnu",
"int8", "float8_e4m3fn", "float8_e5m2", "float8_e1m6fn", "float8_e2m5fn", "float8_e3m4fn",
"uint8", "float8_e1m7fnu", "float8_e2m6fnu", "float8_e3m5fnu", "float8_e4m4fnu", "float8_e5m3fnu",
]
weights_dtype_order_fp32 = weights_dtype_order + [
"int16", "float16", "float16_e1m14fn", "float16_e2m13fn", "float16_e3m12fn", "float16_e4m11fn",
"uint16", "float16_e1m15fnu", "float16_e2m14fnu", "float16_e3m13fnu", "float16_e4m12fnu", "float16_e5m11fnu",
]
is_rdna2 = bool(devices.backend == "rocm" and int(getattr(torch.cuda.get_device_properties(devices.device), "gcnArchName", "gfx0000")[3:]) < 1100)
use_torch_compile = shared.opts.sdnq_dequantize_compile # this setting requires a full restart of the webui to apply

View File

@ -8,6 +8,7 @@ import torch
from modules import devices
from .common import dtype_dict, compile_func, use_contiguous_mm, use_tensorwise_fp8_matmul
from .packed_int import unpack_int_symetric, unpack_int_asymetric
from .packed_float import unpack_float
@devices.inference_context()
@ -25,7 +26,7 @@ def dequantize_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, ze
if result.ndim > 2 and weight.ndim > 2: # convs
result = result.add_(torch.mm(svd_up, svd_down).unflatten(-1, (*result.shape[1:],)))
else:
result = result.addmm_(svd_up, svd_down)
result = result.to(dtype=svd_up.dtype).addmm_(svd_up, svd_down)
if dtype is not None:
result = result.to(dtype=dtype)
return result
@ -48,7 +49,7 @@ def dequantize_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, svd
if result.ndim > 2 and weight.ndim > 2: # convs
result = result.add_(torch.mm(svd_up, svd_down).unflatten(-1, (*result.shape[1:],)))
else:
result = result.addmm_(svd_up, svd_down)
result = result.to(dtype=svd_up.dtype).addmm_(svd_up, svd_down)
if dtype is not None:
result = result.to(dtype=dtype)
return result
@ -70,10 +71,20 @@ def dequantize_packed_int_asymmetric(weight: torch.ByteTensor, scale: torch.Floa
@devices.inference_context()
def dequantize_packed_int_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, dtype: Optional[torch.dtype] = None, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor:
def dequantize_packed_int_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor:
return dequantize_symmetric(unpack_int_symetric(weight, shape, weights_dtype, dtype=scale.dtype), scale, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=re_quantize_for_matmul)
@devices.inference_context()
def dequantize_packed_float_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor:
return dequantize_asymmetric(unpack_float(weight, shape, weights_dtype), scale, zero_point, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=result_shape, skip_quantized_matmul=skip_quantized_matmul)
@devices.inference_context()
def dequantize_packed_float_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor:
return dequantize_symmetric(unpack_float(weight, shape, weights_dtype), scale, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=re_quantize_for_matmul)
@devices.inference_context()
def quantize_int_mm(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "int8") -> Tuple[torch.Tensor, torch.FloatTensor]:
scale = torch.amax(input.abs(), dim=dim, keepdims=True).div_(dtype_dict[matmul_dtype]["max"])
@ -156,6 +167,16 @@ def re_quantize_matmul_packed_int_symmetric(weight: torch.ByteTensor, scale: tor
return re_quantize_matmul_symmetric(unpack_int_symetric(weight, shape, weights_dtype, dtype=scale.dtype), scale, matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=result_shape)
@devices.inference_context()
def re_quantize_matmul_packed_float_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]:
return re_quantize_matmul_asymmetric(unpack_float(weight, shape, weights_dtype), scale, zero_point, matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=result_shape)
@devices.inference_context()
def re_quantize_matmul_packed_float_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: Optional[torch.Size] = None, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]:
return re_quantize_matmul_symmetric(unpack_float(weight, shape, weights_dtype, dtype=scale.dtype), scale, matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=result_shape)
@devices.inference_context()
def dequantize_layer_weight(self: torch.nn.Module, inplace: bool = False):
weight = torch.nn.Parameter(self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down, skip_quantized_matmul=self.sdnq_dequantizer.use_quantized_matmul), requires_grad=True)
@ -248,10 +269,16 @@ class SDNQDequantizer:
@devices.inference_context()
def re_quantize_matmul(self, weight, scale, zero_point, svd_up, svd_down): # pylint: disable=unused-argument
if self.is_packed:
if self.is_unsigned:
return re_quantize_matmul_packed_int_asymmetric_compiled(weight, scale, zero_point, self.quantized_weight_shape, self.weights_dtype, self.quantized_matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=self.result_shape)
if self.is_integer:
if self.is_unsigned:
return re_quantize_matmul_packed_int_asymmetric_compiled(weight, scale, zero_point, self.quantized_weight_shape, self.weights_dtype, self.quantized_matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=self.result_shape)
else:
return re_quantize_matmul_packed_int_symmetric_compiled(weight, scale, self.quantized_weight_shape, self.weights_dtype, self.quantized_matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=self.result_shape)
else:
return re_quantize_matmul_packed_int_symmetric_compiled(weight, scale, self.quantized_weight_shape, self.weights_dtype, self.quantized_matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=self.result_shape)
if self.is_unsigned:
return re_quantize_matmul_packed_float_asymmetric_compiled(weight, scale, zero_point, self.quantized_weight_shape, self.weights_dtype, self.quantized_matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=self.result_shape)
else:
return re_quantize_matmul_packed_float_symmetric_compiled(weight, scale, self.quantized_weight_shape, self.weights_dtype, self.quantized_matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=self.result_shape)
else:
if self.is_unsigned:
return re_quantize_matmul_asymmetric_compiled(weight, scale, zero_point, self.quantized_matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=self.result_shape)
@ -263,16 +290,28 @@ class SDNQDequantizer:
if dtype is None:
dtype = self.result_dtype
if self.is_packed:
if self.is_unsigned:
if skip_compile: # compiled training needs to be traced with the original function
return dequantize_packed_int_asymmetric(weight, scale, zero_point, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul)
if self.is_integer:
if self.is_unsigned:
if skip_compile: # compiled training needs to be traced with the original function
return dequantize_packed_int_asymmetric(weight, scale, zero_point, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul)
else:
return dequantize_packed_int_asymmetric_compiled(weight, scale, zero_point, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul)
else:
return dequantize_packed_int_asymmetric_compiled(weight, scale, zero_point, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul)
if skip_compile:
return dequantize_packed_int_symmetric(weight, scale, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=self.re_quantize_for_matmul)
else:
return dequantize_packed_int_symmetric_compiled(weight, scale, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=self.re_quantize_for_matmul)
else:
if skip_compile:
return dequantize_packed_int_symmetric(weight, scale, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=self.re_quantize_for_matmul)
if self.is_unsigned:
if skip_compile: # compiled training needs to be traced with the original function
return dequantize_packed_float_asymmetric(weight, scale, zero_point, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul)
else:
return dequantize_packed_float_asymmetric_compiled(weight, scale, zero_point, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul)
else:
return dequantize_packed_int_symmetric_compiled(weight, scale, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=self.re_quantize_for_matmul)
if skip_compile:
return dequantize_packed_float_symmetric(weight, scale, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=self.re_quantize_for_matmul)
else:
return dequantize_packed_float_symmetric_compiled(weight, scale, self.quantized_weight_shape, self.weights_dtype, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=self.result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=self.re_quantize_for_matmul)
else:
if self.is_unsigned:
if skip_compile:
@ -290,7 +329,11 @@ dequantize_asymmetric_compiled = compile_func(dequantize_asymmetric)
dequantize_symmetric_compiled = compile_func(dequantize_symmetric)
dequantize_packed_int_asymmetric_compiled = compile_func(dequantize_packed_int_asymmetric)
dequantize_packed_int_symmetric_compiled = compile_func(dequantize_packed_int_symmetric)
dequantize_packed_float_asymmetric_compiled = compile_func(dequantize_packed_float_asymmetric)
dequantize_packed_float_symmetric_compiled = compile_func(dequantize_packed_float_symmetric)
re_quantize_matmul_asymmetric_compiled = compile_func(re_quantize_matmul_asymmetric)
re_quantize_matmul_symmetric_compiled = compile_func(re_quantize_matmul_symmetric)
re_quantize_matmul_packed_int_asymmetric_compiled = compile_func(re_quantize_matmul_packed_int_asymmetric)
re_quantize_matmul_packed_int_symmetric_compiled = compile_func(re_quantize_matmul_packed_int_symmetric)
re_quantize_matmul_packed_float_asymmetric_compiled = compile_func(re_quantize_matmul_packed_float_asymmetric)
re_quantize_matmul_packed_float_symmetric_compiled = compile_func(re_quantize_matmul_packed_float_symmetric)

View File

@ -3,7 +3,7 @@ import json
import torch
from diffusers.models.modeling_utils import ModelMixin
from .common import dtype_dict, use_tensorwise_fp8_matmul, check_torch_compile
from .common import dtype_dict, use_tensorwise_fp8_matmul, check_torch_compile, conv_types, linear_types
from .quantizer import SDNQConfig, sdnq_post_load_quant, prepare_weight_for_matmul, prepare_svd_for_matmul, get_quant_args_from_config
from .forward import get_forward_func
from .file_loader import load_files
@ -106,7 +106,7 @@ def load_sdnq_model(model_path: str, model_cls: ModelMixin = None, file_name: st
else:
model = model_cls(**model_config)
model = sdnq_post_load_quant(model, torch_dtype=dtype, add_skip_keys=False, **get_quant_args_from_config(quantization_config))
model = sdnq_post_load_quant(model, torch_dtype=dtype, add_skip_keys=False, use_dynamic_quantization=False, **get_quant_args_from_config(quantization_config))
key_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
files = []
@ -170,6 +170,18 @@ def apply_sdnq_options_to_module(model, dtype: torch.dtype = None, dequantize_fp
return model
for module_name, module in model.named_children():
if hasattr(module, "sdnq_dequantizer"):
layer_class_name = module.__class__.__name__
current_use_quantized_matmul = use_quantized_matmul
if current_use_quantized_matmul:
if layer_class_name in conv_types:
output_channel_size, channel_size = module.sdnq_dequantizer.original_shape[:2]
elif layer_class_name in linear_types:
output_channel_size, channel_size = module.sdnq_dequantizer.original_shape
else:
current_use_quantized_matmul = False
current_use_quantized_matmul = current_use_quantized_matmul and channel_size >= 32 and output_channel_size >= 32
current_use_quantized_matmul = current_use_quantized_matmul and output_channel_size % 16 == 0 and channel_size % 16 == 0
if dtype is not None and module.sdnq_dequantizer.result_dtype != torch.float32:
module.sdnq_dequantizer.result_dtype = dtype
@ -177,7 +189,7 @@ def apply_sdnq_options_to_module(model, dtype: torch.dtype = None, dequantize_fp
dequantize_fp32
or dtype_dict[module.sdnq_dequantizer.weights_dtype]["num_bits"] > 8
or (
(use_quantized_matmul or (use_quantized_matmul is None and module.sdnq_dequantizer.use_quantized_matmul))
(current_use_quantized_matmul or (current_use_quantized_matmul is None and module.sdnq_dequantizer.use_quantized_matmul))
and not dtype_dict[module.sdnq_dequantizer.quantized_matmul_dtype]["is_integer"]
and (not use_tensorwise_fp8_matmul or dtype_dict[module.sdnq_dequantizer.quantized_matmul_dtype]["num_bits"] == 16)
)
@ -191,19 +203,19 @@ def apply_sdnq_options_to_module(model, dtype: torch.dtype = None, dequantize_fp
module.svd_up.data = module.svd_up.to(dtype=scale_dtype)
module.svd_down.data = module.svd_down.to(dtype=scale_dtype)
if use_quantized_matmul is not None and use_quantized_matmul != module.sdnq_dequantizer.use_quantized_matmul:
if current_use_quantized_matmul is not None and current_use_quantized_matmul != module.sdnq_dequantizer.use_quantized_matmul:
if not module.sdnq_dequantizer.re_quantize_for_matmul:
module.scale.t_()
module.weight.t_()
if use_quantized_matmul:
if current_use_quantized_matmul:
module.weight.data = prepare_weight_for_matmul(module.weight)
else:
module.scale.data = module.scale.contiguous()
module.weight.data = module.weight.contiguous()
if module.svd_up is not None:
module.svd_up.data, module.svd_down.data = prepare_svd_for_matmul(module.svd_up.t_(), module.svd_down.t_(), use_quantized_matmul)
module.sdnq_dequantizer.use_quantized_matmul = use_quantized_matmul
module.forward = get_forward_func(module.__class__.__name__, module.sdnq_dequantizer.quantized_matmul_dtype, use_quantized_matmul)
module.svd_up.data, module.svd_down.data = prepare_svd_for_matmul(module.svd_up.t_(), module.svd_down.t_(), current_use_quantized_matmul)
module.sdnq_dequantizer.use_quantized_matmul = current_use_quantized_matmul
module.forward = get_forward_func(module.__class__.__name__, module.sdnq_dequantizer.quantized_matmul_dtype, current_use_quantized_matmul)
module.forward = module.forward.__get__(module, module.__class__)
setattr(model, module_name, module)
else:

View File

@ -0,0 +1,102 @@
import torch
from .common import dtype_dict
from .packed_int import pack_int_asymetric, unpack_int_asymetric
float_bits_to_uint_dict = {
1: "uint1",
2: "uint2",
3: "uint3",
4: "uint4",
5: "uint5",
6: "uint6",
7: "uint7",
}
def pack_float(x: torch.FloatTensor, weights_dtype: str) -> torch.Tensor:
exponent_bits = dtype_dict[weights_dtype]["exponent"]
mantissa_bits = dtype_dict[weights_dtype]["mantissa"]
total_bits = dtype_dict[weights_dtype]["num_bits"]
if dtype_dict[weights_dtype]["is_unsigned"]:
sign_mask = (1 << (total_bits-1))
else:
sign_mask = (1 << (total_bits-1)) + (1 << (total_bits-2))
mantissa_difference = 23 - mantissa_bits
exponent_difference = 8 - exponent_bits
mantissa_mask = (1 << mantissa_difference)
x = x.to(dtype=torch.float32).view(torch.int32)
x = torch.where(
torch.greater(
torch.bitwise_and(x, -(1 << (mantissa_difference-4)) & ~(-mantissa_mask)),
(1 << (mantissa_difference-1)),
),
torch.add(x, mantissa_mask),
x,
)
x = torch.where(torch.lt(x.view(torch.float32).abs(), dtype_dict[weights_dtype]["min_normal"]), 0, x)
x = torch.bitwise_right_shift(x, mantissa_difference)
x = torch.bitwise_and(
torch.bitwise_or(
torch.bitwise_and(torch.bitwise_right_shift(x, exponent_difference), sign_mask),
torch.bitwise_and(x, ~sign_mask),
),
~(-(1 << total_bits)),
).view(torch.uint32)
if total_bits < 8:
x = pack_int_asymetric(x, float_bits_to_uint_dict[total_bits])
else:
x = x.to(dtype=dtype_dict[weights_dtype]["storage_dtype"])
return x
def unpack_float(x: torch.Tensor, shape: torch.Size, weights_dtype: str) -> torch.FloatTensor:
exponent_bits = dtype_dict[weights_dtype]["exponent"]
mantissa_bits = dtype_dict[weights_dtype]["mantissa"]
total_bits = dtype_dict[weights_dtype]["num_bits"]
if dtype_dict[weights_dtype]["is_unsigned"]:
sign_mask = (1 << (total_bits-1))
else:
sign_mask = (1 << (total_bits-1)) + (1 << (total_bits-2))
mantissa_difference = 23 - mantissa_bits
exponent_difference = 8 - exponent_bits
if total_bits < 8:
x = unpack_int_asymetric(x, shape, float_bits_to_uint_dict[total_bits])
x = x.to(dtype=torch.uint32).view(torch.int32)
x = torch.bitwise_left_shift(
torch.bitwise_or(
torch.bitwise_left_shift(torch.bitwise_and(x, sign_mask), exponent_difference),
torch.bitwise_and(x, ~sign_mask),
),
mantissa_difference,
)
x = torch.bitwise_or(
x,
torch.bitwise_and(
torch.bitwise_right_shift(
-torch.bitwise_and(torch.bitwise_not(x), 1073741824),
exponent_difference,
),
1065353216,
),
)
overflow_mask = (~(-(1 << (22 + exponent_bits))) | -1073741824)
x = torch.where(torch.bitwise_and(x, overflow_mask).to(dtype=torch.bool), x, 0)
x = x.view(torch.float32)
return x

View File

@ -15,9 +15,10 @@ from diffusers.utils import get_module_from_name
from accelerate import init_empty_weights
from modules import devices, shared
from .common import sdnq_version, dtype_dict, common_skip_keys, module_skip_keys_dict, accepted_weight_dtypes, accepted_matmul_dtypes, allowed_types, linear_types, conv_types, conv_transpose_types, compile_func, use_tensorwise_fp8_matmul, use_contiguous_mm, check_torch_compile
from .common import sdnq_version, dtype_dict, common_skip_keys, module_skip_keys_dict, accepted_weight_dtypes, accepted_matmul_dtypes, weights_dtype_order, weights_dtype_order_fp32, allowed_types, linear_types, conv_types, conv_transpose_types, compile_func, use_tensorwise_fp8_matmul, use_contiguous_mm, check_torch_compile
from .dequantizer import SDNQDequantizer, dequantize_sdnq_model
from .packed_int import pack_int_symetric, pack_int_asymetric
from .packed_float import pack_float
from .forward import get_forward_func
@ -131,6 +132,7 @@ def get_quant_args_from_config(quantization_config: Union["SDNQConfig", dict]) -
quantization_config_dict.pop("return_device", None)
quantization_config_dict.pop("non_blocking", None)
quantization_config_dict.pop("add_skip_keys", None)
quantization_config_dict.pop("use_dynamic_quantization", None)
quantization_config_dict.pop("use_static_quantization", None)
quantization_config_dict.pop("use_stochastic_rounding", None)
quantization_config_dict.pop("use_grad_ckpt", None)
@ -202,7 +204,7 @@ def add_module_skip_keys(model, modules_to_not_convert: List[str] = None, module
@devices.inference_context()
def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, use_svd=False, use_quantized_matmul=False, use_stochastic_rounding=False, dequantize_fp32=False, param_name=None): # pylint: disable=unused-argument
def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, use_svd=False, use_quantized_matmul=False, use_stochastic_rounding=False, dequantize_fp32=False, using_pre_calculated_svd=False, param_name=None): # pylint: disable=unused-argument
num_of_groups = 1
is_conv_type = False
is_conv_transpose_type = False
@ -226,6 +228,7 @@ def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int
dtype_dict[weights_dtype]["is_unsigned"]
or dtype_dict[weights_dtype]["is_integer"] != dtype_dict[quantized_matmul_dtype]["is_integer"]
or dtype_dict[weights_dtype]["num_bits"] > dtype_dict[quantized_matmul_dtype]["num_bits"]
or (dtype_dict[weights_dtype]["is_packed"] and not dtype_dict[weights_dtype]["is_integer"])
)
if layer_class_name in conv_types:
@ -278,9 +281,9 @@ def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int
if use_quantized_matmul and not re_quantize_for_matmul and dtype_dict[weights_dtype]["num_bits"] >= 6:
group_size = -1
elif is_linear_type:
group_size = 2 ** ((2 if svd_up is None else 3) + dtype_dict[weights_dtype]["num_bits"])
group_size = 2 ** ((3 if (svd_up is not None or using_pre_calculated_svd) else 2) + dtype_dict[weights_dtype]["num_bits"])
else:
group_size = 2 ** ((1 if svd_up is None else 2) + dtype_dict[weights_dtype]["num_bits"])
group_size = 2 ** ((2 if (svd_up is not None or using_pre_calculated_svd) else 1) + dtype_dict[weights_dtype]["num_bits"])
if group_size > 0:
if group_size >= channel_size:
@ -366,10 +369,13 @@ def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int
)
if dtype_dict[weights_dtype]["is_packed"]:
if dtype_dict[weights_dtype]["is_unsigned"]:
weight = pack_int_asymetric(weight, weights_dtype)
if dtype_dict[weights_dtype]["is_integer"]:
if dtype_dict[weights_dtype]["is_unsigned"]:
weight = pack_int_asymetric(weight, weights_dtype)
else:
weight = pack_int_symetric(weight, weights_dtype)
else:
weight = pack_int_symetric(weight, weights_dtype)
weight = pack_float(weight, weights_dtype)
else:
weight = weight.to(dtype=dtype_dict[weights_dtype]["torch_dtype"])
@ -377,11 +383,63 @@ def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int
@devices.inference_context()
def sdnq_quantize_layer(layer, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_stochastic_rounding=False, dequantize_fp32=False, non_blocking=False, quantization_device=None, return_device=None, param_name=None): # pylint: disable=unused-argument
def sdnq_quantize_layer_weight_dynamic(weight, layer_class_name=None, weights_dtype="int2", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, use_quantized_matmul=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=False, svd_up=None, svd_down=None, param_name=None): # pylint: disable=unused-argument
if torch_dtype is None:
torch_dtype = weight.dtype
weights_dtype_order_to_use = weights_dtype_order_fp32 if torch_dtype in {torch.float32, torch.float64} else weights_dtype_order
weight = weight.to(dtype=torch.float32)
weight_std = weight.std().square()
if use_svd:
try:
svd_weight, svd_up, svd_down = apply_svdquant(weight, rank=svd_rank, niter=svd_steps)
svd_up, svd_down = prepare_svd_for_matmul(svd_up, svd_down, use_quantized_matmul)
svd_up = svd_up.to(dtype=torch_dtype)
svd_down = svd_down.to(dtype=torch_dtype)
except Exception:
svd_up, svd_down = None, None
svd_weight = weight
else:
svd_up, svd_down = None, None
svd_weight = weight
quantization_loss = None
svd_is_transposed = False
for i in range(weights_dtype_order_to_use.index(weights_dtype), len(weights_dtype_order_to_use)):
quantized_weight, scale, zero_point, _, _, sdnq_dequantizer = sdnq_quantize_layer_weight(
svd_weight,
layer_class_name=layer_class_name,
weights_dtype=weights_dtype_order_to_use[i],
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
use_svd=False,
using_pre_calculated_svd=use_svd,
use_quantized_matmul=use_quantized_matmul,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
param_name=param_name,
)
if not svd_is_transposed and sdnq_dequantizer.use_quantized_matmul:
svd_up = svd_up.t_()
svd_down = svd_down.t_()
svd_is_transposed = True
quantization_loss = torch.nn.functional.mse_loss(weight, sdnq_dequantizer(quantized_weight, scale, zero_point, svd_up, svd_down, skip_quantized_matmul=sdnq_dequantizer.use_quantized_matmul, dtype=torch.float32)).div_(weight_std)
if quantization_loss <= dynamic_loss_threshold:
return (quantized_weight, scale, zero_point, svd_up, svd_down, sdnq_dequantizer)
return None
@devices.inference_context()
def sdnq_quantize_layer(layer, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=False, non_blocking=False, modules_to_not_convert=None, modules_dtype_dict=None, quantization_device=None, return_device=None, param_name=None): # pylint: disable=unused-argument
layer_class_name = layer.__class__.__name__
if layer_class_name in conv_transpose_types or layer_class_name in conv_types:
if not quant_conv:
return layer
return layer, modules_to_not_convert, modules_dtype_dict
use_quantized_matmul = use_quantized_matmul_conv
layer.weight.requires_grad_(False)
@ -390,46 +448,83 @@ def sdnq_quantize_layer(layer, weights_dtype="int8", quantized_matmul_dtype=None
if quantization_device is not None:
layer.weight.data = layer.weight.to(quantization_device, non_blocking=non_blocking)
(
layer.weight.data,
layer.scale, layer.zero_point,
layer.svd_up, layer.svd_down,
layer.sdnq_dequantizer,
) = sdnq_quantize_layer_weight(
layer.weight,
layer_class_name=layer_class_name,
weights_dtype=weights_dtype,
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
use_svd=use_svd,
use_quantized_matmul=use_quantized_matmul,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
param_name=param_name,
)
if use_dynamic_quantization:
weight_data = sdnq_quantize_layer_weight_dynamic(
layer.weight,
layer_class_name=layer_class_name,
weights_dtype=weights_dtype,
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
dynamic_loss_threshold=dynamic_loss_threshold,
use_svd=use_svd,
use_quantized_matmul=use_quantized_matmul,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
param_name=param_name,
)
else:
weight_data = sdnq_quantize_layer_weight(
layer.weight,
layer_class_name=layer_class_name,
weights_dtype=weights_dtype,
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
use_svd=use_svd,
use_quantized_matmul=use_quantized_matmul,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
param_name=param_name,
)
layer.weight = torch.nn.Parameter(layer.weight.to(return_device, non_blocking=non_blocking), requires_grad=False)
layer.scale = torch.nn.Parameter(layer.scale.to(return_device, non_blocking=non_blocking), requires_grad=False)
if layer.zero_point is not None:
layer.zero_point = torch.nn.Parameter(layer.zero_point.to(return_device, non_blocking=non_blocking), requires_grad=False)
if layer.svd_up is not None:
layer.svd_up = torch.nn.Parameter(layer.svd_up.to(return_device, non_blocking=non_blocking), requires_grad=False)
layer.svd_down = torch.nn.Parameter(layer.svd_down.to(return_device, non_blocking=non_blocking), requires_grad=False)
if weight_data is not None:
(
layer.weight.data,
layer.scale, layer.zero_point,
layer.svd_up, layer.svd_down,
layer.sdnq_dequantizer,
) = weight_data
del weight_data
layer = layer.to(return_device, non_blocking=non_blocking)
layer.forward = get_forward_func(layer_class_name, layer.sdnq_dequantizer.quantized_matmul_dtype, layer.sdnq_dequantizer.use_quantized_matmul)
layer.forward = layer.forward.__get__(layer, layer.__class__)
return layer
layer.weight = torch.nn.Parameter(layer.weight.to(return_device, non_blocking=non_blocking), requires_grad=False)
layer.scale = torch.nn.Parameter(layer.scale.to(return_device, non_blocking=non_blocking), requires_grad=False)
if layer.zero_point is not None:
layer.zero_point = torch.nn.Parameter(layer.zero_point.to(return_device, non_blocking=non_blocking), requires_grad=False)
if layer.svd_up is not None:
layer.svd_up = torch.nn.Parameter(layer.svd_up.to(return_device, non_blocking=non_blocking), requires_grad=False)
layer.svd_down = torch.nn.Parameter(layer.svd_down.to(return_device, non_blocking=non_blocking), requires_grad=False)
layer = layer.to(return_device, non_blocking=non_blocking)
layer.forward = get_forward_func(layer_class_name, layer.sdnq_dequantizer.quantized_matmul_dtype, layer.sdnq_dequantizer.use_quantized_matmul)
layer.forward = layer.forward.__get__(layer, layer.__class__)
if use_dynamic_quantization:
if modules_dtype_dict is None:
modules_dtype_dict = {}
if layer.sdnq_dequantizer.weights_dtype not in modules_dtype_dict.keys():
modules_dtype_dict[layer.sdnq_dequantizer.weights_dtype] = [param_name]
else:
modules_dtype_dict[layer.sdnq_dequantizer.weights_dtype].append(param_name)
else:
layer = layer.to(return_device, dtype=torch_dtype, non_blocking=non_blocking)
if use_dynamic_quantization:
if modules_to_not_convert is None:
modules_to_not_convert = []
modules_to_not_convert.append(param_name)
return layer, modules_to_not_convert, modules_dtype_dict
@devices.inference_context()
def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_stochastic_rounding=False, dequantize_fp32=False, non_blocking=False, quantization_device=None, return_device=None, modules_to_not_convert: List[str] = None, modules_dtype_dict: Dict[str, List[str]] = None, full_param_name=""): # pylint: disable=unused-argument
def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=False, non_blocking=False, modules_to_not_convert: List[str] = None, modules_dtype_dict: Dict[str, List[str]] = None, quantization_device=None, return_device=None, full_param_name=""): # pylint: disable=unused-argument
has_children = list(model.children())
if not has_children:
return model
return model, modules_to_not_convert, modules_dtype_dict
if modules_to_not_convert is None:
modules_to_not_convert = []
if modules_dtype_dict is None:
@ -447,7 +542,7 @@ def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=Non
if layer_class_name in allowed_types and module.weight.dtype in {torch.float32, torch.float16, torch.bfloat16}:
if (layer_class_name in conv_types or layer_class_name in conv_transpose_types) and not quant_conv:
continue
setattr(model, module_name, sdnq_quantize_layer(
module, modules_to_not_convert, modules_dtype_dict = sdnq_quantize_layer(
module,
weights_dtype=get_minimum_dtype(weights_dtype, param_name, modules_dtype_dict),
quantized_matmul_dtype=quantized_matmul_dtype,
@ -455,39 +550,48 @@ def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=Non
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
dynamic_loss_threshold=dynamic_loss_threshold,
use_svd=use_svd,
quant_conv=quant_conv,
use_quantized_matmul=use_quantized_matmul,
use_quantized_matmul_conv=use_quantized_matmul_conv,
use_dynamic_quantization=use_dynamic_quantization,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
non_blocking=non_blocking,
quantization_device=quantization_device,
return_device=return_device,
modules_to_not_convert=modules_to_not_convert,
modules_dtype_dict=modules_dtype_dict,
param_name=param_name,
))
setattr(model, module_name, apply_sdnq_to_module(
module,
weights_dtype=weights_dtype,
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
use_svd=use_svd,
quant_conv=quant_conv,
use_quantized_matmul=use_quantized_matmul,
use_quantized_matmul_conv=use_quantized_matmul_conv,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
non_blocking=non_blocking,
quantization_device=quantization_device,
return_device=return_device,
modules_to_not_convert=modules_to_not_convert,
modules_dtype_dict=modules_dtype_dict,
full_param_name=param_name,
))
return model
)
setattr(model, module_name, module)
module, modules_to_not_convert, modules_dtype_dict = apply_sdnq_to_module(
module,
dynamic_loss_threshold=dynamic_loss_threshold,
weights_dtype=weights_dtype,
quantized_matmul_dtype=quantized_matmul_dtype,
torch_dtype=torch_dtype,
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
use_svd=use_svd,
quant_conv=quant_conv,
use_quantized_matmul=use_quantized_matmul,
use_quantized_matmul_conv=use_quantized_matmul_conv,
use_dynamic_quantization=use_dynamic_quantization,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
non_blocking=non_blocking,
quantization_device=quantization_device,
return_device=return_device,
modules_to_not_convert=modules_to_not_convert,
modules_dtype_dict=modules_dtype_dict,
full_param_name=param_name,
)
setattr(model, module_name, module)
return model, modules_to_not_convert, modules_dtype_dict
@devices.inference_context()
@ -499,18 +603,20 @@ def sdnq_post_load_quant(
group_size: int = 0,
svd_rank: int = 32,
svd_steps: int = 8,
dynamic_loss_threshold: float = 1e-2,
use_svd: bool = False,
quant_conv: bool = False,
use_quantized_matmul: bool = False,
use_quantized_matmul_conv: bool = False,
use_dynamic_quantization: bool = False,
use_stochastic_rounding: bool = False,
dequantize_fp32: bool = False,
non_blocking: bool = False,
add_skip_keys:bool = True,
quantization_device: Optional[torch.device] = None,
return_device: Optional[torch.device] = None,
modules_to_not_convert: List[str] = None,
modules_dtype_dict: Dict[str, List[str]] = None,
quantization_device: Optional[torch.device] = None,
return_device: Optional[torch.device] = None,
):
if modules_to_not_convert is None:
modules_to_not_convert = []
@ -527,22 +633,24 @@ def sdnq_post_load_quant(
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
dynamic_loss_threshold=dynamic_loss_threshold,
use_svd=use_svd,
quant_conv=quant_conv,
use_quantized_matmul=use_quantized_matmul,
use_quantized_matmul_conv=use_quantized_matmul_conv,
use_dynamic_quantization=use_dynamic_quantization,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
non_blocking=non_blocking,
add_skip_keys=add_skip_keys,
quantization_device=quantization_device,
return_device=return_device,
modules_to_not_convert=modules_to_not_convert,
modules_dtype_dict=modules_dtype_dict,
quantization_device=quantization_device,
return_device=return_device,
)
model.eval()
model = apply_sdnq_to_module(
model, modules_to_not_convert, modules_dtype_dict = apply_sdnq_to_module(
model,
weights_dtype=weights_dtype,
quantized_matmul_dtype=quantized_matmul_dtype,
@ -550,19 +658,24 @@ def sdnq_post_load_quant(
group_size=group_size,
svd_rank=svd_rank,
svd_steps=svd_steps,
dynamic_loss_threshold=dynamic_loss_threshold,
use_svd=use_svd,
quant_conv=quant_conv,
use_quantized_matmul=use_quantized_matmul,
use_quantized_matmul_conv=use_quantized_matmul_conv,
use_dynamic_quantization=use_dynamic_quantization,
use_stochastic_rounding=use_stochastic_rounding,
dequantize_fp32=dequantize_fp32,
non_blocking=non_blocking,
quantization_device=quantization_device,
return_device=return_device,
modules_to_not_convert=modules_to_not_convert,
modules_dtype_dict=modules_dtype_dict,
quantization_device=quantization_device,
return_device=return_device,
)
quantization_config.modules_to_not_convert = modules_to_not_convert
quantization_config.modules_dtype_dict = modules_dtype_dict
model.quantization_config = quantization_config
if hasattr(model, "config"):
try:
@ -695,7 +808,7 @@ class SDNQQuantizer(DiffusersQuantizer, HfQuantizer):
layer, _ = get_module_from_name(model, param_name)
layer.weight = torch.nn.Parameter(param_value, requires_grad=False)
layer = sdnq_quantize_layer(
layer, self.quantization_config.modules_to_not_convert, self.quantization_config.modules_dtype_dict = sdnq_quantize_layer(
layer,
weights_dtype=weights_dtype,
quantized_matmul_dtype=self.quantization_config.quantized_matmul_dtype,
@ -703,25 +816,30 @@ class SDNQQuantizer(DiffusersQuantizer, HfQuantizer):
group_size=self.quantization_config.group_size,
svd_rank=self.quantization_config.svd_rank,
svd_steps=self.quantization_config.svd_steps,
dynamic_loss_threshold=self.quantization_config.dynamic_loss_threshold,
use_svd=self.quantization_config.use_svd,
quant_conv=self.quantization_config.quant_conv,
use_quantized_matmul=self.quantization_config.use_quantized_matmul,
use_quantized_matmul_conv=self.quantization_config.use_quantized_matmul_conv,
use_dynamic_quantization=self.quantization_config.use_dynamic_quantization,
use_stochastic_rounding=self.quantization_config.use_stochastic_rounding,
dequantize_fp32=self.quantization_config.dequantize_fp32,
non_blocking=self.quantization_config.non_blocking,
modules_to_not_convert=self.quantization_config.modules_to_not_convert,
modules_dtype_dict=self.quantization_config.modules_dtype_dict,
quantization_device=None,
return_device=return_device,
param_name=param_name,
)
layer.weight._is_hf_initialized = True # pylint: disable=protected-access
layer.scale._is_hf_initialized = True # pylint: disable=protected-access
if layer.zero_point is not None:
layer.zero_point._is_hf_initialized = True # pylint: disable=protected-access
if layer.svd_up is not None:
layer.svd_up._is_hf_initialized = True # pylint: disable=protected-access
layer.svd_down._is_hf_initialized = True # pylint: disable=protected-access
if hasattr(layer, "scale"):
layer.scale._is_hf_initialized = True # pylint: disable=protected-access
if layer.zero_point is not None:
layer.zero_point._is_hf_initialized = True # pylint: disable=protected-access
if layer.svd_up is not None:
layer.svd_up._is_hf_initialized = True # pylint: disable=protected-access
layer.svd_down._is_hf_initialized = True # pylint: disable=protected-access
def get_quantize_ops(self):
return SDNQQuantize(self)
@ -757,7 +875,7 @@ class SDNQQuantizer(DiffusersQuantizer, HfQuantizer):
self.quantization_config.add_skip_keys = False
with init_empty_weights():
model = sdnq_post_load_quant(model, torch_dtype=self.torch_dtype, add_skip_keys=False, **get_quant_args_from_config(self.quantization_config))
model = sdnq_post_load_quant(model, torch_dtype=self.torch_dtype, add_skip_keys=False, use_dynamic_quantization=False, **get_quant_args_from_config(self.quantization_config))
if self.quantization_config.add_skip_keys:
if keep_in_fp32_modules is not None:
@ -849,6 +967,8 @@ class SDNQConfig(QuantizationConfigMixin):
group_size = 0 will automatically select a group size based on weights_dtype.
svd_rank (`int`, *optional*, defaults to `32`):
The rank size used for the SVDQuant algorithm.
dynamic_loss_threshold (`float`, *optional*, defaults to `1e-2`):
The target quantization mse loss threshold to use for dynamic quantization.
svd_steps (`int`, *optional*, defaults to `8`):
The number of iterations to use in svd lowrank estimation.
use_svd (`bool`, *optional*, defaults to `False`):
@ -861,6 +981,9 @@ class SDNQConfig(QuantizationConfigMixin):
Same as use_quantized_matmul_conv but for the convolutional layers with UNets like SDXL.
use_stochastic_rounding (`bool`, *optional*, defaults to `False`):
Enabling this option will use stochastic rounding on the quantization step.
use_dynamic_quantization (`bool`, *optional*, defaults to `False`):
Enabling this option will dynamically select a quantization type based on the dynamic_loss_threshold.
weights_dtype will be used as the minimum allowed quantization type when this option is enabled.
dequantize_fp32 (`bool`, *optional*, defaults to `False`):
Enabling this option will use FP32 on the dequantization step.
non_blocking (`bool`, *optional*, defaults to `False`):
@ -885,12 +1008,14 @@ class SDNQConfig(QuantizationConfigMixin):
group_size: int = 0,
svd_rank: int = 32,
svd_steps: int = 8,
dynamic_loss_threshold: float = 1e-2,
use_svd: bool = False,
use_grad_ckpt: bool = True,
quant_conv: bool = False,
use_quantized_matmul: bool = False,
use_quantized_matmul_conv: bool = False,
use_static_quantization: bool = True,
use_dynamic_quantization: bool = False,
use_stochastic_rounding: bool = False,
dequantize_fp32: bool = False,
non_blocking: bool = False,
@ -911,6 +1036,7 @@ class SDNQConfig(QuantizationConfigMixin):
self.quant_method = QuantizationMethod.SDNQ
self.group_size = group_size
self.svd_rank = svd_rank
self.dynamic_loss_threshold = dynamic_loss_threshold
self.svd_steps = svd_steps
self.use_svd = use_svd
self.use_grad_ckpt = use_grad_ckpt
@ -918,6 +1044,7 @@ class SDNQConfig(QuantizationConfigMixin):
self.use_quantized_matmul = use_quantized_matmul
self.use_quantized_matmul_conv = use_quantized_matmul_conv
self.use_static_quantization = use_static_quantization
self.use_dynamic_quantization = use_dynamic_quantization
self.use_stochastic_rounding = use_stochastic_rounding
self.dequantize_fp32 = dequantize_fp32
self.non_blocking = non_blocking

View File

@ -70,7 +70,7 @@ restricted_opts = {
}
resize_modes = ["None", "Fixed", "Crop", "Fill", "Outpaint", "Context aware"]
max_workers = 12
sdnq_quant_modes = ["int8", "float8_e4m3fn", "int7", "int6", "int5", "uint4", "uint3", "uint2", "float8_e5m2", "float8_e4m3fnuz", "float8_e5m2fnuz", "float16", "int16", "uint16", "uint8", "uint7", "uint6", "uint5", "int4", "int3", "int2", "uint1"]
sdnq_quant_modes = ["int8", "int7", "int6", "uint5", "uint4", "uint3", "uint2", "float8_e4m3fn", "float7_e3m3fn", "float6_e2m3fn", "float5_e2m2fn", "float4_e2m1fn", "float3_e1m1fn", "float2_e1m0fn"]
sdnq_matmul_modes = ["auto", "int8", "float8_e4m3fn", "float16"]
default_hfcache_dir = os.environ.get("SD_HFCACHEDIR", None) or os.path.join(paths.models_path, 'huggingface')
state = shared_state.State()