diff --git a/ip_adapter/ip_adapter.py b/ip_adapter/ip_adapter.py index 2afe535..dcb8824 100644 --- a/ip_adapter/ip_adapter.py +++ b/ip_adapter/ip_adapter.py @@ -8,7 +8,7 @@ from PIL import Image from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -from .utils import is_torch2_available +from .utils import is_torch2_available, get_generator if is_torch2_available(): from .attention_processor import ( @@ -204,7 +204,8 @@ class IPAdapter: prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -267,7 +268,7 @@ class IPAdapterXL(IPAdapter): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - self.generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + self.generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, @@ -401,7 +402,8 @@ class IPAdapterPlusXL(IPAdapter): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index c75fddf..fe98ad5 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -9,7 +9,7 @@ from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor -from .utils import is_torch2_available +from .utils import is_torch2_available, get_generator USE_DAFAULT_ATTN = False # should be True for visualization_attnmap if is_torch2_available() and (not USE_DAFAULT_ATTN): @@ -238,7 +238,8 @@ class IPAdapterFaceID: prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -390,7 +391,8 @@ class IPAdapterFaceIDPlus: prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -454,7 +456,8 @@ class IPAdapterFaceIDXL(IPAdapterFaceID): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -523,7 +526,8 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, diff --git a/ip_adapter/ip_adapter_faceid_separate.py b/ip_adapter/ip_adapter_faceid_separate.py index 67b6d1e..80ca84c 100644 --- a/ip_adapter/ip_adapter_faceid_separate.py +++ b/ip_adapter/ip_adapter_faceid_separate.py @@ -8,7 +8,7 @@ from PIL import Image from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -from .utils import is_torch2_available +from .utils import is_torch2_available, get_generator USE_DAFAULT_ATTN = False # should be True for visualization_attnmap if is_torch2_available() and (not USE_DAFAULT_ATTN): @@ -246,7 +246,8 @@ class IPAdapterFaceID: prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -395,7 +396,8 @@ class IPAdapterFaceIDPlus: prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -459,7 +461,8 @@ class IPAdapterFaceIDXL(IPAdapterFaceID): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -528,7 +531,8 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, diff --git a/ip_adapter/utils.py b/ip_adapter/utils.py index 3c99cd1..6a27335 100644 --- a/ip_adapter/utils.py +++ b/ip_adapter/utils.py @@ -79,3 +79,15 @@ def attnmaps2images(net_attn_maps): return images def is_torch2_available(): return hasattr(F, "scaled_dot_product_attention") + +def get_generator(seed, device): + + if seed is not None: + if isinstance(seed, list): + generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] + else: + generator = torch.Generator(device).manual_seed(seed) + else: + generator = None + + return generator \ No newline at end of file