addpositional embedding and mean pooled latents

pull/135/head
danbochman 2023-11-08 14:36:39 +01:00
parent 608f3433c9
commit 9fc189e3fb
No known key found for this signature in database
GPG Key ID: B0DE112E399D1082
2 changed files with 83 additions and 0 deletions

View File

@ -1,8 +1,12 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
import math
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
# FFN
@ -85,8 +89,12 @@ class Resampler(nn.Module):
embedding_dim=768,
output_dim=1024,
ff_mult=4,
max_seq_len: int = 257, # CLIP tokens + CLS token
apply_pos_emb: bool = False,
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
@ -95,6 +103,16 @@ class Resampler(nn.Module):
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.to_latents_from_mean_pooled_seq = (
nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
)
if num_latents_mean_pooled > 0
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
@ -107,13 +125,34 @@ class Resampler(nn.Module):
)
def forward(self, x):
if self.pos_emb is not None:
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device=device))
x = x + pos_emb
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
if self.to_latents_from_mean_pooled_seq:
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim=-2)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
def masked_mean(t, *, dim, mask=None):
if mask is None:
return t.mean(dim=dim)
denom = mask.sum(dim=dim, keepdim=True)
mask = rearrange(mask, "b n -> b n 1")
masked_t = t.masked_fill(~mask, 0.0)
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)

View File

@ -0,0 +1,44 @@
import torch
from resampler import Resampler
from transformers import CLIPVisionModel
BATCH_SIZE = 2
OUTPUT_DIM = 1280
NUM_QUERIES = 8
NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior)
APPLY_POS_EMB = True # False for no positional embeddings (previous behavior)
IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
def main():
image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH)
embedding_dim = image_encoder.config.hidden_size
print(f"image_encoder hidden size: ", embedding_dim)
image_proj_model = Resampler(
dim=1024,
depth=2,
dim_head=64,
heads=16,
num_queries=NUM_QUERIES,
embedding_dim=embedding_dim,
output_dim=OUTPUT_DIM,
ff_mult=2,
max_seq_len=257,
apply_pos_emb=APPLY_POS_EMB,
num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED,
)
dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224)
with torch.no_grad():
image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2]
print("image_embds shape: ", image_embeds.shape)
with torch.no_grad():
ip_tokens = image_proj_model(image_embeds)
print("ip_tokens shape:", ip_tokens.shape)
assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM)
if __name__ == "__main__":
main()