add t2i demo with ip-adapter and prior
parent
1bea723195
commit
055393e7e9
|
|
@ -103,11 +103,14 @@ class IPAdapter:
|
|||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image):
|
||||
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
|
||||
if pil_image is not None:
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
|
||||
else:
|
||||
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
|
||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
||||
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
|
@ -119,7 +122,8 @@ class IPAdapter:
|
|||
|
||||
def generate(
|
||||
self,
|
||||
pil_image,
|
||||
pil_image=None,
|
||||
clip_image_embeds=None,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
scale=1.0,
|
||||
|
|
@ -131,10 +135,13 @@ class IPAdapter:
|
|||
):
|
||||
self.set_scale(scale)
|
||||
|
||||
if pil_image is not None:
|
||||
if isinstance(pil_image, Image.Image):
|
||||
num_prompts = 1
|
||||
else:
|
||||
num_prompts = len(pil_image)
|
||||
else:
|
||||
num_prompts = clip_image_embeds.size(0)
|
||||
|
||||
if prompt is None:
|
||||
prompt = "best quality, high quality"
|
||||
|
|
@ -146,7 +153,7 @@ class IPAdapter:
|
|||
if not isinstance(negative_prompt, List):
|
||||
negative_prompt = [negative_prompt] * num_prompts
|
||||
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=clip_image_embeds)
|
||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue