diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index 318e131..5fd8a89 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -106,11 +106,12 @@ class ProjPlusModel(torch.nn.Module): class IPAdapterFaceID: - def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4): + def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16): self.device = device self.ip_ckpt = ip_ckpt self.lora_rank = lora_rank self.num_tokens = num_tokens + self.torch_dtype = torch_dtype self.pipe = sd_pipe.to(self.device) self.set_ip_adapter() @@ -125,7 +126,7 @@ class IPAdapterFaceID: cross_attention_dim=self.pipe.unet.config.cross_attention_dim, id_embeddings_dim=512, num_tokens=self.num_tokens, - ).to(self.device, dtype=torch.float16) + ).to(self.device, dtype=self.torch_dtype) return image_proj_model def set_ip_adapter(self): @@ -144,11 +145,11 @@ class IPAdapterFaceID: if cross_attention_dim is None: attn_procs[name] = LoRAAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, - ).to(self.device, dtype=torch.float16) + ).to(self.device, dtype=self.torch_dtype) else: attn_procs[name] = LoRAIPAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, - ).to(self.device, dtype=torch.float16) + ).to(self.device, dtype=self.torch_dtype) unet.set_attn_processor(attn_procs) def load_ip_adapter(self): @@ -169,7 +170,7 @@ class IPAdapterFaceID: @torch.inference_mode() def get_image_embeds(self, faceid_embeds): - faceid_embeds = faceid_embeds.to(self.device, dtype=torch.float16) + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) image_prompt_embeds = self.image_proj_model(faceid_embeds) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds @@ -239,19 +240,20 @@ class IPAdapterFaceID: class IPAdapterFaceIDPlus: - def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4): + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16): self.device = device self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt self.lora_rank = lora_rank self.num_tokens = num_tokens + self.torch_dtype = torch_dtype self.pipe = sd_pipe.to(self.device) self.set_ip_adapter() # load image encoder self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( - self.device, dtype=torch.float16 + self.device, dtype=self.torch_dtype ) self.clip_image_processor = CLIPImageProcessor() # image proj model @@ -265,7 +267,7 @@ class IPAdapterFaceIDPlus: id_embeddings_dim=512, clip_embeddings_dim=self.image_encoder.config.hidden_size, num_tokens=self.num_tokens, - ).to(self.device, dtype=torch.float16) + ).to(self.device, dtype=self.torch_dtype) return image_proj_model def set_ip_adapter(self): @@ -284,11 +286,11 @@ class IPAdapterFaceIDPlus: if cross_attention_dim is None: attn_procs[name] = LoRAAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, - ).to(self.device, dtype=torch.float16) + ).to(self.device, dtype=self.torch_dtype) else: attn_procs[name] = LoRAIPAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, - ).to(self.device, dtype=torch.float16) + ).to(self.device, dtype=self.torch_dtype) unet.set_attn_processor(attn_procs) def load_ip_adapter(self): @@ -311,13 +313,13 @@ class IPAdapterFaceIDPlus: if isinstance(face_image, Image.Image): pil_image = [face_image] clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image = clip_image.to(self.device, dtype=self.torch_dtype) clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] uncond_clip_image_embeds = self.image_encoder( torch.zeros_like(clip_image), output_hidden_states=True ).hidden_states[-2] - faceid_embeds = faceid_embeds.to(self.device, dtype=torch.float16) + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale) return image_prompt_embeds, uncond_image_prompt_embeds