add t2i demo with ip-adapter and prior

pull/129/head
xiaohu2015 2023-11-05 18:51:41 +08:00
parent 1bea723195
commit 055393e7e9
2 changed files with 224 additions and 11 deletions

View File

@ -103,11 +103,14 @@ class IPAdapter:
ip_layers.load_state_dict(state_dict["ip_adapter"]) ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode() @torch.inference_mode()
def get_image_embeds(self, pil_image): def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if isinstance(pil_image, Image.Image): if pil_image is not None:
pil_image = [pil_image] if isinstance(pil_image, Image.Image):
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values pil_image = [pil_image]
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
image_prompt_embeds = self.image_proj_model(clip_image_embeds) image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds return image_prompt_embeds, uncond_image_prompt_embeds
@ -119,7 +122,8 @@ class IPAdapter:
def generate( def generate(
self, self,
pil_image, pil_image=None,
clip_image_embeds=None,
prompt=None, prompt=None,
negative_prompt=None, negative_prompt=None,
scale=1.0, scale=1.0,
@ -130,11 +134,14 @@ class IPAdapter:
**kwargs, **kwargs,
): ):
self.set_scale(scale) self.set_scale(scale)
if isinstance(pil_image, Image.Image): if pil_image is not None:
num_prompts = 1 if isinstance(pil_image, Image.Image):
num_prompts = 1
else:
num_prompts = len(pil_image)
else: else:
num_prompts = len(pil_image) num_prompts = clip_image_embeds.size(0)
if prompt is None: if prompt is None:
prompt = "best quality, high quality" prompt = "best quality, high quality"
@ -146,7 +153,7 @@ class IPAdapter:
if not isinstance(negative_prompt, List): if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=clip_image_embeds)
bs_embed, seq_len, _ = image_prompt_embeds.shape bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

206
ip_adapter_t2i_demo.ipynb Normal file

File diff suppressed because one or more lines are too long