mirror of https://github.com/vladmandic/automatic
Dynamic Atten fix OmniGen
parent
e424b343ce
commit
3195e8ad1f
|
|
@ -136,6 +136,11 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
|
|||
if do_split:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
if attn_mask is not None and attn_mask.shape != query.shape:
|
||||
if len(query.shape) == 4:
|
||||
attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2], 1))
|
||||
else:
|
||||
attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2]))
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
|
|
|
|||
|
|
@ -57,6 +57,11 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo
|
|||
if do_split:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
if attn_mask is not None and attn_mask.shape != query.shape:
|
||||
if len(query.shape) == 4:
|
||||
attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2], 1))
|
||||
else:
|
||||
attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2]))
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
|
|
|
|||
Loading…
Reference in New Issue