95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from dataclasses import dataclass
|
|
from typing import Callable
|
|
|
|
|
|
@dataclass
|
|
class PuLIDAttnSetting:
|
|
num_zero: int = 0
|
|
ortho: bool = False
|
|
ortho_v2: bool = False
|
|
|
|
def eval(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
query: torch.Tensor,
|
|
id_embedding: torch.Tensor,
|
|
attn_heads: int,
|
|
head_dim: int,
|
|
id_to_k: Callable[[torch.Tensor], torch.Tensor],
|
|
id_to_v: Callable[[torch.Tensor], torch.Tensor],
|
|
):
|
|
assert hidden_states.ndim == 3
|
|
batch_size, sequence_length, inner_dim = hidden_states.shape
|
|
|
|
if self.num_zero == 0:
|
|
id_key = id_to_k(id_embedding).to(query.dtype)
|
|
id_value = id_to_v(id_embedding).to(query.dtype)
|
|
else:
|
|
zero_tensor = torch.zeros(
|
|
(id_embedding.size(0), self.num_zero, id_embedding.size(-1)),
|
|
dtype=id_embedding.dtype,
|
|
device=id_embedding.device,
|
|
)
|
|
id_key = id_to_k(torch.cat((id_embedding, zero_tensor), dim=1)).to(
|
|
query.dtype
|
|
)
|
|
id_value = id_to_v(torch.cat((id_embedding, zero_tensor), dim=1)).to(
|
|
query.dtype
|
|
)
|
|
|
|
id_key = id_key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
|
id_value = id_value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
|
|
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
id_hidden_states = F.scaled_dot_product_attention(
|
|
query, id_key, id_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
|
)
|
|
|
|
id_hidden_states = id_hidden_states.transpose(1, 2).reshape(
|
|
batch_size, -1, attn_heads * head_dim
|
|
)
|
|
id_hidden_states = id_hidden_states.to(query.dtype)
|
|
|
|
if not self.ortho and not self.ortho_v2:
|
|
return id_hidden_states
|
|
elif self.ortho_v2:
|
|
orig_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
id_hidden_states = id_hidden_states.to(torch.float32)
|
|
attn_map = query @ id_key.transpose(-2, -1)
|
|
attn_mean = attn_map.softmax(dim=-1).mean(dim=1)
|
|
attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True)
|
|
projection = (
|
|
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
|
|
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
|
|
* hidden_states
|
|
)
|
|
orthogonal = id_hidden_states + (attn_mean - 1) * projection
|
|
return orthogonal.to(orig_dtype)
|
|
else:
|
|
orig_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
id_hidden_states = id_hidden_states.to(torch.float32)
|
|
projection = (
|
|
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
|
|
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
|
|
* hidden_states
|
|
)
|
|
orthogonal = id_hidden_states - projection
|
|
return orthogonal.to(orig_dtype)
|
|
|
|
|
|
PULID_SETTING_FIDELITY = PuLIDAttnSetting(
|
|
num_zero=8,
|
|
ortho=False,
|
|
ortho_v2=True,
|
|
)
|
|
|
|
PULID_SETTING_STYLE = PuLIDAttnSetting(
|
|
num_zero=16,
|
|
ortho=True,
|
|
ortho_v2=False,
|
|
)
|