memory efficient attention fallthrough (#330)

pull/331/head
Jabasukuriputo Wang 2023-11-20 16:38:20 -06:00 committed by GitHub
parent 4ce29a4a3c
commit 4a788668d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 50 additions and 22 deletions

View File

@ -493,31 +493,59 @@ class CrossAttention(nn.Module):
k = k.contiguous()
v = v.contiguous()
if current_optimizer_name == "xformers":
import xformers.ops
from modules.sd_hijack_optimizations import get_xformers_flash_attention_op
hidden_states = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=mask,
op=get_xformers_flash_attention_op(q, k, v))
elif current_optimizer_name == "sdp":
hidden_states = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
elif current_optimizer_name == "sdp-no-mem":
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
fallthrough = False
if current_optimizer_name == "xformers" or fallthrough:
fallthrough = False
try:
import xformers.ops
from modules.sd_hijack_optimizations import get_xformers_flash_attention_op
hidden_states = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=mask,
op=get_xformers_flash_attention_op(q, k, v))
except (ImportError, RuntimeError, AttributeError):
fallthrough = True
if current_optimizer_name == "sdp" or fallthrough:
fallthrough = False
try:
hidden_states = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
elif current_optimizer_name == "sub-quadratic":
from modules.sd_hijack_optimizations import sub_quad_attention
from modules import shared
hidden_states = sub_quad_attention(
q, k, v,
q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size,
kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size,
chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold,
use_checkpoint=self.training
)
except (ImportError, RuntimeError, AttributeError):
fallthrough = True
if current_optimizer_name == "sdp-no-mem" or fallthrough:
fallthrough = False
try:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
hidden_states = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
except (ImportError, RuntimeError, AttributeError):
fallthrough = True
if current_optimizer_name == "sub-quadratic" or fallthrough:
fallthrough = False
try:
from modules.sd_hijack_optimizations import sub_quad_attention
from modules import shared
hidden_states = sub_quad_attention(
q, k, v,
q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size,
kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size,
chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold,
use_checkpoint=self.training
)
except (ImportError, RuntimeError, AttributeError):
fallthrough = True
if fallthrough:
fallthrough = False
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states