diff --git a/modules/devices.py b/modules/devices.py index 961ffb384..2cb5ee653 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -21,6 +21,7 @@ cpu = torch.device("cpu") fp16_ok = None # set once by test_fp16 bf16_ok = None # set once by test_bf16 +triton_ok = None # set once by test_triton backend = None # set by get_backend device = None # set by get_optimal_device @@ -65,11 +66,10 @@ def has_zluda() -> bool: def has_triton() -> bool: - try: - from torch.utils._triton import has_triton as torch_has_triton - return torch_has_triton() - except Exception: - return False + global triton_ok + if triton_ok is not None: + return triton_ok + return test_triton() def get_backend(shared_cmd_opts): @@ -382,6 +382,26 @@ def test_bf16(): return bf16_ok +def test_triton(): + global triton_ok + if triton_ok is not None: + return triton_ok + try: + from torch.utils._triton import has_triton as torch_has_triton + if torch_has_triton(): + def test_triton_func(a,b,c): + return a * b + c + test_triton_func = torch.compile(test_triton_func, fullgraph=True) + test_triton_func(torch.randn(128, device=device), torch.randn(128, device=device), torch.randn(128, device=device)) + triton_ok = True + else: + triton_ok = False + except Exception as e: + log.warning(f"Triton test fail: {e}") + triton_ok = False + return triton_ok + + def set_cudnn_params(): if not cuda_ok: return @@ -614,6 +634,7 @@ def set_cuda_params(): set_cudnn_params() set_sdpa_params() set_dtype() + test_triton() if backend == 'openvino': from modules.intel.openvino import get_device as get_raw_openvino_device device_name = get_raw_openvino_device() @@ -624,7 +645,7 @@ def set_cuda_params(): tunable = [torch.cuda.tunable.is_enabled(), torch.cuda.tunable.tuning_is_enabled()] except Exception: tunable = [False, False] - log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} optimization="{opts.cross_attention_optimization}"') + log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} triton={"pass" if triton_ok else "fail"} optimization="{opts.cross_attention_optimization}"') def randn(seed, shape=None):