Add basic triton test

pull/4317/head
Disty0 2025-10-26 10:44:04 +03:00
parent 512036d291
commit 818b0c0821
1 changed files with 27 additions and 6 deletions

View File

@ -21,6 +21,7 @@ cpu = torch.device("cpu")
fp16_ok = None # set once by test_fp16
bf16_ok = None # set once by test_bf16
triton_ok = None # set once by test_triton
backend = None # set by get_backend
device = None # set by get_optimal_device
@ -65,11 +66,10 @@ def has_zluda() -> bool:
def has_triton() -> bool:
try:
from torch.utils._triton import has_triton as torch_has_triton
return torch_has_triton()
except Exception:
return False
global triton_ok
if triton_ok is not None:
return triton_ok
return test_triton()
def get_backend(shared_cmd_opts):
@ -382,6 +382,26 @@ def test_bf16():
return bf16_ok
def test_triton():
global triton_ok
if triton_ok is not None:
return triton_ok
try:
from torch.utils._triton import has_triton as torch_has_triton
if torch_has_triton():
def test_triton_func(a,b,c):
return a * b + c
test_triton_func = torch.compile(test_triton_func, fullgraph=True)
test_triton_func(torch.randn(128, device=device), torch.randn(128, device=device), torch.randn(128, device=device))
triton_ok = True
else:
triton_ok = False
except Exception as e:
log.warning(f"Triton test fail: {e}")
triton_ok = False
return triton_ok
def set_cudnn_params():
if not cuda_ok:
return
@ -614,6 +634,7 @@ def set_cuda_params():
set_cudnn_params()
set_sdpa_params()
set_dtype()
test_triton()
if backend == 'openvino':
from modules.intel.openvino import get_device as get_raw_openvino_device
device_name = get_raw_openvino_device()
@ -624,7 +645,7 @@ def set_cuda_params():
tunable = [torch.cuda.tunable.is_enabled(), torch.cuda.tunable.tuning_is_enabled()]
except Exception:
tunable = [False, False]
log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} optimization="{opts.cross_attention_optimization}"')
log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} triton={"pass" if triton_ok else "fail"} optimization="{opts.cross_attention_optimization}"')
def randn(seed, shape=None):