get_generator() implemented to use on all ip adapter pipelines
parent
9e82c1038a
commit
911a65f006
|
|
@ -8,7 +8,7 @@ from PIL import Image
|
|||
from safetensors import safe_open
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from .utils import is_torch2_available
|
||||
from .utils import is_torch2_available, get_generator
|
||||
|
||||
if is_torch2_available():
|
||||
from .attention_processor import (
|
||||
|
|
@ -204,7 +204,8 @@ class IPAdapter:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
@ -401,7 +402,8 @@ class IPAdapterPlusXL(IPAdapter):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from safetensors import safe_open
|
|||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
|
||||
from .utils import is_torch2_available
|
||||
from .utils import is_torch2_available, get_generator
|
||||
|
||||
USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
|
||||
if is_torch2_available() and (not USE_DAFAULT_ATTN):
|
||||
|
|
@ -238,13 +238,7 @@ class IPAdapterFaceID:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
if seed is not None:
|
||||
if isinstance(seed, list):
|
||||
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
|
||||
else:
|
||||
generator = torch.Generator(self.device).manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
generator = get_generator(seed)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
@ -397,13 +391,7 @@ class IPAdapterFaceIDPlus:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
if seed is not None:
|
||||
if isinstance(seed, list):
|
||||
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
|
||||
else:
|
||||
generator = torch.Generator(self.device).manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
generator = get_generator(seed)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
@ -468,13 +456,7 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
if seed is not None:
|
||||
if isinstance(seed, list):
|
||||
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
|
||||
else:
|
||||
generator = torch.Generator(self.device).manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
generator = get_generator(seed)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
@ -544,13 +526,7 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
if seed is not None:
|
||||
if isinstance(seed, list):
|
||||
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
|
||||
else:
|
||||
generator = torch.Generator(self.device).manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
generator = get_generator(seed)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from PIL import Image
|
|||
from safetensors import safe_open
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from .utils import is_torch2_available
|
||||
from .utils import is_torch2_available, get_generator
|
||||
|
||||
USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
|
||||
if is_torch2_available() and (not USE_DAFAULT_ATTN):
|
||||
|
|
@ -246,7 +246,8 @@ class IPAdapterFaceID:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
@ -395,7 +396,8 @@ class IPAdapterFaceIDPlus:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
@ -459,7 +461,8 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
@ -528,7 +531,8 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
generator = get_generator(seed)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
|
|
|
|||
|
|
@ -79,3 +79,15 @@ def attnmaps2images(net_attn_maps):
|
|||
return images
|
||||
def is_torch2_available():
|
||||
return hasattr(F, "scaled_dot_product_attention")
|
||||
|
||||
def get_generator(seed):
|
||||
|
||||
if seed is not None:
|
||||
if isinstance(seed, list):
|
||||
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
|
||||
else:
|
||||
generator = torch.Generator(self.device).manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
return generator
|
||||
Loading…
Reference in New Issue