addpositional embedding and mean pooled latents
parent
608f3433c9
commit
9fc189e3fb
|
|
@ -1,8 +1,12 @@
|
|||
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
|
||||
# FFN
|
||||
|
|
@ -85,8 +89,12 @@ class Resampler(nn.Module):
|
|||
embedding_dim=768,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
max_seq_len: int = 257, # CLIP tokens + CLS token
|
||||
apply_pos_emb: bool = False,
|
||||
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
|
||||
):
|
||||
super().__init__()
|
||||
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
|
||||
|
|
@ -95,6 +103,16 @@ class Resampler(nn.Module):
|
|||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.to_latents_from_mean_pooled_seq = (
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim * num_latents_mean_pooled),
|
||||
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
||||
)
|
||||
if num_latents_mean_pooled > 0
|
||||
else None
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
|
|
@ -107,13 +125,34 @@ class Resampler(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.pos_emb is not None:
|
||||
n, device = x.shape[1], x.device
|
||||
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
||||
x = x + pos_emb
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
if self.to_latents_from_mean_pooled_seq:
|
||||
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
|
||||
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
||||
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
|
||||
def masked_mean(t, *, dim, mask=None):
|
||||
if mask is None:
|
||||
return t.mean(dim=dim)
|
||||
|
||||
denom = mask.sum(dim=dim, keepdim=True)
|
||||
mask = rearrange(mask, "b n -> b n 1")
|
||||
masked_t = t.masked_fill(~mask, 0.0)
|
||||
|
||||
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
from resampler import Resampler
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
BATCH_SIZE = 2
|
||||
OUTPUT_DIM = 1280
|
||||
NUM_QUERIES = 8
|
||||
NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior)
|
||||
APPLY_POS_EMB = True # False for no positional embeddings (previous behavior)
|
||||
IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
||||
|
||||
|
||||
def main():
|
||||
image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH)
|
||||
embedding_dim = image_encoder.config.hidden_size
|
||||
print(f"image_encoder hidden size: ", embedding_dim)
|
||||
|
||||
image_proj_model = Resampler(
|
||||
dim=1024,
|
||||
depth=2,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=NUM_QUERIES,
|
||||
embedding_dim=embedding_dim,
|
||||
output_dim=OUTPUT_DIM,
|
||||
ff_mult=2,
|
||||
max_seq_len=257,
|
||||
apply_pos_emb=APPLY_POS_EMB,
|
||||
num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED,
|
||||
)
|
||||
|
||||
dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224)
|
||||
with torch.no_grad():
|
||||
image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2]
|
||||
print("image_embds shape: ", image_embeds.shape)
|
||||
|
||||
with torch.no_grad():
|
||||
ip_tokens = image_proj_model(image_embeds)
|
||||
print("ip_tokens shape:", ip_tokens.shape)
|
||||
assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue