mirror of https://github.com/InstantID/InstantID
Merge 6a1bbdf5c9 into 2145b67f96
commit
2c6e615700
|
|
@ -143,6 +143,7 @@ class IPAttnProcessor(nn.Module):
|
|||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
ip_hidden_states = None
|
||||
else:
|
||||
# get encoder_hidden_states, ip_hidden_states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
|
|
@ -165,31 +166,32 @@ class IPAttnProcessor(nn.Module):
|
|||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
if xformers_available:
|
||||
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
|
||||
else:
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
if ip_hidden_states is not None:
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
# region control
|
||||
if len(region_control.prompt_image_conditioning) == 1:
|
||||
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
|
||||
if region_mask is not None:
|
||||
h, w = region_mask.shape[:2]
|
||||
ratio = (h * w / query.shape[1]) ** 0.5
|
||||
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
if xformers_available:
|
||||
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
|
||||
else:
|
||||
mask = torch.ones_like(ip_hidden_states)
|
||||
ip_hidden_states = ip_hidden_states * mask
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
# region control
|
||||
if len(region_control.prompt_image_conditioning) == 1:
|
||||
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
|
||||
if region_mask is not None:
|
||||
h, w = region_mask.shape[:2]
|
||||
ratio = (h * w / query.shape[1]) ** 0.5
|
||||
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
|
||||
else:
|
||||
mask = torch.ones_like(ip_hidden_states)
|
||||
ip_hidden_states = ip_hidden_states * mask
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
|
|
@ -368,6 +370,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
ip_hidden_states = None
|
||||
else:
|
||||
# get encoder_hidden_states, ip_hidden_states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
|
|
@ -399,37 +402,38 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
if ip_hidden_states is not None:
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
with torch.no_grad():
|
||||
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
|
||||
#print(self.attn_map.shape)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
with torch.no_grad():
|
||||
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
|
||||
#print(self.attn_map.shape)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# region control
|
||||
if len(region_control.prompt_image_conditioning) == 1:
|
||||
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
|
||||
if region_mask is not None:
|
||||
query = query.reshape([-1, query.shape[-2], query.shape[-1]])
|
||||
h, w = region_mask.shape[:2]
|
||||
ratio = (h * w / query.shape[1]) ** 0.5
|
||||
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
|
||||
else:
|
||||
mask = torch.ones_like(ip_hidden_states)
|
||||
ip_hidden_states = ip_hidden_states * mask
|
||||
# region control
|
||||
if len(region_control.prompt_image_conditioning) == 1:
|
||||
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
|
||||
if region_mask is not None:
|
||||
query = query.reshape([-1, query.shape[-2], query.shape[-1]])
|
||||
h, w = region_mask.shape[:2]
|
||||
ratio = (h * w / query.shape[1]) ** 0.5
|
||||
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
|
||||
else:
|
||||
mask = torch.ones_like(ip_hidden_states)
|
||||
ip_hidden_states = ip_hidden_states * mask
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
|
|
|
|||
Loading…
Reference in New Issue