float8_e5m2 (#29)
parent
a7ea23b594
commit
a8c04410aa
|
|
@ -9,7 +9,8 @@ from modules import sd_models, sd_vae
|
|||
# position_ids in clip is int64. model_ema.num_updates is int32
|
||||
dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16}
|
||||
dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16}
|
||||
dtypes_to_fp8 = {torch.float32, torch.float64, torch.bfloat16, torch.float16}
|
||||
dtypes_to_float8_e4m3fn = {torch.float32, torch.float64, torch.bfloat16, torch.float16}
|
||||
dtypes_to_float8_e5m2 = {torch.float32, torch.float64, torch.bfloat16, torch.float16}
|
||||
|
||||
|
||||
class MockModelInfo:
|
||||
|
|
@ -27,9 +28,11 @@ def conv_bf16(t: Tensor):
|
|||
return t.bfloat16() if t.dtype in dtypes_to_bf16 else t
|
||||
|
||||
|
||||
def conv_fp8(t: Tensor):
|
||||
return t.to(torch.float8_e4m3fn) if t.dtype in dtypes_to_fp8 else t
|
||||
def conv_float8_e4m3fn(t: Tensor):
|
||||
return t.to(torch.float8_e4m3fn) if t.dtype in dtypes_to_float8_e4m3fn else t
|
||||
|
||||
def conv_float8_e5m2(t: Tensor):
|
||||
return t.to(torch.float8_e5m2) if t.dtype in dtypes_to_float8_e5m2 else t
|
||||
|
||||
def conv_full(t):
|
||||
return t
|
||||
|
|
@ -40,7 +43,8 @@ _g_precision_func = {
|
|||
"fp32": conv_full,
|
||||
"fp16": conv_fp16,
|
||||
"bf16": conv_bf16,
|
||||
"fp8": conv_fp8,
|
||||
"float8_e4m3fn": conv_float8_e4m3fn,
|
||||
"float8_e5m2": conv_float8_e5m2,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -171,6 +175,7 @@ def do_convert(model_info: MockModelInfo,
|
|||
if not is_sdxl:
|
||||
fix_model(state_dict, fix_clip=fix_clip, force_position_id=force_position_id)
|
||||
|
||||
|
||||
if precision == "fp8":
|
||||
assert torch.__version__ >= "2.1.0", "PyTorch 2.1.0 or newer is required for fp8 conversion"
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def add_tab():
|
|||
input_directory = gr.Textbox(label="Input Directory")
|
||||
|
||||
with gr.Row():
|
||||
precision = gr.Radio(choices=["fp32", "fp16", "bf16", "fp8"], value="fp16", label="Precision")
|
||||
precision = gr.Radio(choices=["fp32", "fp16", "bf16", "float8_e4m3fn","float8_e5m2"], value="fp16", label="Precision")
|
||||
m_type = gr.Radio(choices=["disabled", "no-ema", "ema-only"], value="disabled", label="Pruning Methods")
|
||||
|
||||
with gr.Row():
|
||||
|
|
|
|||
Loading…
Reference in New Issue