diff --git a/ip_adapter/resampler.py b/ip_adapter/resampler.py index 5a17bda..2426667 100644 --- a/ip_adapter/resampler.py +++ b/ip_adapter/resampler.py @@ -1,8 +1,12 @@ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py + import math import torch import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange # FFN @@ -85,8 +89,12 @@ class Resampler(nn.Module): embedding_dim=768, output_dim=1024, ff_mult=4, + max_seq_len: int = 257, # CLIP tokens + CLS token + apply_pos_emb: bool = False, + num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence ): super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) @@ -95,6 +103,16 @@ class Resampler(nn.Module): self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( @@ -107,13 +125,34 @@ class Resampler(nn.Module): ) def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb + latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) + for attn, ff in self.layers: latents = attn(x, latents) + latents latents = ff(latents) + latents latents = self.proj_out(latents) return self.norm_out(latents) + + +def masked_mean(t, *, dim, mask=None): + if mask is None: + return t.mean(dim=dim) + + denom = mask.sum(dim=dim, keepdim=True) + mask = rearrange(mask, "b n -> b n 1") + masked_t = t.masked_fill(~mask, 0.0) + + return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) diff --git a/ip_adapter/test_resampler.py b/ip_adapter/test_resampler.py new file mode 100644 index 0000000..8978c8e --- /dev/null +++ b/ip_adapter/test_resampler.py @@ -0,0 +1,44 @@ +import torch +from resampler import Resampler +from transformers import CLIPVisionModel + +BATCH_SIZE = 2 +OUTPUT_DIM = 1280 +NUM_QUERIES = 8 +NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior) +APPLY_POS_EMB = True # False for no positional embeddings (previous behavior) +IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" + + +def main(): + image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH) + embedding_dim = image_encoder.config.hidden_size + print(f"image_encoder hidden size: ", embedding_dim) + + image_proj_model = Resampler( + dim=1024, + depth=2, + dim_head=64, + heads=16, + num_queries=NUM_QUERIES, + embedding_dim=embedding_dim, + output_dim=OUTPUT_DIM, + ff_mult=2, + max_seq_len=257, + apply_pos_emb=APPLY_POS_EMB, + num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED, + ) + + dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224) + with torch.no_grad(): + image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2] + print("image_embds shape: ", image_embeds.shape) + + with torch.no_grad(): + ip_tokens = image_proj_model(image_embeds) + print("ip_tokens shape:", ip_tokens.shape) + assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM) + + +if __name__ == "__main__": + main()