get_generator() implemented to use on all ip adapter pipelines

pull/296/head
Ayça 2024-02-15 09:52:48 +03:00
parent 9e82c1038a
commit 911a65f006
4 changed files with 31 additions and 37 deletions

View File

@ -8,7 +8,7 @@ from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .utils import is_torch2_available
from .utils import is_torch2_available, get_generator
if is_torch2_available():
from .attention_processor import (
@ -204,7 +204,8 @@ class IPAdapter:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@ -401,7 +402,8 @@ class IPAdapterPlusXL(IPAdapter):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,

View File

@ -9,7 +9,7 @@ from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
from .utils import is_torch2_available
from .utils import is_torch2_available, get_generator
USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
if is_torch2_available() and (not USE_DAFAULT_ATTN):
@ -238,13 +238,7 @@ class IPAdapterFaceID:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
if seed is not None:
if isinstance(seed, list):
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
else:
generator = torch.Generator(self.device).manual_seed(seed)
else:
generator = None
generator = get_generator(seed)
images = self.pipe(
prompt_embeds=prompt_embeds,
@ -397,13 +391,7 @@ class IPAdapterFaceIDPlus:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
if seed is not None:
if isinstance(seed, list):
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
else:
generator = torch.Generator(self.device).manual_seed(seed)
else:
generator = None
generator = get_generator(seed)
images = self.pipe(
prompt_embeds=prompt_embeds,
@ -468,13 +456,7 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
if seed is not None:
if isinstance(seed, list):
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
else:
generator = torch.Generator(self.device).manual_seed(seed)
else:
generator = None
generator = get_generator(seed)
images = self.pipe(
prompt_embeds=prompt_embeds,
@ -544,13 +526,7 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
if seed is not None:
if isinstance(seed, list):
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
else:
generator = torch.Generator(self.device).manual_seed(seed)
else:
generator = None
generator = get_generator(seed)
images = self.pipe(
prompt_embeds=prompt_embeds,

View File

@ -8,7 +8,7 @@ from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .utils import is_torch2_available
from .utils import is_torch2_available, get_generator
USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
if is_torch2_available() and (not USE_DAFAULT_ATTN):
@ -246,7 +246,8 @@ class IPAdapterFaceID:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@ -395,7 +396,8 @@ class IPAdapterFaceIDPlus:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@ -459,7 +461,8 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@ -528,7 +531,8 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = get_generator(seed)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,

View File

@ -79,3 +79,15 @@ def attnmaps2images(net_attn_maps):
return images
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
def get_generator(seed):
if seed is not None:
if isinstance(seed, list):
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
else:
generator = torch.Generator(self.device).manual_seed(seed)
else:
generator = None
return generator