mirror of https://github.com/InstantID/InstantID
update vae decode
parent
0e0c3d0b14
commit
7caac0ec8a
|
|
@ -743,15 +743,31 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
|||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
|
|
|
|||
|
|
@ -1180,15 +1180,31 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
|||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue