from typing import Dict, Any import time import warnings import torch import torch.nn.functional as F warnings.filterwarnings("ignore", category=UserWarning) warmup = 2 repeats = 50 dtypes = [torch.bfloat16] # , torch.float16] backends = [ "sdpa_math", "sdpa_mem_efficient", "sdpa_flash", "sdpa_all", "flex_attention", "xformers", "flash_attn", "sage_attn", ] profiles = { "sdxl": {"l_q": 4096, "l_k": 4096, "h": 32, "d": 128}, "flux.1": {"l_q": 16717, "l_k": 16717, "h": 24, "d": 128}, "sd35": {"l_q": 16538, "l_k": 16538, "h": 24, "d": 128}, "qwen-image": {"l_q": 16384, "l_k": 16384, "h": 24, "d": 128}, "z-image": {"l_q": 4096, "l_k": 4096, "h": 32, "d": 120}, "wan2.1": {"l_q": 16384, "l_k": 16384, "h": 40, "d": 128}, } def get_stats(reset: bool = False): torch.cuda.synchronize() if reset: with torch.no_grad(): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() m = torch.cuda.max_memory_allocated() t = time.perf_counter() return m / (1024 ** 2), t def print_gpu_info(): if not torch.cuda.is_available(): print("GPU: Not available") return device = torch.cuda.current_device() props = torch.cuda.get_device_properties(device) total_mem = props.total_memory / (1024**3) free_mem, _ = torch.cuda.mem_get_info(device) free_mem = free_mem / (1024**3) major, minor = torch.cuda.get_device_capability(device) print(f"gpu: {torch.cuda.get_device_name(device)}") print(f"vram: total={total_mem:.2f}GB free={free_mem:.2f}GB") print(f"cuda: capability={major}.{minor} version={torch.version.cuda}") print(f"torch: {torch.__version__}") def benchmark_attention( backend: str, dtype: torch.dtype, b: int = 1, l_q: int = 4096, l_k: int = 4096, h: int = 32, d: int = 128, ) -> Dict[str, Any]: device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize tensors q = torch.randn(b, h, l_q, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype) k = torch.randn(b, h, l_k, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype) v = torch.randn(b, h, l_k, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype) results = { "backend": backend, "dtype": str(dtype), "status": "pass", "latency_ms": 0.0, "memory_mb": 0.0, "version": "N/A", "error": "" } try: if backend.startswith("sdpa_"): from torch.nn.attention import sdpa_kernel, SDPBackend sdp_type = backend[len("sdpa_"):] # Map friendly names to new SDPA backends backend_map = { "math": [SDPBackend.MATH], "flash": [SDPBackend.FLASH_ATTENTION], "mem_efficient": [SDPBackend.EFFICIENT_ATTENTION], "all": [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] } if sdp_type not in backend_map: raise ValueError(f"Unknown SDPA type: {sdp_type}") results["version"] = torch.__version__ with sdpa_kernel(backend_map[sdp_type]): # Warmup for _ in range(warmup): _ = F.scaled_dot_product_attention(q, k, v) start_mem, start_time = get_stats(True) for _ in range(repeats): _ = F.scaled_dot_product_attention(q, k, v) end_mem, end_time = get_stats() results["latency_ms"] = (end_time - start_time) / repeats * 1000 results["memory_mb"] = end_mem - start_mem elif backend == "flash_attn": from flash_attn import flash_attn_func, __version__ as fa_version results["version"] = fa_version # Flash attention usually expects (B, L, H, D) q_fa = q.transpose(1, 2) k_fa = k.transpose(1, 2) v_fa = v.transpose(1, 2) for _ in range(warmup): _ = flash_attn_func(q_fa, k_fa, v_fa) start_mem, start_time = get_stats(True) for _ in range(repeats): _ = flash_attn_func(q_fa, k_fa, v_fa) end_mem, end_time = get_stats() results["latency_ms"] = (end_time - start_time) / repeats * 1000 results["memory_mb"] = end_mem - start_mem elif backend == "xformers": from xformers.ops import memory_efficient_attention from xformers import __version__ as xf_version results["version"] = xf_version # xformers also usually prefers (B, L, H, D) q_xf = q.transpose(1, 2) k_xf = k.transpose(1, 2) v_xf = v.transpose(1, 2) for _ in range(warmup): _ = memory_efficient_attention(q_xf, k_xf, v_xf) start_mem, start_time = get_stats(True) for _ in range(repeats): _ = memory_efficient_attention(q_xf, k_xf, v_xf) end_mem, end_time = get_stats() results["latency_ms"] = (end_time - start_time) / repeats * 1000 results["memory_mb"] = end_mem - start_mem elif backend == "sage_attn": from sageattention import sageattn import sageattention # Attempt to get version from package metadata or a common attribute try: import importlib.metadata results["version"] = importlib.metadata.version("sageattention") except Exception: results["version"] = getattr(sageattention, "__version__", "N/A") # SageAttention expects (B, H, L, D) logic for _ in range(warmup): _ = sageattn(q, k, v) start_mem, start_time = get_stats(True) for _ in range(repeats): _ = sageattn(q, k, v) end_mem, end_time = get_stats() results["latency_ms"] = (end_time - start_time) / repeats * 1000 results["memory_mb"] = end_mem - start_mem elif backend == "flex_attention": from torch.nn.attention.flex_attention import flex_attention results["version"] = torch.__version__ # flex_attention requires torch.compile for performance flex_attention_compiled = torch.compile(flex_attention, dynamic=False) # Warmup (important to trigger compilation) for _ in range(warmup): _ = flex_attention_compiled(q, k, v) start_mem, start_time = get_stats(True) for _ in range(repeats): _ = flex_attention_compiled(q, k, v) end_mem, end_time = get_stats() results["latency_ms"] = (end_time - start_time) / repeats * 1000 results["memory_mb"] = end_mem - start_mem except Exception as e: results["status"] = "fail" results["error"] = str(e)[:49] print(e) return results def main(): all_results = [] print_gpu_info() print(f'config: warmup={warmup} repeats={repeats} dtypes={dtypes}') for name, config in profiles.items(): print(f"profile: {name} (L_q={config['l_q']}, L_k={config['l_k']}, H={config['h']}, D={config['d']})") for dtype in dtypes: print(f" dtype: {dtype}") print(f" {'backend':<20} | {'version':<12} | {'status':<8} | {'latency':<10} | {'memory':<12} | ") for backend in backends: res = benchmark_attention( backend, dtype, l_q=config["l_q"], l_k=config["l_k"], h=config["h"], d=config["d"], ) all_results.append(res) latency = f"{res['latency_ms']:.4f} ms" memory = f"{res['memory_mb']:.2f} MB" print(f" {res['backend']:<20} | {res['version']:<12} | {res['status']:<8} | {latency:<10} | {memory:<12} | {res['error']}") if __name__ == "__main__": main()