Make SDPA hijacks chainable and add Sage Attention

pull/3490/head
Disty0 2024-10-12 21:19:38 +03:00
parent ea0dfebe2d
commit 2e2cb43406
2 changed files with 42 additions and 15 deletions

View File

@ -24,6 +24,8 @@ dtype_vae = None
dtype_unet = None
unet_needs_upcast = False # compatibility item
onnx = None
sdpa_original = None
sdpa_pre_dyanmic_atten = None
previous_oom = 0 # oom counter
if debug:
log.info(f'Torch build config: {torch.__config__.show()}')
@ -329,26 +331,51 @@ def set_sdpa_params():
torch.backends.cuda.enable_flash_sdp('Flash attention' in opts.sdp_options)
torch.backends.cuda.enable_mem_efficient_sdp('Memory attention' in opts.sdp_options)
torch.backends.cuda.enable_math_sdp('Math attention' in opts.sdp_options)
global sdpa_original
if sdpa_original is not None:
torch.nn.functional.scaled_dot_product_attention = sdpa_original
else:
sdpa_original = torch.nn.functional.scaled_dot_product_attention
if backend == "rocm":
if 'Flash attention' in opts.sdp_options:
try:
# https://github.com/huggingface/diffusers/discussions/7172
from flash_attn import flash_attn_func
from functools import wraps
backup_sdpa = torch.nn.functional.scaled_dot_product_attention
@wraps(torch.nn.functional.scaled_dot_product_attention)
def sdpa_hijack(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[3] <= 128 and attn_mask is None and query.dtype != torch.float32:
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flash_atten)
def sdpa_flash_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
return flash_attn_func(q=query.transpose(1, 2), k=key.transpose(1, 2), v=value.transpose(1, 2), dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2)
else:
return backup_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_hijack
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
log.debug('ROCm Flash Attention Hijacked')
except Exception as err:
log.error(f'ROCm Flash Attention failed: {err}')
if 'Sage attention' in opts.sdp_options:
try:
from sageattention import sageattn
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_sage_atten)
def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[-1] in {128, 96, 64} and attn_mask is None and query.dtype != torch.float32:
return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale).transpose(1, 2)
else:
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
log.debug('SDPA Sage Attention Hijacked')
except Exception as err:
log.error(f'SDPA Sage Attention failed: {err}')
if 'Dynamic attention' in opts.sdp_options:
from modules.sd_hijack_dynamic_atten import sliced_scaled_dot_product_attention
torch.nn.functional.scaled_dot_product_attention = sliced_scaled_dot_product_attention
try:
global sdpa_pre_dyanmic_atten
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
from modules.sd_hijack_dynamic_atten import sliced_scaled_dot_product_attention
torch.nn.functional.scaled_dot_product_attention = sliced_scaled_dot_product_attention
log.debug('SDPA Dynamic Attention Hijacked')
except Exception as err:
log.error(f'SDPA Dynamic Attention failed: {err}')
except Exception:
pass

View File

@ -47,9 +47,9 @@ def find_slice_sizes(query_shape, query_element_size, slice_rate=4):
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
backup_sdpa = torch.nn.functional.scaled_dot_product_attention
@wraps(torch.nn.functional.scaled_dot_product_attention)
if devices.sdpa_pre_dyanmic_atten is None:
devices.sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(devices.sdpa_pre_dyanmic_atten)
def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_slice_sizes(query.shape, query.element_size(), slice_rate=shared.opts.dynamic_attention_slice_rate)
@ -68,7 +68,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = backup_sdpa(
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = devices.sdpa_pre_dyanmic_atten(
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
@ -76,7 +76,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = backup_sdpa(
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = devices.sdpa_pre_dyanmic_atten(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
@ -84,7 +84,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx] = backup_sdpa(
hidden_states[start_idx:end_idx] = devices.sdpa_pre_dyanmic_atten(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
@ -94,7 +94,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo
if devices.backend != "directml":
getattr(torch, query.device.type).synchronize()
else:
return backup_sdpa(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
return devices.sdpa_pre_dyanmic_atten(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
return hidden_states