diff --git a/modules/intel/ipex/attention.py b/modules/intel/ipex/attention.py index dead035e0..3e58c0761 100644 --- a/modules/intel/ipex/attention.py +++ b/modules/intel/ipex/attention.py @@ -136,6 +136,11 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo if do_split: batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + if attn_mask is not None and attn_mask.shape != query.shape: + if len(query.shape) == 4: + attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2], 1)) + else: + attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2])) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size end_idx = (i + 1) * split_slice_size diff --git a/modules/sd_hijack_dynamic_atten.py b/modules/sd_hijack_dynamic_atten.py index 9ee7d72ac..48a5760d3 100644 --- a/modules/sd_hijack_dynamic_atten.py +++ b/modules/sd_hijack_dynamic_atten.py @@ -57,6 +57,11 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo if do_split: batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + if attn_mask is not None and attn_mask.shape != query.shape: + if len(query.shape) == 4: + attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2], 1)) + else: + attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2])) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size end_idx = (i + 1) * split_slice_size