Revert "Support for other models."
parent
b9390fa168
commit
7e9a3748c7
|
|
@ -162,7 +162,6 @@ class IPAttnProcessor(nn.Module):
|
||||||
ip_value = attn.head_to_batch_dim(ip_value)
|
ip_value = attn.head_to_batch_dim(ip_value)
|
||||||
|
|
||||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||||
self.attn_map = ip_attention_probs
|
|
||||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||||
|
|
||||||
|
|
@ -379,9 +378,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||||
ip_hidden_states = F.scaled_dot_product_attention(
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
)
|
)
|
||||||
with torch.no_grad():
|
|
||||||
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
|
|
||||||
#print(self.attn_map.shape)
|
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
|
||||||
|
|
@ -50,9 +50,10 @@ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detac
|
||||||
|
|
||||||
for name, attn_map in attn_maps.items():
|
for name, attn_map in attn_maps.items():
|
||||||
attn_map = attn_map.cpu() if detach else attn_map
|
attn_map = attn_map.cpu() if detach else attn_map
|
||||||
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
|
attn_map = torch.chunk(attn_map, batch_size)[idx] # (20, 32*32, 77) -> (10, 32*32, 77) # negative & positive CFG
|
||||||
attn_map = upscale(attn_map, image_size)
|
|
||||||
net_attn_maps.append(attn_map)
|
attn_map = upscale(attn_map, image_size) # (10,32*32,77) -> (77,64*64)
|
||||||
|
net_attn_maps.append(attn_map) # (10,32*32,77) -> (77,64*64)
|
||||||
|
|
||||||
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
|
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue