commit
5a18b1f366
|
|
@ -8,7 +8,7 @@ from PIL import Image
|
|||
from safetensors import safe_open
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from .utils import is_torch2_available
|
||||
from .utils import is_torch2_available, get_generator
|
||||
|
||||
if is_torch2_available():
|
||||
from .attention_processor import (
|
||||
|
|
@ -204,7 +204,8 @@ class IPAdapter:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
@ -267,7 +268,7 @@ class IPAdapterXL(IPAdapter):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
self.generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
self.generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
@ -401,7 +402,8 @@ class IPAdapterPlusXL(IPAdapter):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from safetensors import safe_open
|
|||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
|
||||
from .utils import is_torch2_available
|
||||
from .utils import is_torch2_available, get_generator
|
||||
|
||||
USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
|
||||
if is_torch2_available() and (not USE_DAFAULT_ATTN):
|
||||
|
|
@ -238,7 +238,8 @@ class IPAdapterFaceID:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
@ -390,7 +391,8 @@ class IPAdapterFaceIDPlus:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
@ -454,7 +456,8 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
@ -523,7 +526,8 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from PIL import Image
|
|||
from safetensors import safe_open
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from .utils import is_torch2_available
|
||||
from .utils import is_torch2_available, get_generator
|
||||
|
||||
USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
|
||||
if is_torch2_available() and (not USE_DAFAULT_ATTN):
|
||||
|
|
@ -246,7 +246,8 @@ class IPAdapterFaceID:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
@ -395,7 +396,8 @@ class IPAdapterFaceIDPlus:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
@ -459,7 +461,8 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
@ -528,7 +531,8 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
|
|||
|
|
@ -79,3 +79,15 @@ def attnmaps2images(net_attn_maps):
|
|||
return images
|
||||
def is_torch2_available():
|
||||
return hasattr(F, "scaled_dot_product_attention")
|
||||
|
||||
def get_generator(seed, device):
|
||||
|
||||
if seed is not None:
|
||||
if isinstance(seed, list):
|
||||
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
|
||||
else:
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
return generator
|
||||
Loading…
Reference in New Issue