diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py index bfbb50d..383df70 100644 --- a/ip_adapter/attention_processor.py +++ b/ip_adapter/attention_processor.py @@ -143,6 +143,7 @@ class IPAttnProcessor(nn.Module): if encoder_hidden_states is None: encoder_hidden_states = hidden_states + ip_hidden_states = None else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens @@ -165,31 +166,32 @@ class IPAttnProcessor(nn.Module): hidden_states = attn.batch_to_head_dim(hidden_states) # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) - - if xformers_available: - ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) - else: - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + if ip_hidden_states is not None: + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) - # region control - if len(region_control.prompt_image_conditioning) == 1: - region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) - if region_mask is not None: - h, w = region_mask.shape[:2] - ratio = (h * w / query.shape[1]) ** 0.5 - mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + if xformers_available: + ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) else: - mask = torch.ones_like(ip_hidden_states) - ip_hidden_states = ip_hidden_states * mask + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - hidden_states = hidden_states + self.scale * ip_hidden_states + # region control + if len(region_control.prompt_image_conditioning) == 1: + region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) + if region_mask is not None: + h, w = region_mask.shape[:2] + ratio = (h * w / query.shape[1]) ** 0.5 + mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) + else: + mask = torch.ones_like(ip_hidden_states) + ip_hidden_states = ip_hidden_states * mask + + hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -368,6 +370,7 @@ class IPAttnProcessor2_0(torch.nn.Module): if encoder_hidden_states is None: encoder_hidden_states = hidden_states + ip_hidden_states = None else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens @@ -399,37 +402,38 @@ class IPAttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states.to(query.dtype) # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) + if ip_hidden_states is not None: + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - with torch.no_grad(): - self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) - #print(self.attn_map.shape) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + with torch.no_grad(): + self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + #print(self.attn_map.shape) - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) - # region control - if len(region_control.prompt_image_conditioning) == 1: - region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) - if region_mask is not None: - query = query.reshape([-1, query.shape[-2], query.shape[-1]]) - h, w = region_mask.shape[:2] - ratio = (h * w / query.shape[1]) ** 0.5 - mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) - else: - mask = torch.ones_like(ip_hidden_states) - ip_hidden_states = ip_hidden_states * mask + # region control + if len(region_control.prompt_image_conditioning) == 1: + region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) + if region_mask is not None: + query = query.reshape([-1, query.shape[-2], query.shape[-1]]) + h, w = region_mask.shape[:2] + ratio = (h * w / query.shape[1]) ** 0.5 + mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) + else: + mask = torch.ones_like(ip_hidden_states) + ip_hidden_states = ip_hidden_states * mask - hidden_states = hidden_states + self.scale * ip_hidden_states + hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states)