mirror of https://github.com/vladmandic/automatic
Add basic triton test
parent
512036d291
commit
818b0c0821
|
|
@ -21,6 +21,7 @@ cpu = torch.device("cpu")
|
|||
|
||||
fp16_ok = None # set once by test_fp16
|
||||
bf16_ok = None # set once by test_bf16
|
||||
triton_ok = None # set once by test_triton
|
||||
|
||||
backend = None # set by get_backend
|
||||
device = None # set by get_optimal_device
|
||||
|
|
@ -65,11 +66,10 @@ def has_zluda() -> bool:
|
|||
|
||||
|
||||
def has_triton() -> bool:
|
||||
try:
|
||||
from torch.utils._triton import has_triton as torch_has_triton
|
||||
return torch_has_triton()
|
||||
except Exception:
|
||||
return False
|
||||
global triton_ok
|
||||
if triton_ok is not None:
|
||||
return triton_ok
|
||||
return test_triton()
|
||||
|
||||
|
||||
def get_backend(shared_cmd_opts):
|
||||
|
|
@ -382,6 +382,26 @@ def test_bf16():
|
|||
return bf16_ok
|
||||
|
||||
|
||||
def test_triton():
|
||||
global triton_ok
|
||||
if triton_ok is not None:
|
||||
return triton_ok
|
||||
try:
|
||||
from torch.utils._triton import has_triton as torch_has_triton
|
||||
if torch_has_triton():
|
||||
def test_triton_func(a,b,c):
|
||||
return a * b + c
|
||||
test_triton_func = torch.compile(test_triton_func, fullgraph=True)
|
||||
test_triton_func(torch.randn(128, device=device), torch.randn(128, device=device), torch.randn(128, device=device))
|
||||
triton_ok = True
|
||||
else:
|
||||
triton_ok = False
|
||||
except Exception as e:
|
||||
log.warning(f"Triton test fail: {e}")
|
||||
triton_ok = False
|
||||
return triton_ok
|
||||
|
||||
|
||||
def set_cudnn_params():
|
||||
if not cuda_ok:
|
||||
return
|
||||
|
|
@ -614,6 +634,7 @@ def set_cuda_params():
|
|||
set_cudnn_params()
|
||||
set_sdpa_params()
|
||||
set_dtype()
|
||||
test_triton()
|
||||
if backend == 'openvino':
|
||||
from modules.intel.openvino import get_device as get_raw_openvino_device
|
||||
device_name = get_raw_openvino_device()
|
||||
|
|
@ -624,7 +645,7 @@ def set_cuda_params():
|
|||
tunable = [torch.cuda.tunable.is_enabled(), torch.cuda.tunable.tuning_is_enabled()]
|
||||
except Exception:
|
||||
tunable = [False, False]
|
||||
log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} optimization="{opts.cross_attention_optimization}"')
|
||||
log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} triton={"pass" if triton_ok else "fail"} optimization="{opts.cross_attention_optimization}"')
|
||||
|
||||
|
||||
def randn(seed, shape=None):
|
||||
|
|
|
|||
Loading…
Reference in New Issue