pull/299/merge
Harikrishna KP 2026-02-11 20:07:33 +05:30 committed by GitHub
commit 2c6e615700
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 52 additions and 48 deletions

View File

@ -143,6 +143,7 @@ class IPAttnProcessor(nn.Module):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
ip_hidden_states = None
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
@ -165,31 +166,32 @@ class IPAttnProcessor(nn.Module):
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
if xformers_available:
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
else:
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
# region control
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
if region_mask is not None:
h, w = region_mask.shape[:2]
ratio = (h * w / query.shape[1]) ** 0.5
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
if xformers_available:
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
else:
mask = torch.ones_like(ip_hidden_states)
ip_hidden_states = ip_hidden_states * mask
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# region control
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
if region_mask is not None:
h, w = region_mask.shape[:2]
ratio = (h * w / query.shape[1]) ** 0.5
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
else:
mask = torch.ones_like(ip_hidden_states)
ip_hidden_states = ip_hidden_states * mask
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@ -368,6 +370,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
ip_hidden_states = None
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
@ -399,37 +402,38 @@ class IPAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
#print(self.attn_map.shape)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
#print(self.attn_map.shape)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# region control
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
if region_mask is not None:
query = query.reshape([-1, query.shape[-2], query.shape[-1]])
h, w = region_mask.shape[:2]
ratio = (h * w / query.shape[1]) ** 0.5
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
else:
mask = torch.ones_like(ip_hidden_states)
ip_hidden_states = ip_hidden_states * mask
# region control
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
if region_mask is not None:
query = query.reshape([-1, query.shape[-2], query.shape[-1]])
h, w = region_mask.shape[:2]
ratio = (h * w / query.shape[1]) ** 0.5
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
else:
mask = torch.ones_like(ip_hidden_states)
ip_hidden_states = ip_hidden_states * mask
hidden_states = hidden_states + self.scale * ip_hidden_states
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)