diff --git a/modules/devices.py b/modules/devices.py index 253bd10a7..1274bf82b 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -24,6 +24,8 @@ dtype_vae = None dtype_unet = None unet_needs_upcast = False # compatibility item onnx = None +sdpa_original = None +sdpa_pre_dyanmic_atten = None previous_oom = 0 # oom counter if debug: log.info(f'Torch build config: {torch.__config__.show()}') @@ -329,26 +331,51 @@ def set_sdpa_params(): torch.backends.cuda.enable_flash_sdp('Flash attention' in opts.sdp_options) torch.backends.cuda.enable_mem_efficient_sdp('Memory attention' in opts.sdp_options) torch.backends.cuda.enable_math_sdp('Math attention' in opts.sdp_options) + global sdpa_original + if sdpa_original is not None: + torch.nn.functional.scaled_dot_product_attention = sdpa_original + else: + sdpa_original = torch.nn.functional.scaled_dot_product_attention if backend == "rocm": if 'Flash attention' in opts.sdp_options: try: # https://github.com/huggingface/diffusers/discussions/7172 from flash_attn import flash_attn_func from functools import wraps - backup_sdpa = torch.nn.functional.scaled_dot_product_attention - @wraps(torch.nn.functional.scaled_dot_product_attention) - def sdpa_hijack(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): - if query.shape[3] <= 128 and attn_mask is None and query.dtype != torch.float32: + sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention + @wraps(sdpa_pre_flash_atten) + def sdpa_flash_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): + if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32: return flash_attn_func(q=query.transpose(1, 2), k=key.transpose(1, 2), v=value.transpose(1, 2), dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2) else: - return backup_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - torch.nn.functional.scaled_dot_product_attention = sdpa_hijack + return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) + torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten log.debug('ROCm Flash Attention Hijacked') except Exception as err: log.error(f'ROCm Flash Attention failed: {err}') + if 'Sage attention' in opts.sdp_options: + try: + from sageattention import sageattn + sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention + @wraps(sdpa_pre_sage_atten) + def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): + if query.shape[-1] in {128, 96, 64} and attn_mask is None and query.dtype != torch.float32: + return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale).transpose(1, 2) + else: + return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) + torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten + log.debug('SDPA Sage Attention Hijacked') + except Exception as err: + log.error(f'SDPA Sage Attention failed: {err}') if 'Dynamic attention' in opts.sdp_options: - from modules.sd_hijack_dynamic_atten import sliced_scaled_dot_product_attention - torch.nn.functional.scaled_dot_product_attention = sliced_scaled_dot_product_attention + try: + global sdpa_pre_dyanmic_atten + sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention + from modules.sd_hijack_dynamic_atten import sliced_scaled_dot_product_attention + torch.nn.functional.scaled_dot_product_attention = sliced_scaled_dot_product_attention + log.debug('SDPA Dynamic Attention Hijacked') + except Exception as err: + log.error(f'SDPA Dynamic Attention failed: {err}') except Exception: pass diff --git a/modules/sd_hijack_dynamic_atten.py b/modules/sd_hijack_dynamic_atten.py index 1c17e024c..9ee7d72ac 100644 --- a/modules/sd_hijack_dynamic_atten.py +++ b/modules/sd_hijack_dynamic_atten.py @@ -47,9 +47,9 @@ def find_slice_sizes(query_shape, query_element_size, slice_rate=4): return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - -backup_sdpa = torch.nn.functional.scaled_dot_product_attention -@wraps(torch.nn.functional.scaled_dot_product_attention) +if devices.sdpa_pre_dyanmic_atten is None: + devices.sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention +@wraps(devices.sdpa_pre_dyanmic_atten) def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_slice_sizes(query.shape, query.element_size(), slice_rate=shared.opts.dynamic_attention_slice_rate) @@ -68,7 +68,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name start_idx_3 = i3 * split_3_slice_size end_idx_3 = (i3 + 1) * split_3_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = backup_sdpa( + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = devices.sdpa_pre_dyanmic_atten( query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], @@ -76,7 +76,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = backup_sdpa( + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = devices.sdpa_pre_dyanmic_atten( query[start_idx:end_idx, start_idx_2:end_idx_2], key[start_idx:end_idx, start_idx_2:end_idx_2], value[start_idx:end_idx, start_idx_2:end_idx_2], @@ -84,7 +84,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: - hidden_states[start_idx:end_idx] = backup_sdpa( + hidden_states[start_idx:end_idx] = devices.sdpa_pre_dyanmic_atten( query[start_idx:end_idx], key[start_idx:end_idx], value[start_idx:end_idx], @@ -94,7 +94,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo if devices.backend != "directml": getattr(torch, query.device.type).synchronize() else: - return backup_sdpa(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + return devices.sdpa_pre_dyanmic_atten(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) return hidden_states