Merge pull request #210 from xiaohu2015/main

add ip-adapter-faceid-plusv2
pull/213/head
Hu Ye 2023-12-29 21:25:54 +08:00 committed by GitHub
commit 9cd5a487e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 7 deletions

View File

@ -16,6 +16,7 @@ we present IP-Adapter, an effective and lightweight adapter to achieve image pro
![arch](assets/figs/fig1.png)
## Release
- [2023/12/29] 🔥 Add an experimental version of IP-Adapter-FaceID-PlusV2, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID).
- [2023/12/27] 🔥 Add an experimental version of IP-Adapter-FaceID-Plus, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID).
- [2023/12/20] 🔥 Add an experimental version of IP-Adapter-FaceID, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID).
- [2023/11/22] IP-Adapter is available in [Diffusers](https://github.com/huggingface/diffusers/pull/5713) thanks to Diffusers Team.

View File

@ -94,13 +94,15 @@ class ProjPlusModel(torch.nn.Module):
ff_mult=4,
)
def forward(self, id_embeds, clip_embeds):
def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
x = self.perceiver_resampler(x, clip_embeds)
return x
out = self.perceiver_resampler(x, clip_embeds)
if shortcut:
out = x + scale * out
return out
class IPAdapterFaceID:
@ -305,7 +307,7 @@ class IPAdapterFaceIDPlus:
ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
def get_image_embeds(self, faceid_embeds, face_image):
def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut):
if isinstance(face_image, Image.Image):
pil_image = [face_image]
clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
@ -316,8 +318,8 @@ class IPAdapterFaceIDPlus:
).hidden_states[-2]
faceid_embeds = faceid_embeds.to(self.device, dtype=torch.float16)
image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds)
image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
@ -336,6 +338,8 @@ class IPAdapterFaceIDPlus:
seed=None,
guidance_scale=7.5,
num_inference_steps=30,
s_scale=1.0,
shortcut=False,
**kwargs,
):
self.set_scale(scale)
@ -353,7 +357,7 @@ class IPAdapterFaceIDPlus:
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image)
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)