float8_e5m2 (#29)

main
Luís Eduardo Ribeiro Guerra 2024-12-24 09:56:13 -03:00 committed by GitHub
parent a7ea23b594
commit a8c04410aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 5 deletions

View File

@ -9,7 +9,8 @@ from modules import sd_models, sd_vae
# position_ids in clip is int64. model_ema.num_updates is int32
dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16}
dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16}
dtypes_to_fp8 = {torch.float32, torch.float64, torch.bfloat16, torch.float16}
dtypes_to_float8_e4m3fn = {torch.float32, torch.float64, torch.bfloat16, torch.float16}
dtypes_to_float8_e5m2 = {torch.float32, torch.float64, torch.bfloat16, torch.float16}
class MockModelInfo:
@ -27,9 +28,11 @@ def conv_bf16(t: Tensor):
return t.bfloat16() if t.dtype in dtypes_to_bf16 else t
def conv_fp8(t: Tensor):
return t.to(torch.float8_e4m3fn) if t.dtype in dtypes_to_fp8 else t
def conv_float8_e4m3fn(t: Tensor):
return t.to(torch.float8_e4m3fn) if t.dtype in dtypes_to_float8_e4m3fn else t
def conv_float8_e5m2(t: Tensor):
return t.to(torch.float8_e5m2) if t.dtype in dtypes_to_float8_e5m2 else t
def conv_full(t):
return t
@ -40,7 +43,8 @@ _g_precision_func = {
"fp32": conv_full,
"fp16": conv_fp16,
"bf16": conv_bf16,
"fp8": conv_fp8,
"float8_e4m3fn": conv_float8_e4m3fn,
"float8_e5m2": conv_float8_e5m2,
}
@ -171,6 +175,7 @@ def do_convert(model_info: MockModelInfo,
if not is_sdxl:
fix_model(state_dict, fix_clip=fix_clip, force_position_id=force_position_id)
if precision == "fp8":
assert torch.__version__ >= "2.1.0", "PyTorch 2.1.0 or newer is required for fp8 conversion"

View File

@ -34,7 +34,7 @@ def add_tab():
input_directory = gr.Textbox(label="Input Directory")
with gr.Row():
precision = gr.Radio(choices=["fp32", "fp16", "bf16", "fp8"], value="fp16", label="Precision")
precision = gr.Radio(choices=["fp32", "fp16", "bf16", "float8_e4m3fn","float8_e5m2"], value="fp16", label="Precision")
m_type = gr.Radio(choices=["disabled", "no-ema", "ema-only"], value="disabled", label="Pruning Methods")
with gr.Row():