Merge pull request #213 from ironbar/add-cpu-support
Added support for cpu for ip-adapter-facepull/220/head
commit
9d8960cbe8
|
|
@ -106,11 +106,12 @@ class ProjPlusModel(torch.nn.Module):
|
|||
|
||||
|
||||
class IPAdapterFaceID:
|
||||
def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4):
|
||||
def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
|
||||
self.device = device
|
||||
self.ip_ckpt = ip_ckpt
|
||||
self.lora_rank = lora_rank
|
||||
self.num_tokens = num_tokens
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
self.pipe = sd_pipe.to(self.device)
|
||||
self.set_ip_adapter()
|
||||
|
|
@ -125,7 +126,7 @@ class IPAdapterFaceID:
|
|||
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
||||
id_embeddings_dim=512,
|
||||
num_tokens=self.num_tokens,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
).to(self.device, dtype=self.torch_dtype)
|
||||
return image_proj_model
|
||||
|
||||
def set_ip_adapter(self):
|
||||
|
|
@ -144,11 +145,11 @@ class IPAdapterFaceID:
|
|||
if cross_attention_dim is None:
|
||||
attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
).to(self.device, dtype=self.torch_dtype)
|
||||
else:
|
||||
attn_procs[name] = LoRAIPAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
).to(self.device, dtype=self.torch_dtype)
|
||||
unet.set_attn_processor(attn_procs)
|
||||
|
||||
def load_ip_adapter(self):
|
||||
|
|
@ -169,7 +170,7 @@ class IPAdapterFaceID:
|
|||
@torch.inference_mode()
|
||||
def get_image_embeds(self, faceid_embeds):
|
||||
|
||||
faceid_embeds = faceid_embeds.to(self.device, dtype=torch.float16)
|
||||
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
|
||||
image_prompt_embeds = self.image_proj_model(faceid_embeds)
|
||||
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
|
@ -239,19 +240,20 @@ class IPAdapterFaceID:
|
|||
|
||||
|
||||
class IPAdapterFaceIDPlus:
|
||||
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4):
|
||||
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
|
||||
self.device = device
|
||||
self.image_encoder_path = image_encoder_path
|
||||
self.ip_ckpt = ip_ckpt
|
||||
self.lora_rank = lora_rank
|
||||
self.num_tokens = num_tokens
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
self.pipe = sd_pipe.to(self.device)
|
||||
self.set_ip_adapter()
|
||||
|
||||
# load image encoder
|
||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
||||
self.device, dtype=torch.float16
|
||||
self.device, dtype=self.torch_dtype
|
||||
)
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
# image proj model
|
||||
|
|
@ -265,7 +267,7 @@ class IPAdapterFaceIDPlus:
|
|||
id_embeddings_dim=512,
|
||||
clip_embeddings_dim=self.image_encoder.config.hidden_size,
|
||||
num_tokens=self.num_tokens,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
).to(self.device, dtype=self.torch_dtype)
|
||||
return image_proj_model
|
||||
|
||||
def set_ip_adapter(self):
|
||||
|
|
@ -284,11 +286,11 @@ class IPAdapterFaceIDPlus:
|
|||
if cross_attention_dim is None:
|
||||
attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
).to(self.device, dtype=self.torch_dtype)
|
||||
else:
|
||||
attn_procs[name] = LoRAIPAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
).to(self.device, dtype=self.torch_dtype)
|
||||
unet.set_attn_processor(attn_procs)
|
||||
|
||||
def load_ip_adapter(self):
|
||||
|
|
@ -311,13 +313,13 @@ class IPAdapterFaceIDPlus:
|
|||
if isinstance(face_image, Image.Image):
|
||||
pil_image = [face_image]
|
||||
clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
||||
clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
|
||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
uncond_clip_image_embeds = self.image_encoder(
|
||||
torch.zeros_like(clip_image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
|
||||
faceid_embeds = faceid_embeds.to(self.device, dtype=torch.float16)
|
||||
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
|
||||
image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
|
||||
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
|
|
|||
Loading…
Reference in New Issue