mirror of https://github.com/vladmandic/automatic
78 lines
2.7 KiB
Python
78 lines
2.7 KiB
Python
from typing import Any, Dict
|
|
import torch
|
|
import torchvision.transforms.functional as F
|
|
from ...vae import AutoencoderKLDiffusers
|
|
from ..base import BaseConditioner
|
|
from .latents_concat_embedder_config import LatentsConcatEmbedderConfig
|
|
|
|
|
|
class LatentsConcatEmbedder(BaseConditioner):
|
|
"""
|
|
Class computing VAE embeddings from given images and resizing the masks.
|
|
Then outputs are then concatenated to the noise in the latent space.
|
|
|
|
Args:
|
|
config (LatentsConcatEmbedderConfig): Configs to create the embedder
|
|
"""
|
|
|
|
def __init__(self, config: LatentsConcatEmbedderConfig):
|
|
BaseConditioner.__init__(self, config)
|
|
|
|
def forward(
|
|
self, batch: Dict[str, Any], vae: AutoencoderKLDiffusers, *args, **kwargs
|
|
) -> dict:
|
|
"""
|
|
Args:
|
|
batch (dict): A batch of images to be processed by this embedder. In the batch,
|
|
the images must range between [-1, 1] and the masks range between [0, 1].
|
|
vae (AutoencoderKLDiffusers): VAE
|
|
|
|
Returns:
|
|
output (dict): outputs
|
|
"""
|
|
|
|
# Check if image are of the same size
|
|
dims_list = []
|
|
for image_key in self.config.image_keys:
|
|
dims_list.append(batch[image_key].shape[-2:])
|
|
for mask_key in self.config.mask_keys:
|
|
dims_list.append(batch[mask_key].shape[-2:])
|
|
assert all(
|
|
dims == dims_list[0] for dims in dims_list
|
|
), "All images and masks must have the same dimensions."
|
|
|
|
# Find the latent dimensions
|
|
if len(self.config.image_keys) > 0:
|
|
latent_dims = (
|
|
batch[self.config.image_keys[0]].shape[-2] // vae.downsampling_factor,
|
|
batch[self.config.image_keys[0]].shape[-1] // vae.downsampling_factor,
|
|
)
|
|
else:
|
|
latent_dims = (
|
|
batch[self.config.mask_keys[0]].shape[-2] // vae.downsampling_factor,
|
|
batch[self.config.mask_keys[0]].shape[-1] // vae.downsampling_factor,
|
|
)
|
|
|
|
outputs = []
|
|
|
|
# Resize the masks and concat them
|
|
for mask_key in self.config.mask_keys:
|
|
curr_latents = F.resize(
|
|
batch[mask_key],
|
|
size=latent_dims,
|
|
interpolation=F.InterpolationMode.BILINEAR,
|
|
)
|
|
outputs.append(curr_latents)
|
|
|
|
# Compute VAE embeddings from the images
|
|
for image_key in self.config.image_keys:
|
|
vae_embs = vae.encode(batch[image_key])
|
|
outputs.append(vae_embs)
|
|
|
|
# Concat all the outputs
|
|
outputs = torch.concat(outputs, dim=1)
|
|
|
|
outputs = {self.dim2outputkey[outputs.dim()]: outputs}
|
|
|
|
return outputs
|