diff --git a/modules/attention.py b/modules/attention.py index bef7cc3ea..96296a687 100644 --- a/modules/attention.py +++ b/modules/attention.py @@ -83,9 +83,9 @@ def set_ck_flash_attention(backend: str, device: torch.device): kwargs["enable_gqa"] = enable_gqa return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs) torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten - log.debug('Torch attention: type="CK Flash attention"') + log.debug('Torch attention: type="Flash attention"') except Exception as err: - log.error(f'Torch attention: type="CK Flash attention" {err}') + log.error(f'Torch attention: type="Flash attention" {err}') def set_sage_attention(backend: str, device: torch.device): try: diff --git a/modules/devices.py b/modules/devices.py index 270e4d1bc..ed29d5421 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -477,7 +477,7 @@ def set_sdpa_params(): if 'Triton Flash attention' in opts.sdp_overrides: attention.set_triton_flash_attention(backend) - if 'CK Flash attention' in opts.sdp_overrides: + if 'Flash attention' in opts.sdp_overrides: attention.set_ck_flash_attention(backend, device) if 'Sage attention' in opts.sdp_overrides: diff --git a/modules/shared_defaults.py b/modules/shared_defaults.py index 669c0ecf1..fef7462fd 100644 --- a/modules/shared_defaults.py +++ b/modules/shared_defaults.py @@ -43,7 +43,7 @@ def get_default_modes(cmd_opts, mem_stat): default_sdp_choices = ['Flash', 'Memory', 'Math'] default_sdp_options = ['Flash', 'Memory', 'Math'] - default_sdp_override_choices = ['Dynamic attention', 'CK Flash attention', 'Sage attention'] + default_sdp_override_choices = ['Dynamic attention', 'Flash attention', 'Sage attention'] default_sdp_override_options = [] if devices.backend == "zluda":