diff --git a/README.md b/README.md index eecf35c..b15d295 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ we present IP-Adapter, an effective and lightweight adapter to achieve image pro ![arch](assets/figs/fig1.png) ## Release +- [2024/01/19] 🔥 Add IP-Adapter-FaceID-Portrait, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID). - [2024/01/17] 🔥 Add an experimental version of IP-Adapter-FaceID-PlusV2 for SDXL, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID). - [2024/01/04] 🔥 Add an experimental version of IP-Adapter-FaceID for SDXL, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID). - [2023/12/29] 🔥 Add an experimental version of IP-Adapter-FaceID-PlusV2, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID). diff --git a/ip_adapter/ip_adapter_faceid_separate.py b/ip_adapter/ip_adapter_faceid_separate.py index 4355864..41caa1a 100644 --- a/ip_adapter/ip_adapter_faceid_separate.py +++ b/ip_adapter/ip_adapter_faceid_separate.py @@ -8,7 +8,6 @@ from PIL import Image from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor from .utils import is_torch2_available USE_DAFAULT_ATTN = False # should be True for visualization_attnmap @@ -118,10 +117,11 @@ class ProjPlusModel(torch.nn.Module): class IPAdapterFaceID: - def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=4, torch_dtype=torch.float16): + def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=4, n_cond=1, torch_dtype=torch.float16): self.device = device self.ip_ckpt = ip_ckpt self.num_tokens = num_tokens + self.n_cond = n_cond self.torch_dtype = torch_dtype self.pipe = sd_pipe.to(self.device) @@ -157,7 +157,7 @@ class IPAdapterFaceID: attn_procs[name] = AttnProcessor() else: attn_procs[name] = IPAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens, + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens*self.n_cond, ).to(self.device, dtype=self.torch_dtype) unet.set_attn_processor(attn_procs) @@ -178,15 +178,26 @@ class IPAdapterFaceID: @torch.inference_mode() def get_image_embeds(self, faceid_embeds): - + + multi_face = False + if faceid_embeds.dim() == 3: + multi_face = True + b, n, c = faceid_embeds.shape + faceid_embeds = faceid_embeds.reshape(b*n, c) + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) image_prompt_embeds = self.image_proj_model(faceid_embeds) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds)) + if multi_face: + c = image_prompt_embeds.size(-1) + image_prompt_embeds = image_prompt_embeds.reshape(b, -1, c) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.reshape(b, -1, c) + return image_prompt_embeds, uncond_image_prompt_embeds def set_scale(self, scale): for attn_processor in self.pipe.unet.attn_processors.values(): - if isinstance(attn_processor, LoRAIPAttnProcessor): + if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale def generate(