diff --git a/ip_adapter/ip_adapter.py b/ip_adapter/ip_adapter.py index ef6d05c..ec70eb5 100644 --- a/ip_adapter/ip_adapter.py +++ b/ip_adapter/ip_adapter.py @@ -204,7 +204,7 @@ class IPAdapter: prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - generator = get_generator(seed) + generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, @@ -402,7 +402,7 @@ class IPAdapterPlusXL(IPAdapter): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = get_generator(seed) + generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index 27e3b64..fe98ad5 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -238,7 +238,7 @@ class IPAdapterFaceID: prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - generator = get_generator(seed) + generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, @@ -391,7 +391,7 @@ class IPAdapterFaceIDPlus: prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - generator = get_generator(seed) + generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, @@ -456,7 +456,7 @@ class IPAdapterFaceIDXL(IPAdapterFaceID): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = get_generator(seed) + generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, @@ -526,7 +526,7 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = get_generator(seed) + generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, diff --git a/ip_adapter/ip_adapter_faceid_separate.py b/ip_adapter/ip_adapter_faceid_separate.py index 8881d8b..80ca84c 100644 --- a/ip_adapter/ip_adapter_faceid_separate.py +++ b/ip_adapter/ip_adapter_faceid_separate.py @@ -246,7 +246,7 @@ class IPAdapterFaceID: prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - generator = get_generator(seed) + generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, @@ -396,7 +396,7 @@ class IPAdapterFaceIDPlus: prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - generator = get_generator(seed) + generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, @@ -461,7 +461,7 @@ class IPAdapterFaceIDXL(IPAdapterFaceID): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = get_generator(seed) + generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, @@ -531,7 +531,7 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = get_generator(seed) + generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, diff --git a/ip_adapter/utils.py b/ip_adapter/utils.py index 46c446c..6a27335 100644 --- a/ip_adapter/utils.py +++ b/ip_adapter/utils.py @@ -80,13 +80,13 @@ def attnmaps2images(net_attn_maps): def is_torch2_available(): return hasattr(F, "scaled_dot_product_attention") -def get_generator(seed): +def get_generator(seed, device): if seed is not None: if isinstance(seed, list): - generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed] + generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] else: - generator = torch.Generator(self.device).manual_seed(seed) + generator = torch.Generator(device).manual_seed(seed) else: generator = None