From bd7d2c1998f12b63a559a1e75b2577fbe2237218 Mon Sep 17 00:00:00 2001 From: jetthu Date: Wed, 8 Nov 2023 07:13:26 +0800 Subject: [PATCH] style: Format code indent, import order and remove trailing whitespace --- ip_adapter/__init__.py | 9 +- ip_adapter/attention_processor.py | 57 +++++----- ip_adapter/custom_pipelines.py | 17 ++- ip_adapter/ip_adapter.py | 170 ++++++++++++++++++------------ ip_adapter/resampler.py | 30 +++--- 5 files changed, 161 insertions(+), 122 deletions(-) diff --git a/ip_adapter/__init__.py b/ip_adapter/__init__.py index c0da1ff..301128c 100644 --- a/ip_adapter/__init__.py +++ b/ip_adapter/__init__.py @@ -1 +1,8 @@ -from .ip_adapter import IPAdapter, IPAdapterXL, IPAdapterPlus, IPAdapterPlusXL +from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL + +__all__ = [ + "IPAdapter", + "IPAdapterPlus", + "IPAdapterPlusXL", + "IPAdapterXL", +] diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py index 9b6eaed..fecc7ed 100644 --- a/ip_adapter/attention_processor.py +++ b/ip_adapter/attention_processor.py @@ -8,6 +8,7 @@ class AttnProcessor(nn.Module): r""" Default processor for performing attention-related computations. """ + def __init__( self, hidden_size=None, @@ -74,8 +75,8 @@ class AttnProcessor(nn.Module): hidden_states = hidden_states / attn.rescale_output_factor return hidden_states - - + + class IPAttnProcessor(nn.Module): r""" Attention processor for IP-Adapater. @@ -135,7 +136,10 @@ class IPAttnProcessor(nn.Module): else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) @@ -149,18 +153,18 @@ class IPAttnProcessor(nn.Module): attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) - + # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) - + ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) - + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - + hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj @@ -177,12 +181,13 @@ class IPAttnProcessor(nn.Module): hidden_states = hidden_states / attn.rescale_output_factor return hidden_states - - + + class AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ + def __init__( self, hidden_size=None, @@ -265,8 +270,8 @@ class AttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states / attn.rescale_output_factor return hidden_states - - + + class IPAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapater for PyTorch 2.0. @@ -334,7 +339,10 @@ class IPAttnProcessor2_0(torch.nn.Module): else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) @@ -357,11 +365,11 @@ class IPAttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - + # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) - + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) @@ -370,10 +378,10 @@ class IPAttnProcessor2_0(torch.nn.Module): ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) - + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) - + hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj @@ -401,14 +409,7 @@ class CNAttnProcessor: def __init__(self, num_tokens=4): self.num_tokens = num_tokens - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None - ): + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: @@ -434,7 +435,7 @@ class CNAttnProcessor: encoder_hidden_states = hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text + encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) @@ -470,7 +471,7 @@ class CNAttnProcessor2_0: Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ - def __init__(self, num_tokens=4): + def __init__(self, num_tokens=4): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.num_tokens = num_tokens @@ -513,7 +514,7 @@ class CNAttnProcessor2_0: encoder_hidden_states = hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text + encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) @@ -550,4 +551,4 @@ class CNAttnProcessor2_0: hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states \ No newline at end of file + return hidden_states diff --git a/ip_adapter/custom_pipelines.py b/ip_adapter/custom_pipelines.py index abe769c..7d43d2c 100644 --- a/ip_adapter/custom_pipelines.py +++ b/ip_adapter/custom_pipelines.py @@ -1,13 +1,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch - from diffusers import StableDiffusionXLPipeline from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg - from .utils import is_torch2_available + if is_torch2_available(): from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor else: @@ -15,16 +14,15 @@ else: class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline): - def set_scale(self, scale): for attn_processor in self.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale @torch.no_grad() - def __call__( + def __call__( # noqa: C901 self, - prompt: Union[str, List[str]] = None, + prompt: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, @@ -316,7 +314,7 @@ class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline): ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] - + # get init conditioning scale for attn_processor in self.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): @@ -325,9 +323,8 @@ class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - if (i / len(timesteps) < control_guidance_start) or ((i + 1) / len(timesteps) > control_guidance_end): - self.set_scale(0.) + self.set_scale(0.0) else: self.set_scale(conditioning_scale) @@ -381,7 +378,7 @@ class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline): else: image = latents - if not output_type == "latent": + if output_type != "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) @@ -394,4 +391,4 @@ class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline): if not return_dict: return (image,) - return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file + return StableDiffusionXLPipelineOutput(images=image) diff --git a/ip_adapter/ip_adapter.py b/ip_adapter/ip_adapter.py index 8daf73a..0060d7f 100644 --- a/ip_adapter/ip_adapter.py +++ b/ip_adapter/ip_adapter.py @@ -4,55 +4,67 @@ from typing import List import torch from diffusers import StableDiffusionPipeline from diffusers.pipelines.controlnet import MultiControlNetModel -from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor from PIL import Image from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from .utils import is_torch2_available + if is_torch2_available(): - from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor, CNAttnProcessor2_0 as CNAttnProcessor + from .attention_processor import ( + AttnProcessor2_0 as AttnProcessor, + ) + from .attention_processor import ( + CNAttnProcessor2_0 as CNAttnProcessor, + ) + from .attention_processor import ( + IPAttnProcessor2_0 as IPAttnProcessor, + ) else: - from .attention_processor import IPAttnProcessor, AttnProcessor, CNAttnProcessor + from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor from .resampler import Resampler class ImageProjModel(torch.nn.Module): """Projection Model""" + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() - + self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) - + def forward(self, image_embeds): embeds = image_embeds - clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class IPAdapter: - def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): - self.device = device self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt self.num_tokens = num_tokens - + self.pipe = sd_pipe.to(self.device) self.set_ip_adapter() - + # load image encoder - self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16) + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.float16 + ) self.clip_image_processor = CLIPImageProcessor() # image proj model self.image_proj_model = self.init_proj() - + self.load_ip_adapter() - + def init_proj(self): image_proj_model = ImageProjModel( cross_attention_dim=self.pipe.unet.config.cross_attention_dim, @@ -60,7 +72,7 @@ class IPAdapter: clip_extra_context_tokens=self.num_tokens, ).to(self.device, dtype=torch.float16) return image_proj_model - + def set_ip_adapter(self): unet = self.pipe.unet attn_procs = {} @@ -77,8 +89,12 @@ class IPAdapter: if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: - attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, - scale=1.0,num_tokens= self.num_tokens).to(self.device, dtype=torch.float16) + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float16) unet.set_attn_processor(attn_procs) if hasattr(self.pipe, "controlnet"): if isinstance(self.pipe.controlnet, MultiControlNetModel): @@ -86,7 +102,7 @@ class IPAdapter: controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) else: self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) - + def load_ip_adapter(self): if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": state_dict = {"image_proj": {}, "ip_adapter": {}} @@ -101,7 +117,7 @@ class IPAdapter: self.image_proj_model.load_state_dict(state_dict["image_proj"]) ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(state_dict["ip_adapter"]) - + @torch.inference_mode() def get_image_embeds(self, pil_image=None, clip_image_embeds=None): if pil_image is not None: @@ -114,12 +130,12 @@ class IPAdapter: image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds - + def set_scale(self, scale): for attn_processor in self.pipe.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale - + def generate( self, pil_image=None, @@ -136,24 +152,23 @@ class IPAdapter: self.set_scale(scale) if pil_image is not None: - if isinstance(pil_image, Image.Image): - num_prompts = 1 - else: - num_prompts = len(pil_image) + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) else: num_prompts = clip_image_embeds.size(0) - + if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - + if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts - - image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=clip_image_embeds) + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( + pil_image=pil_image, clip_image_embeds=clip_image_embeds + ) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) @@ -162,11 +177,16 @@ class IPAdapter: with torch.inference_mode(): prompt_embeds = self.pipe._encode_prompt( - prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2) prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None images = self.pipe( prompt_embeds=prompt_embeds, @@ -176,13 +196,13 @@ class IPAdapter: generator=generator, **kwargs, ).images - + return images - - + + class IPAdapterXL(IPAdapter): """SDXL""" - + def generate( self, pil_image, @@ -195,22 +215,19 @@ class IPAdapterXL(IPAdapter): **kwargs, ): self.set_scale(scale) - - if isinstance(pil_image, Image.Image): - num_prompts = 1 - else: - num_prompts = len(pil_image) - + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - + if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts - + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) @@ -219,11 +236,20 @@ class IPAdapterXL(IPAdapter): uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) with torch.inference_mode(): - prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt( - prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None images = self.pipe( prompt_embeds=prompt_embeds, @@ -234,10 +260,10 @@ class IPAdapterXL(IPAdapter): generator=generator, **kwargs, ).images - + return images - - + + class IPAdapterPlus(IPAdapter): """IP-Adapter with fine-grained features""" @@ -250,10 +276,10 @@ class IPAdapterPlus(IPAdapter): num_queries=self.num_tokens, embedding_dim=self.image_encoder.config.hidden_size, output_dim=self.pipe.unet.config.cross_attention_dim, - ff_mult=4 + ff_mult=4, ).to(self.device, dtype=torch.float16) return image_proj_model - + @torch.inference_mode() def get_image_embeds(self, pil_image=None, clip_image_embeds=None): if isinstance(pil_image, Image.Image): @@ -262,7 +288,9 @@ class IPAdapterPlus(IPAdapter): clip_image = clip_image.to(self.device, dtype=torch.float16) clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2] + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds @@ -279,10 +307,10 @@ class IPAdapterPlusXL(IPAdapter): num_queries=self.num_tokens, embedding_dim=self.image_encoder.config.hidden_size, output_dim=self.pipe.unet.config.cross_attention_dim, - ff_mult=4 + ff_mult=4, ).to(self.device, dtype=torch.float16) return image_proj_model - + @torch.inference_mode() def get_image_embeds(self, pil_image): if isinstance(pil_image, Image.Image): @@ -291,10 +319,12 @@ class IPAdapterPlusXL(IPAdapter): clip_image = clip_image.to(self.device, dtype=torch.float16) clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2] + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds - + def generate( self, pil_image, @@ -307,22 +337,19 @@ class IPAdapterPlusXL(IPAdapter): **kwargs, ): self.set_scale(scale) - - if isinstance(pil_image, Image.Image): - num_prompts = 1 - else: - num_prompts = len(pil_image) - + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - + if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts - + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) @@ -331,11 +358,20 @@ class IPAdapterPlusXL(IPAdapter): uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) with torch.inference_mode(): - prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt( - prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None images = self.pipe( prompt_embeds=prompt_embeds, @@ -346,5 +382,5 @@ class IPAdapterPlusXL(IPAdapter): generator=generator, **kwargs, ).images - + return images diff --git a/ip_adapter/resampler.py b/ip_adapter/resampler.py index 4521c8c..5a17bda 100644 --- a/ip_adapter/resampler.py +++ b/ip_adapter/resampler.py @@ -14,11 +14,11 @@ def FeedForward(dim, mult=4): nn.GELU(), nn.Linear(inner_dim, dim, bias=False), ) - - + + def reshape_tensor(x, heads): bs, length, width = x.shape - #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, length, heads, -1) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) @@ -42,7 +42,6 @@ class PerceiverAttention(nn.Module): self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward(self, x, latents): """ Args: @@ -53,23 +52,23 @@ class PerceiverAttention(nn.Module): """ x = self.norm1(x) latents = self.norm2(latents) - + b, l, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) - + q = reshape_tensor(q, self.heads) k = reshape_tensor(k, self.heads) v = reshape_tensor(v, self.heads) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v - + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) @@ -88,14 +87,14 @@ class Resampler(nn.Module): ff_mult=4, ): super().__init__() - + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) - + self.proj_in = nn.Linear(embedding_dim, dim) self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) - + self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( @@ -108,14 +107,13 @@ class Resampler(nn.Module): ) def forward(self, x): - latents = self.latents.repeat(x.size(0), 1, 1) - + x = self.proj_in(x) - + for attn, ff in self.layers: latents = attn(x, latents) + latents latents = ff(latents) + latents - + latents = self.proj_out(latents) - return self.norm_out(latents) \ No newline at end of file + return self.norm_out(latents)