Merge pull request #224 from tencent-ailab/revert-223-main
Revert "Support for other models."pull/226/head
commit
6a8f2a2dad
|
|
@ -162,7 +162,6 @@ class IPAttnProcessor(nn.Module):
|
|||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
self.attn_map = ip_attention_probs
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
|
||||
|
|
@ -379,9 +378,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
with torch.no_grad():
|
||||
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
|
||||
#print(self.attn_map.shape)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
|
|
|||
|
|
@ -50,9 +50,10 @@ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detac
|
|||
|
||||
for name, attn_map in attn_maps.items():
|
||||
attn_map = attn_map.cpu() if detach else attn_map
|
||||
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
|
||||
attn_map = upscale(attn_map, image_size)
|
||||
net_attn_maps.append(attn_map)
|
||||
attn_map = torch.chunk(attn_map, batch_size)[idx] # (20, 32*32, 77) -> (10, 32*32, 77) # negative & positive CFG
|
||||
|
||||
attn_map = upscale(attn_map, image_size) # (10,32*32,77) -> (77,64*64)
|
||||
net_attn_maps.append(attn_map) # (10,32*32,77) -> (77,64*64)
|
||||
|
||||
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
|
||||
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue