Merge pull request #296 from aycaecemgul/main

seed can be a list
pull/338/head
Hu Ye 2024-02-17 21:08:51 +08:00 committed by GitHub
commit 5a18b1f366
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 36 additions and 14 deletions

View File

@ -8,7 +8,7 @@ from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .utils import is_torch2_available
from .utils import is_torch2_available, get_generator
if is_torch2_available():
from .attention_processor import (
@ -204,7 +204,8 @@ class IPAdapter:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@ -267,7 +268,7 @@ class IPAdapterXL(IPAdapter):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
self.generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
self.generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
@ -401,7 +402,8 @@ class IPAdapterPlusXL(IPAdapter):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,

View File

@ -9,7 +9,7 @@ from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
from .utils import is_torch2_available
from .utils import is_torch2_available, get_generator
USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
if is_torch2_available() and (not USE_DAFAULT_ATTN):
@ -238,7 +238,8 @@ class IPAdapterFaceID:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@ -390,7 +391,8 @@ class IPAdapterFaceIDPlus:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@ -454,7 +456,8 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@ -523,7 +526,8 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,

View File

@ -8,7 +8,7 @@ from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .utils import is_torch2_available
from .utils import is_torch2_available, get_generator
USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
if is_torch2_available() and (not USE_DAFAULT_ATTN):
@ -246,7 +246,8 @@ class IPAdapterFaceID:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@ -395,7 +396,8 @@ class IPAdapterFaceIDPlus:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@ -459,7 +461,8 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@ -528,7 +531,8 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,

View File

@ -79,3 +79,15 @@ def attnmaps2images(net_attn_maps):
return images
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
def get_generator(seed, device):
if seed is not None:
if isinstance(seed, list):
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
else:
generator = torch.Generator(device).manual_seed(seed)
else:
generator = None
return generator