Merge pull request #224 from tencent-ailab/revert-223-main

Revert "Support for other models."
pull/226/head
Hu Ye 2024-01-04 11:50:00 +08:00 committed by GitHub
commit 6a8f2a2dad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 371 additions and 7 deletions

View File

@ -162,7 +162,6 @@ class IPAttnProcessor(nn.Module):
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
self.attn_map = ip_attention_probs
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
@ -379,9 +378,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
#print(self.attn_map.shape)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)

View File

@ -50,9 +50,10 @@ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detac
for name, attn_map in attn_maps.items():
attn_map = attn_map.cpu() if detach else attn_map
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
attn_map = upscale(attn_map, image_size)
net_attn_maps.append(attn_map)
attn_map = torch.chunk(attn_map, batch_size)[idx] # (20, 32*32, 77) -> (10, 32*32, 77) # negative & positive CFG
attn_map = upscale(attn_map, image_size) # (10,32*32,77) -> (77,64*64)
net_attn_maps.append(attn_map) # (10,32*32,77) -> (77,64*64)
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)

367
visual_attnmap.ipynb Normal file

File diff suppressed because one or more lines are too long