mirror of https://github.com/vladmandic/automatic
Make SDPA hijacks chainable and add Sage Attention
parent
ea0dfebe2d
commit
2e2cb43406
|
|
@ -24,6 +24,8 @@ dtype_vae = None
|
|||
dtype_unet = None
|
||||
unet_needs_upcast = False # compatibility item
|
||||
onnx = None
|
||||
sdpa_original = None
|
||||
sdpa_pre_dyanmic_atten = None
|
||||
previous_oom = 0 # oom counter
|
||||
if debug:
|
||||
log.info(f'Torch build config: {torch.__config__.show()}')
|
||||
|
|
@ -329,26 +331,51 @@ def set_sdpa_params():
|
|||
torch.backends.cuda.enable_flash_sdp('Flash attention' in opts.sdp_options)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp('Memory attention' in opts.sdp_options)
|
||||
torch.backends.cuda.enable_math_sdp('Math attention' in opts.sdp_options)
|
||||
global sdpa_original
|
||||
if sdpa_original is not None:
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_original
|
||||
else:
|
||||
sdpa_original = torch.nn.functional.scaled_dot_product_attention
|
||||
if backend == "rocm":
|
||||
if 'Flash attention' in opts.sdp_options:
|
||||
try:
|
||||
# https://github.com/huggingface/diffusers/discussions/7172
|
||||
from flash_attn import flash_attn_func
|
||||
from functools import wraps
|
||||
backup_sdpa = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||
def sdpa_hijack(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
||||
if query.shape[3] <= 128 and attn_mask is None and query.dtype != torch.float32:
|
||||
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_flash_atten)
|
||||
def sdpa_flash_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
||||
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
|
||||
return flash_attn_func(q=query.transpose(1, 2), k=key.transpose(1, 2), v=value.transpose(1, 2), dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2)
|
||||
else:
|
||||
return backup_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_hijack
|
||||
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
|
||||
log.debug('ROCm Flash Attention Hijacked')
|
||||
except Exception as err:
|
||||
log.error(f'ROCm Flash Attention failed: {err}')
|
||||
if 'Sage attention' in opts.sdp_options:
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_sage_atten)
|
||||
def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
||||
if query.shape[-1] in {128, 96, 64} and attn_mask is None and query.dtype != torch.float32:
|
||||
return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale).transpose(1, 2)
|
||||
else:
|
||||
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
|
||||
log.debug('SDPA Sage Attention Hijacked')
|
||||
except Exception as err:
|
||||
log.error(f'SDPA Sage Attention failed: {err}')
|
||||
if 'Dynamic attention' in opts.sdp_options:
|
||||
from modules.sd_hijack_dynamic_atten import sliced_scaled_dot_product_attention
|
||||
torch.nn.functional.scaled_dot_product_attention = sliced_scaled_dot_product_attention
|
||||
try:
|
||||
global sdpa_pre_dyanmic_atten
|
||||
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
from modules.sd_hijack_dynamic_atten import sliced_scaled_dot_product_attention
|
||||
torch.nn.functional.scaled_dot_product_attention = sliced_scaled_dot_product_attention
|
||||
log.debug('SDPA Dynamic Attention Hijacked')
|
||||
except Exception as err:
|
||||
log.error(f'SDPA Dynamic Attention failed: {err}')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -47,9 +47,9 @@ def find_slice_sizes(query_shape, query_element_size, slice_rate=4):
|
|||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
|
||||
backup_sdpa = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||
if devices.sdpa_pre_dyanmic_atten is None:
|
||||
devices.sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(devices.sdpa_pre_dyanmic_atten)
|
||||
def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_slice_sizes(query.shape, query.element_size(), slice_rate=shared.opts.dynamic_attention_slice_rate)
|
||||
|
||||
|
|
@ -68,7 +68,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo
|
|||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = backup_sdpa(
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = devices.sdpa_pre_dyanmic_atten(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
|
|
@ -76,7 +76,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo
|
|||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = backup_sdpa(
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = devices.sdpa_pre_dyanmic_atten(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
|
|
@ -84,7 +84,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo
|
|||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx] = backup_sdpa(
|
||||
hidden_states[start_idx:end_idx] = devices.sdpa_pre_dyanmic_atten(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
|
|
@ -94,7 +94,7 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo
|
|||
if devices.backend != "directml":
|
||||
getattr(torch, query.device.type).synchronize()
|
||||
else:
|
||||
return backup_sdpa(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
return devices.sdpa_pre_dyanmic_atten(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue