add t2i demo with ip-adapter and prior

pull/129/head
xiaohu2015 2023-11-05 18:51:41 +08:00
parent 1bea723195
commit 055393e7e9
2 changed files with 224 additions and 11 deletions

View File

@ -103,11 +103,14 @@ class IPAdapter:
ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
def get_image_embeds(self, pil_image):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
@ -119,7 +122,8 @@ class IPAdapter:
def generate(
self,
pil_image,
pil_image=None,
clip_image_embeds=None,
prompt=None,
negative_prompt=None,
scale=1.0,
@ -130,11 +134,14 @@ class IPAdapter:
**kwargs,
):
self.set_scale(scale)
if isinstance(pil_image, Image.Image):
num_prompts = 1
if pil_image is not None:
if isinstance(pil_image, Image.Image):
num_prompts = 1
else:
num_prompts = len(pil_image)
else:
num_prompts = len(pil_image)
num_prompts = clip_image_embeds.size(0)
if prompt is None:
prompt = "best quality, high quality"
@ -146,7 +153,7 @@ class IPAdapter:
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=clip_image_embeds)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

206
ip_adapter_t2i_demo.ipynb Normal file

File diff suppressed because one or more lines are too long