From 035a8e32fc9b07c6aefd008d3a01771e0bf0734e Mon Sep 17 00:00:00 2001 From: aycaecemgul Date: Tue, 28 May 2024 17:35:51 +0300 Subject: [PATCH] repeat face embeds according to num samples to match with prompt list length --- ip_adapter/ip_adapter_faceid_separate.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/ip_adapter/ip_adapter_faceid_separate.py b/ip_adapter/ip_adapter_faceid_separate.py index 80ca84c..7c34e7c 100644 --- a/ip_adapter/ip_adapter_faceid_separate.py +++ b/ip_adapter/ip_adapter_faceid_separate.py @@ -214,7 +214,6 @@ class IPAdapterFaceID: ): self.set_scale(scale) - num_prompts = faceid_embeds.size(0) if prompt is None: @@ -224,6 +223,10 @@ class IPAdapterFaceID: if not isinstance(prompt, List): prompt = [prompt] * num_prompts + else: + faceid_embeds = faceid_embeds.repeat(num_samples, 1, 1) + num_samples = 1 + if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts @@ -254,6 +257,7 @@ class IPAdapterFaceID: guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, + num_images_per_prompt=num_samples, **kwargs, ).images @@ -435,6 +439,10 @@ class IPAdapterFaceIDXL(IPAdapterFaceID): if not isinstance(prompt, List): prompt = [prompt] * num_prompts + else: + faceid_embeds = faceid_embeds.repeat(num_samples, 1, 1) + num_samples = 1 + if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts @@ -470,6 +478,7 @@ class IPAdapterFaceIDXL(IPAdapterFaceID): negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_inference_steps=num_inference_steps, generator=generator, + num_images_per_prompt=num_samples, **kwargs, ).images