Dynamic Atten fix OmniGen

pull/3518/head
Disty0 2024-10-24 21:22:06 +03:00
parent e424b343ce
commit 3195e8ad1f
2 changed files with 10 additions and 0 deletions

View File

@ -136,6 +136,11 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
if do_split:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
if attn_mask is not None and attn_mask.shape != query.shape:
if len(query.shape) == 4:
attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2], 1))
else:
attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2]))
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size

View File

@ -57,6 +57,11 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo
if do_split:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
if attn_mask is not None and attn_mask.shape != query.shape:
if len(query.shape) == 4:
attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2], 1))
else:
attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2]))
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size