automatic/test/benchmark_attention.py

237 lines
8.2 KiB
Python

from typing import Dict, Any
import time
import warnings
import torch
import torch.nn.functional as F
warnings.filterwarnings("ignore", category=UserWarning)
warmup = 2
repeats = 50
dtypes = [torch.bfloat16] # , torch.float16]
backends = [
"sdpa_math",
"sdpa_mem_efficient",
"sdpa_flash",
"sdpa_all",
"flex_attention",
"xformers",
"flash_attn",
"sage_attn",
]
profiles = {
"sdxl": {"l_q": 4096, "l_k": 4096, "h": 32, "d": 128},
"flux.1": {"l_q": 16717, "l_k": 16717, "h": 24, "d": 128},
"sd35": {"l_q": 16538, "l_k": 16538, "h": 24, "d": 128},
"qwen-image": {"l_q": 16384, "l_k": 16384, "h": 24, "d": 128},
"z-image": {"l_q": 4096, "l_k": 4096, "h": 32, "d": 120},
"wan2.1": {"l_q": 16384, "l_k": 16384, "h": 40, "d": 128},
}
def get_stats(reset: bool = False):
torch.cuda.synchronize()
if reset:
with torch.no_grad():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
m = torch.cuda.max_memory_allocated()
t = time.perf_counter()
return m / (1024 ** 2), t
def print_gpu_info():
if not torch.cuda.is_available():
print("GPU: Not available")
return
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
total_mem = props.total_memory / (1024**3)
free_mem, _ = torch.cuda.mem_get_info(device)
free_mem = free_mem / (1024**3)
major, minor = torch.cuda.get_device_capability(device)
print(f"gpu: {torch.cuda.get_device_name(device)}")
print(f"vram: total={total_mem:.2f}GB free={free_mem:.2f}GB")
print(f"cuda: capability={major}.{minor} version={torch.version.cuda}")
print(f"torch: {torch.__version__}")
def benchmark_attention(
backend: str,
dtype: torch.dtype,
b: int = 1,
l_q: int = 4096,
l_k: int = 4096,
h: int = 32,
d: int = 128,
) -> Dict[str, Any]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize tensors
q = torch.randn(b, h, l_q, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype)
k = torch.randn(b, h, l_k, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype)
v = torch.randn(b, h, l_k, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype)
results = {
"backend": backend,
"dtype": str(dtype),
"status": "pass",
"latency_ms": 0.0,
"memory_mb": 0.0,
"version": "N/A",
"error": ""
}
try:
if backend.startswith("sdpa_"):
from torch.nn.attention import sdpa_kernel, SDPBackend
sdp_type = backend[len("sdpa_"):]
# Map friendly names to new SDPA backends
backend_map = {
"math": [SDPBackend.MATH],
"flash": [SDPBackend.FLASH_ATTENTION],
"mem_efficient": [SDPBackend.EFFICIENT_ATTENTION],
"all": [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
}
if sdp_type not in backend_map:
raise ValueError(f"Unknown SDPA type: {sdp_type}")
results["version"] = torch.__version__
with sdpa_kernel(backend_map[sdp_type]):
# Warmup
for _ in range(warmup):
_ = F.scaled_dot_product_attention(q, k, v)
start_mem, start_time = get_stats(True)
for _ in range(repeats):
_ = F.scaled_dot_product_attention(q, k, v)
end_mem, end_time = get_stats()
results["latency_ms"] = (end_time - start_time) / repeats * 1000
results["memory_mb"] = end_mem - start_mem
elif backend == "flash_attn":
from flash_attn import flash_attn_func, __version__ as fa_version
results["version"] = fa_version
# Flash attention usually expects (B, L, H, D)
q_fa = q.transpose(1, 2)
k_fa = k.transpose(1, 2)
v_fa = v.transpose(1, 2)
for _ in range(warmup):
_ = flash_attn_func(q_fa, k_fa, v_fa)
start_mem, start_time = get_stats(True)
for _ in range(repeats):
_ = flash_attn_func(q_fa, k_fa, v_fa)
end_mem, end_time = get_stats()
results["latency_ms"] = (end_time - start_time) / repeats * 1000
results["memory_mb"] = end_mem - start_mem
elif backend == "xformers":
from xformers.ops import memory_efficient_attention
from xformers import __version__ as xf_version
results["version"] = xf_version
# xformers also usually prefers (B, L, H, D)
q_xf = q.transpose(1, 2)
k_xf = k.transpose(1, 2)
v_xf = v.transpose(1, 2)
for _ in range(warmup):
_ = memory_efficient_attention(q_xf, k_xf, v_xf)
start_mem, start_time = get_stats(True)
for _ in range(repeats):
_ = memory_efficient_attention(q_xf, k_xf, v_xf)
end_mem, end_time = get_stats()
results["latency_ms"] = (end_time - start_time) / repeats * 1000
results["memory_mb"] = end_mem - start_mem
elif backend == "sage_attn":
from sageattention import sageattn
import sageattention
# Attempt to get version from package metadata or a common attribute
try:
import importlib.metadata
results["version"] = importlib.metadata.version("sageattention")
except Exception:
results["version"] = getattr(sageattention, "__version__", "N/A")
# SageAttention expects (B, H, L, D) logic
for _ in range(warmup):
_ = sageattn(q, k, v)
start_mem, start_time = get_stats(True)
for _ in range(repeats):
_ = sageattn(q, k, v)
end_mem, end_time = get_stats()
results["latency_ms"] = (end_time - start_time) / repeats * 1000
results["memory_mb"] = end_mem - start_mem
elif backend == "flex_attention":
from torch.nn.attention.flex_attention import flex_attention
results["version"] = torch.__version__
# flex_attention requires torch.compile for performance
flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
# Warmup (important to trigger compilation)
for _ in range(warmup):
_ = flex_attention_compiled(q, k, v)
start_mem, start_time = get_stats(True)
for _ in range(repeats):
_ = flex_attention_compiled(q, k, v)
end_mem, end_time = get_stats()
results["latency_ms"] = (end_time - start_time) / repeats * 1000
results["memory_mb"] = end_mem - start_mem
except Exception as e:
results["status"] = "fail"
results["error"] = str(e)[:49]
print(e)
return results
def main():
all_results = []
print_gpu_info()
print(f'config: warmup={warmup} repeats={repeats} dtypes={dtypes}')
for name, config in profiles.items():
print(f"profile: {name} (L_q={config['l_q']}, L_k={config['l_k']}, H={config['h']}, D={config['d']})")
for dtype in dtypes:
print(f" dtype: {dtype}")
print(f" {'backend':<20} | {'version':<12} | {'status':<8} | {'latency':<10} | {'memory':<12} | ")
for backend in backends:
res = benchmark_attention(
backend,
dtype,
l_q=config["l_q"],
l_k=config["l_k"],
h=config["h"],
d=config["d"],
)
all_results.append(res)
latency = f"{res['latency_ms']:.4f} ms"
memory = f"{res['memory_mb']:.2f} MB"
print(f" {res['backend']:<20} | {res['version']:<12} | {res['status']:<8} | {latency:<10} | {memory:<12} | {res['error']}")
if __name__ == "__main__":
main()