automatic/scripts/lbm/embedders/latents_concat/latents_concat_embedder_mod...

78 lines
2.7 KiB
Python

from typing import Any, Dict
import torch
import torchvision.transforms.functional as F
from ...vae import AutoencoderKLDiffusers
from ..base import BaseConditioner
from .latents_concat_embedder_config import LatentsConcatEmbedderConfig
class LatentsConcatEmbedder(BaseConditioner):
"""
Class computing VAE embeddings from given images and resizing the masks.
Then outputs are then concatenated to the noise in the latent space.
Args:
config (LatentsConcatEmbedderConfig): Configs to create the embedder
"""
def __init__(self, config: LatentsConcatEmbedderConfig):
BaseConditioner.__init__(self, config)
def forward(
self, batch: Dict[str, Any], vae: AutoencoderKLDiffusers, *args, **kwargs
) -> dict:
"""
Args:
batch (dict): A batch of images to be processed by this embedder. In the batch,
the images must range between [-1, 1] and the masks range between [0, 1].
vae (AutoencoderKLDiffusers): VAE
Returns:
output (dict): outputs
"""
# Check if image are of the same size
dims_list = []
for image_key in self.config.image_keys:
dims_list.append(batch[image_key].shape[-2:])
for mask_key in self.config.mask_keys:
dims_list.append(batch[mask_key].shape[-2:])
assert all(
dims == dims_list[0] for dims in dims_list
), "All images and masks must have the same dimensions."
# Find the latent dimensions
if len(self.config.image_keys) > 0:
latent_dims = (
batch[self.config.image_keys[0]].shape[-2] // vae.downsampling_factor,
batch[self.config.image_keys[0]].shape[-1] // vae.downsampling_factor,
)
else:
latent_dims = (
batch[self.config.mask_keys[0]].shape[-2] // vae.downsampling_factor,
batch[self.config.mask_keys[0]].shape[-1] // vae.downsampling_factor,
)
outputs = []
# Resize the masks and concat them
for mask_key in self.config.mask_keys:
curr_latents = F.resize(
batch[mask_key],
size=latent_dims,
interpolation=F.InterpolationMode.BILINEAR,
)
outputs.append(curr_latents)
# Compute VAE embeddings from the images
for image_key in self.config.image_keys:
vae_embs = vae.encode(batch[image_key])
outputs.append(vae_embs)
# Concat all the outputs
outputs = torch.concat(outputs, dim=1)
outputs = {self.dim2outputkey[outputs.dim()]: outputs}
return outputs