device added as a parameter
parent
911a65f006
commit
abd3f860d1
|
|
@ -204,7 +204,7 @@ class IPAdapter:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = get_generator(seed)
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
@ -402,7 +402,7 @@ class IPAdapterPlusXL(IPAdapter):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = get_generator(seed)
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ class IPAdapterFaceID:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = get_generator(seed)
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
@ -391,7 +391,7 @@ class IPAdapterFaceIDPlus:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = get_generator(seed)
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
@ -456,7 +456,7 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = get_generator(seed)
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
@ -526,7 +526,7 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = get_generator(seed)
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ class IPAdapterFaceID:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = get_generator(seed)
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
@ -396,7 +396,7 @@ class IPAdapterFaceIDPlus:
|
|||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = get_generator(seed)
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
@ -461,7 +461,7 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = get_generator(seed)
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
@ -531,7 +531,7 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
|
|||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = get_generator(seed)
|
||||
generator = get_generator(seed, self.device)
|
||||
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
|
|
|
|||
|
|
@ -80,13 +80,13 @@ def attnmaps2images(net_attn_maps):
|
|||
def is_torch2_available():
|
||||
return hasattr(F, "scaled_dot_product_attention")
|
||||
|
||||
def get_generator(seed):
|
||||
def get_generator(seed, device):
|
||||
|
||||
if seed is not None:
|
||||
if isinstance(seed, list):
|
||||
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
|
||||
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
|
||||
else:
|
||||
generator = torch.Generator(self.device).manual_seed(seed)
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue