device added as a parameter

pull/296/head
Ayça 2024-02-15 10:55:54 +03:00
parent 911a65f006
commit abd3f860d1
4 changed files with 13 additions and 13 deletions

View File

@ -204,7 +204,7 @@ class IPAdapter:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
@ -402,7 +402,7 @@ class IPAdapterPlusXL(IPAdapter):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,

View File

@ -238,7 +238,7 @@ class IPAdapterFaceID:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
@ -391,7 +391,7 @@ class IPAdapterFaceIDPlus:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
@ -456,7 +456,7 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
@ -526,7 +526,7 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,

View File

@ -246,7 +246,7 @@ class IPAdapterFaceID:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
@ -396,7 +396,7 @@ class IPAdapterFaceIDPlus:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
@ -461,7 +461,7 @@ class IPAdapterFaceIDXL(IPAdapterFaceID):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
@ -531,7 +531,7 @@ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,

View File

@ -80,13 +80,13 @@ def attnmaps2images(net_attn_maps):
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
def get_generator(seed):
def get_generator(seed, device):
if seed is not None:
if isinstance(seed, list):
generator = [torch.Generator(self.device).manual_seed(seed_item) for seed_item in seed]
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
else:
generator = torch.Generator(self.device).manual_seed(seed)
generator = torch.Generator(device).manual_seed(seed)
else:
generator = None