Merge pull request #213 from ironbar/add-cpu-support

Added support for cpu for ip-adapter-face
pull/220/head
Hu Ye 2023-12-30 19:12:44 +08:00 committed by GitHub
commit 9d8960cbe8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 12 deletions

View File

@ -106,11 +106,12 @@ class ProjPlusModel(torch.nn.Module):
class IPAdapterFaceID:
def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4):
def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
self.device = device
self.ip_ckpt = ip_ckpt
self.lora_rank = lora_rank
self.num_tokens = num_tokens
self.torch_dtype = torch_dtype
self.pipe = sd_pipe.to(self.device)
self.set_ip_adapter()
@ -125,7 +126,7 @@ class IPAdapterFaceID:
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
id_embeddings_dim=512,
num_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
).to(self.device, dtype=self.torch_dtype)
return image_proj_model
def set_ip_adapter(self):
@ -144,11 +145,11 @@ class IPAdapterFaceID:
if cross_attention_dim is None:
attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
).to(self.device, dtype=torch.float16)
).to(self.device, dtype=self.torch_dtype)
else:
attn_procs[name] = LoRAIPAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
).to(self.device, dtype=self.torch_dtype)
unet.set_attn_processor(attn_procs)
def load_ip_adapter(self):
@ -169,7 +170,7 @@ class IPAdapterFaceID:
@torch.inference_mode()
def get_image_embeds(self, faceid_embeds):
faceid_embeds = faceid_embeds.to(self.device, dtype=torch.float16)
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
image_prompt_embeds = self.image_proj_model(faceid_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
@ -239,19 +240,20 @@ class IPAdapterFaceID:
class IPAdapterFaceIDPlus:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4):
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
self.device = device
self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt
self.lora_rank = lora_rank
self.num_tokens = num_tokens
self.torch_dtype = torch_dtype
self.pipe = sd_pipe.to(self.device)
self.set_ip_adapter()
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
self.device, dtype=self.torch_dtype
)
self.clip_image_processor = CLIPImageProcessor()
# image proj model
@ -265,7 +267,7 @@ class IPAdapterFaceIDPlus:
id_embeddings_dim=512,
clip_embeddings_dim=self.image_encoder.config.hidden_size,
num_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
).to(self.device, dtype=self.torch_dtype)
return image_proj_model
def set_ip_adapter(self):
@ -284,11 +286,11 @@ class IPAdapterFaceIDPlus:
if cross_attention_dim is None:
attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
).to(self.device, dtype=torch.float16)
).to(self.device, dtype=self.torch_dtype)
else:
attn_procs[name] = LoRAIPAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
).to(self.device, dtype=self.torch_dtype)
unet.set_attn_processor(attn_procs)
def load_ip_adapter(self):
@ -311,13 +313,13 @@ class IPAdapterFaceIDPlus:
if isinstance(face_image, Image.Image):
pil_image = [face_image]
clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
faceid_embeds = faceid_embeds.to(self.device, dtype=torch.float16)
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
return image_prompt_embeds, uncond_image_prompt_embeds