add triton backend flash-attn (experimental)

pull/3454/head
Seunghoon Lee 2024-09-24 12:53:18 +09:00
parent e246e55734
commit 1395f5bf9e
No known key found for this signature in database
GPG Key ID: 8BBA0F6A4069A002
1 changed files with 2 additions and 0 deletions

View File

@ -185,6 +185,8 @@ else:
return bool(int(os.environ.get("TORCH_BLAS_PREFER_HIPBLASLT", "1")))
def get_flash_attention_command(agent: Agent):
if os.environ.get("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE":
return "pytest git+https://github.com/ROCm/flash-attention@micmelesse/upstream_pr"
default = "git+https://github.com/ROCm/flash-attention"
if agent.arch == MicroArchitecture.RDNA:
default = "git+https://github.com/ROCm/flash-attention@howiejay/navi_support"