add t2i demo with ip-adapter and prior
parent
1bea723195
commit
055393e7e9
|
|
@ -103,11 +103,14 @@ class IPAdapter:
|
||||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_image_embeds(self, pil_image):
|
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
|
||||||
if isinstance(pil_image, Image.Image):
|
if pil_image is not None:
|
||||||
pil_image = [pil_image]
|
if isinstance(pil_image, Image.Image):
|
||||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
pil_image = [pil_image]
|
||||||
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
|
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||||
|
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
|
||||||
|
else:
|
||||||
|
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
|
||||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
||||||
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
|
@ -119,7 +122,8 @@ class IPAdapter:
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
pil_image,
|
pil_image=None,
|
||||||
|
clip_image_embeds=None,
|
||||||
prompt=None,
|
prompt=None,
|
||||||
negative_prompt=None,
|
negative_prompt=None,
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
|
|
@ -130,11 +134,14 @@ class IPAdapter:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.set_scale(scale)
|
self.set_scale(scale)
|
||||||
|
|
||||||
if isinstance(pil_image, Image.Image):
|
if pil_image is not None:
|
||||||
num_prompts = 1
|
if isinstance(pil_image, Image.Image):
|
||||||
|
num_prompts = 1
|
||||||
|
else:
|
||||||
|
num_prompts = len(pil_image)
|
||||||
else:
|
else:
|
||||||
num_prompts = len(pil_image)
|
num_prompts = clip_image_embeds.size(0)
|
||||||
|
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = "best quality, high quality"
|
prompt = "best quality, high quality"
|
||||||
|
|
@ -146,7 +153,7 @@ class IPAdapter:
|
||||||
if not isinstance(negative_prompt, List):
|
if not isinstance(negative_prompt, List):
|
||||||
negative_prompt = [negative_prompt] * num_prompts
|
negative_prompt = [negative_prompt] * num_prompts
|
||||||
|
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=clip_image_embeds)
|
||||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
||||||
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue