From e36ca465ec6606c546f0de20968eb837141972e1 Mon Sep 17 00:00:00 2001 From: InstantX <153269709+ResearcherXman@users.noreply.github.com> Date: Fri, 2 Feb 2024 01:05:48 +0800 Subject: [PATCH] Fix typos Better support IPAttnProcessor2_0 --- ip_adapter/attention_processor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py index 745143d..bfbb50d 100644 --- a/ip_adapter/attention_processor.py +++ b/ip_adapter/attention_processor.py @@ -421,6 +421,7 @@ class IPAttnProcessor2_0(torch.nn.Module): if len(region_control.prompt_image_conditioning) == 1: region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) if region_mask is not None: + query = query.reshape([-1, query.shape[-2], query.shape[-1]]) h, w = region_mask.shape[:2] ratio = (h * w / query.shape[1]) ** 0.5 mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) @@ -443,4 +444,4 @@ class IPAttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states \ No newline at end of file + return hidden_states