memory efficient attention fallthrough (#330)
parent
4ce29a4a3c
commit
4a788668d0
|
|
@ -493,31 +493,59 @@ class CrossAttention(nn.Module):
|
|||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
if current_optimizer_name == "xformers":
|
||||
import xformers.ops
|
||||
from modules.sd_hijack_optimizations import get_xformers_flash_attention_op
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=mask,
|
||||
op=get_xformers_flash_attention_op(q, k, v))
|
||||
elif current_optimizer_name == "sdp":
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
elif current_optimizer_name == "sdp-no-mem":
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
fallthrough = False
|
||||
|
||||
if current_optimizer_name == "xformers" or fallthrough:
|
||||
fallthrough = False
|
||||
try:
|
||||
import xformers.ops
|
||||
from modules.sd_hijack_optimizations import get_xformers_flash_attention_op
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=mask,
|
||||
op=get_xformers_flash_attention_op(q, k, v))
|
||||
except (ImportError, RuntimeError, AttributeError):
|
||||
fallthrough = True
|
||||
|
||||
if current_optimizer_name == "sdp" or fallthrough:
|
||||
fallthrough = False
|
||||
try:
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
elif current_optimizer_name == "sub-quadratic":
|
||||
from modules.sd_hijack_optimizations import sub_quad_attention
|
||||
from modules import shared
|
||||
hidden_states = sub_quad_attention(
|
||||
q, k, v,
|
||||
q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size,
|
||||
kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size,
|
||||
chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold,
|
||||
use_checkpoint=self.training
|
||||
)
|
||||
except (ImportError, RuntimeError, AttributeError):
|
||||
fallthrough = True
|
||||
|
||||
if current_optimizer_name == "sdp-no-mem" or fallthrough:
|
||||
fallthrough = False
|
||||
try:
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
except (ImportError, RuntimeError, AttributeError):
|
||||
fallthrough = True
|
||||
|
||||
if current_optimizer_name == "sub-quadratic" or fallthrough:
|
||||
fallthrough = False
|
||||
try:
|
||||
from modules.sd_hijack_optimizations import sub_quad_attention
|
||||
from modules import shared
|
||||
hidden_states = sub_quad_attention(
|
||||
q, k, v,
|
||||
q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size,
|
||||
kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size,
|
||||
chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold,
|
||||
use_checkpoint=self.training
|
||||
)
|
||||
except (ImportError, RuntimeError, AttributeError):
|
||||
fallthrough = True
|
||||
|
||||
if fallthrough:
|
||||
fallthrough = False
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value, attention_mask)
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
||||
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
|
|
|||
Loading…
Reference in New Issue