mirror of https://github.com/vladmandic/automatic
Rename CK Flash attention to just Flash attention
parent
a93715e0da
commit
2bbbb684cc
|
|
@ -83,9 +83,9 @@ def set_ck_flash_attention(backend: str, device: torch.device):
|
|||
kwargs["enable_gqa"] = enable_gqa
|
||||
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs)
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
|
||||
log.debug('Torch attention: type="CK Flash attention"')
|
||||
log.debug('Torch attention: type="Flash attention"')
|
||||
except Exception as err:
|
||||
log.error(f'Torch attention: type="CK Flash attention" {err}')
|
||||
log.error(f'Torch attention: type="Flash attention" {err}')
|
||||
|
||||
def set_sage_attention(backend: str, device: torch.device):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -477,7 +477,7 @@ def set_sdpa_params():
|
|||
if 'Triton Flash attention' in opts.sdp_overrides:
|
||||
attention.set_triton_flash_attention(backend)
|
||||
|
||||
if 'CK Flash attention' in opts.sdp_overrides:
|
||||
if 'Flash attention' in opts.sdp_overrides:
|
||||
attention.set_ck_flash_attention(backend, device)
|
||||
|
||||
if 'Sage attention' in opts.sdp_overrides:
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def get_default_modes(cmd_opts, mem_stat):
|
|||
default_sdp_choices = ['Flash', 'Memory', 'Math']
|
||||
default_sdp_options = ['Flash', 'Memory', 'Math']
|
||||
|
||||
default_sdp_override_choices = ['Dynamic attention', 'CK Flash attention', 'Sage attention']
|
||||
default_sdp_override_choices = ['Dynamic attention', 'Flash attention', 'Sage attention']
|
||||
default_sdp_override_options = []
|
||||
|
||||
if devices.backend == "zluda":
|
||||
|
|
|
|||
Loading…
Reference in New Issue