1802 lines
71 KiB
Python
1802 lines
71 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import sys
|
|
import math
|
|
import warnings
|
|
from typing import List, Optional, Sequence, Tuple, Union, Any
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import copy
|
|
import inspect
|
|
import torch.nn as nn
|
|
|
|
Device = Union[str, torch.device]
|
|
|
|
# Default values for rotation and translation matrices.
|
|
_R = torch.eye(3)[None] # (1, 3, 3)
|
|
_T = torch.zeros(1, 3) # (1, 3)
|
|
|
|
|
|
# Provide get_origin and get_args even in Python 3.7.
|
|
|
|
if sys.version_info >= (3, 8, 0):
|
|
from typing import get_args, get_origin
|
|
elif sys.version_info >= (3, 7, 0):
|
|
|
|
def get_origin(cls): # pragma: no cover
|
|
return getattr(cls, "__origin__", None)
|
|
|
|
def get_args(cls): # pragma: no cover
|
|
return getattr(cls, "__args__", None)
|
|
|
|
|
|
else:
|
|
raise ImportError("This module requires Python 3.7+")
|
|
|
|
################################################################
|
|
## ██████╗██╗ █████╗ ███████╗███████╗███████╗███████╗ ##
|
|
## ██╔════╝██║ ██╔══██╗██╔════╝██╔════╝██╔════╝██╔════╝ ##
|
|
## ██║ ██║ ███████║███████╗███████╗█████╗ ███████╗ ##
|
|
## ██║ ██║ ██╔══██║╚════██║╚════██║██╔══╝ ╚════██║ ##
|
|
## ╚██████╗███████╗██║ ██║███████║███████║███████╗███████║ ##
|
|
## ╚═════╝╚══════╝╚═╝ ╚═╝╚══════╝╚══════╝╚══════╝╚══════╝ ##
|
|
################################################################
|
|
|
|
class Transform3d:
|
|
"""
|
|
A Transform3d object encapsulates a batch of N 3D transformations, and knows
|
|
how to transform points and normal vectors. Suppose that t is a Transform3d;
|
|
then we can do the following:
|
|
|
|
.. code-block:: python
|
|
|
|
N = len(t)
|
|
points = torch.randn(N, P, 3)
|
|
normals = torch.randn(N, P, 3)
|
|
points_transformed = t.transform_points(points) # => (N, P, 3)
|
|
normals_transformed = t.transform_normals(normals) # => (N, P, 3)
|
|
|
|
|
|
BROADCASTING
|
|
Transform3d objects supports broadcasting. Suppose that t1 and tN are
|
|
Transform3d objects with len(t1) == 1 and len(tN) == N respectively. Then we
|
|
can broadcast transforms like this:
|
|
|
|
.. code-block:: python
|
|
|
|
t1.transform_points(torch.randn(P, 3)) # => (P, 3)
|
|
t1.transform_points(torch.randn(1, P, 3)) # => (1, P, 3)
|
|
t1.transform_points(torch.randn(M, P, 3)) # => (M, P, 3)
|
|
tN.transform_points(torch.randn(P, 3)) # => (N, P, 3)
|
|
tN.transform_points(torch.randn(1, P, 3)) # => (N, P, 3)
|
|
|
|
|
|
COMBINING TRANSFORMS
|
|
Transform3d objects can be combined in two ways: composing and stacking.
|
|
Composing is function composition. Given Transform3d objects t1, t2, t3,
|
|
the following all compute the same thing:
|
|
|
|
.. code-block:: python
|
|
|
|
y1 = t3.transform_points(t2.transform_points(t1.transform_points(x)))
|
|
y2 = t1.compose(t2).compose(t3).transform_points(x)
|
|
y3 = t1.compose(t2, t3).transform_points(x)
|
|
|
|
|
|
Composing transforms should broadcast.
|
|
|
|
.. code-block:: python
|
|
|
|
if len(t1) == 1 and len(t2) == N, then len(t1.compose(t2)) == N.
|
|
|
|
We can also stack a sequence of Transform3d objects, which represents
|
|
composition along the batch dimension; then the following should compute the
|
|
same thing.
|
|
|
|
.. code-block:: python
|
|
|
|
N, M = len(tN), len(tM)
|
|
xN = torch.randn(N, P, 3)
|
|
xM = torch.randn(M, P, 3)
|
|
y1 = torch.cat([tN.transform_points(xN), tM.transform_points(xM)], dim=0)
|
|
y2 = tN.stack(tM).transform_points(torch.cat([xN, xM], dim=0))
|
|
|
|
BUILDING TRANSFORMS
|
|
We provide convenience methods for easily building Transform3d objects
|
|
as compositions of basic transforms.
|
|
|
|
.. code-block:: python
|
|
|
|
# Scale by 0.5, then translate by (1, 2, 3)
|
|
t1 = Transform3d().scale(0.5).translate(1, 2, 3)
|
|
|
|
# Scale each axis by a different amount, then translate, then scale
|
|
t2 = Transform3d().scale(1, 3, 3).translate(2, 3, 1).scale(2.0)
|
|
|
|
t3 = t1.compose(t2)
|
|
tN = t1.stack(t3, t3)
|
|
|
|
|
|
BACKPROP THROUGH TRANSFORMS
|
|
When building transforms, we can also parameterize them by Torch tensors;
|
|
in this case we can backprop through the construction and application of
|
|
Transform objects, so they could be learned via gradient descent or
|
|
predicted by a neural network.
|
|
|
|
.. code-block:: python
|
|
|
|
s1_params = torch.randn(N, requires_grad=True)
|
|
t_params = torch.randn(N, 3, requires_grad=True)
|
|
s2_params = torch.randn(N, 3, requires_grad=True)
|
|
|
|
t = Transform3d().scale(s1_params).translate(t_params).scale(s2_params)
|
|
x = torch.randn(N, 3)
|
|
y = t.transform_points(x)
|
|
loss = compute_loss(y)
|
|
loss.backward()
|
|
|
|
with torch.no_grad():
|
|
s1_params -= lr * s1_params.grad
|
|
t_params -= lr * t_params.grad
|
|
s2_params -= lr * s2_params.grad
|
|
|
|
CONVENTIONS
|
|
We adopt a right-hand coordinate system, meaning that rotation about an axis
|
|
with a positive angle results in a counter clockwise rotation.
|
|
|
|
This class assumes that transformations are applied on inputs which
|
|
are row vectors. The internal representation of the Nx4x4 transformation
|
|
matrix is of the form:
|
|
|
|
.. code-block:: python
|
|
|
|
M = [
|
|
[Rxx, Ryx, Rzx, 0],
|
|
[Rxy, Ryy, Rzy, 0],
|
|
[Rxz, Ryz, Rzz, 0],
|
|
[Tx, Ty, Tz, 1],
|
|
]
|
|
|
|
To apply the transformation to points which are row vectors, the M matrix
|
|
can be pre multiplied by the points:
|
|
|
|
.. code-block:: python
|
|
|
|
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
|
transformed_points = points * M
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Device = "cpu",
|
|
matrix: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
dtype: The data type of the transformation matrix.
|
|
to be used if `matrix = None`.
|
|
device: The device for storing the implemented transformation.
|
|
If `matrix != None`, uses the device of input `matrix`.
|
|
matrix: A tensor of shape (4, 4) or of shape (minibatch, 4, 4)
|
|
representing the 4x4 3D transformation matrix.
|
|
If `None`, initializes with identity using
|
|
the specified `device` and `dtype`.
|
|
"""
|
|
|
|
if matrix is None:
|
|
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
|
|
else:
|
|
if matrix.ndim not in (2, 3):
|
|
raise ValueError('"matrix" has to be a 2- or a 3-dimensional tensor.')
|
|
if matrix.shape[-2] != 4 or matrix.shape[-1] != 4:
|
|
raise ValueError(
|
|
'"matrix" has to be a tensor of shape (minibatch, 4, 4)'
|
|
)
|
|
# set dtype and device from matrix
|
|
dtype = matrix.dtype
|
|
device = matrix.device
|
|
self._matrix = matrix.view(-1, 4, 4)
|
|
|
|
self._transforms = [] # store transforms to compose
|
|
self._lu = None
|
|
self.device = make_device(device)
|
|
self.dtype = dtype
|
|
|
|
def __len__(self) -> int:
|
|
return self.get_matrix().shape[0]
|
|
|
|
def __getitem__(
|
|
self, index: Union[int, List[int], slice, torch.Tensor]
|
|
) -> "Transform3d":
|
|
"""
|
|
Args:
|
|
index: Specifying the index of the transform to retrieve.
|
|
Can be an int, slice, list of ints, boolean, long tensor.
|
|
Supports negative indices.
|
|
|
|
Returns:
|
|
Transform3d object with selected transforms. The tensors are not cloned.
|
|
"""
|
|
if isinstance(index, int):
|
|
index = [index]
|
|
return self.__class__(matrix=self.get_matrix()[index])
|
|
|
|
def compose(self, *others: "Transform3d") -> "Transform3d":
|
|
"""
|
|
Return a new Transform3d representing the composition of self with the
|
|
given other transforms, which will be stored as an internal list.
|
|
|
|
Args:
|
|
*others: Any number of Transform3d objects
|
|
|
|
Returns:
|
|
A new Transform3d with the stored transforms
|
|
"""
|
|
out = Transform3d(dtype=self.dtype, device=self.device)
|
|
out._matrix = self._matrix.clone()
|
|
for other in others:
|
|
if not isinstance(other, Transform3d):
|
|
msg = "Only possible to compose Transform3d objects; got %s"
|
|
raise ValueError(msg % type(other))
|
|
out._transforms = self._transforms + list(others)
|
|
return out
|
|
|
|
def get_matrix(self) -> torch.Tensor:
|
|
"""
|
|
Return a matrix which is the result of composing this transform
|
|
with others stored in self.transforms. Where necessary transforms
|
|
are broadcast against each other.
|
|
For example, if self.transforms contains transforms t1, t2, and t3, and
|
|
given a set of points x, the following should be true:
|
|
|
|
.. code-block:: python
|
|
|
|
y1 = t1.compose(t2, t3).transform(x)
|
|
y2 = t3.transform(t2.transform(t1.transform(x)))
|
|
y1.get_matrix() == y2.get_matrix()
|
|
|
|
Returns:
|
|
A transformation matrix representing the composed inputs.
|
|
"""
|
|
composed_matrix = self._matrix.clone()
|
|
if len(self._transforms) > 0:
|
|
for other in self._transforms:
|
|
other_matrix = other.get_matrix()
|
|
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
|
|
return composed_matrix
|
|
|
|
def _get_matrix_inverse(self) -> torch.Tensor:
|
|
"""
|
|
Return the inverse of self._matrix.
|
|
"""
|
|
return torch.inverse(self._matrix)
|
|
|
|
def inverse(self, invert_composed: bool = False) -> "Transform3d":
|
|
"""
|
|
Returns a new Transform3d object that represents an inverse of the
|
|
current transformation.
|
|
|
|
Args:
|
|
invert_composed:
|
|
- True: First compose the list of stored transformations
|
|
and then apply inverse to the result. This is
|
|
potentially slower for classes of transformations
|
|
with inverses that can be computed efficiently
|
|
(e.g. rotations and translations).
|
|
- False: Invert the individual stored transformations
|
|
independently without composing them.
|
|
|
|
Returns:
|
|
A new Transform3d object containing the inverse of the original
|
|
transformation.
|
|
"""
|
|
|
|
tinv = Transform3d(dtype=self.dtype, device=self.device)
|
|
|
|
if invert_composed:
|
|
# first compose then invert
|
|
tinv._matrix = torch.inverse(self.get_matrix())
|
|
else:
|
|
# self._get_matrix_inverse() implements efficient inverse
|
|
# of self._matrix
|
|
i_matrix = self._get_matrix_inverse()
|
|
|
|
# 2 cases:
|
|
if len(self._transforms) > 0:
|
|
# a) Either we have a non-empty list of transforms:
|
|
# Here we take self._matrix and append its inverse at the
|
|
# end of the reverted _transforms list. After composing
|
|
# the transformations with get_matrix(), this correctly
|
|
# right-multiplies by the inverse of self._matrix
|
|
# at the end of the composition.
|
|
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
|
|
last = Transform3d(dtype=self.dtype, device=self.device)
|
|
last._matrix = i_matrix
|
|
tinv._transforms.append(last)
|
|
else:
|
|
# b) Or there are no stored transformations
|
|
# we just set inverted matrix
|
|
tinv._matrix = i_matrix
|
|
|
|
return tinv
|
|
|
|
def stack(self, *others: "Transform3d") -> "Transform3d":
|
|
"""
|
|
Return a new batched Transform3d representing the batch elements from
|
|
self and all the given other transforms all batched together.
|
|
|
|
Args:
|
|
*others: Any number of Transform3d objects
|
|
|
|
Returns:
|
|
A new Transform3d.
|
|
"""
|
|
transforms = [self] + list(others)
|
|
matrix = torch.cat([t.get_matrix() for t in transforms], dim=0)
|
|
out = Transform3d(dtype=self.dtype, device=self.device)
|
|
out._matrix = matrix
|
|
return out
|
|
|
|
def transform_points(self, points, eps: Optional[float] = None) -> torch.Tensor:
|
|
"""
|
|
Use this transform to transform a set of 3D points. Assumes row major
|
|
ordering of the input points.
|
|
|
|
Args:
|
|
points: Tensor of shape (P, 3) or (N, P, 3)
|
|
eps: If eps!=None, the argument is used to clamp the
|
|
last coordinate before performing the final division.
|
|
The clamping corresponds to:
|
|
last_coord := (last_coord.sign() + (last_coord==0)) *
|
|
torch.clamp(last_coord.abs(), eps),
|
|
i.e. the last coordinates that are exactly 0 will
|
|
be clamped to +eps.
|
|
|
|
Returns:
|
|
points_out: points of shape (N, P, 3) or (P, 3) depending
|
|
on the dimensions of the transform
|
|
"""
|
|
points_batch = points.clone()
|
|
if points_batch.dim() == 2:
|
|
points_batch = points_batch[None] # (P, 3) -> (1, P, 3)
|
|
if points_batch.dim() != 3:
|
|
msg = "Expected points to have dim = 2 or dim = 3: got shape %r"
|
|
raise ValueError(msg % repr(points.shape))
|
|
|
|
N, P, _3 = points_batch.shape
|
|
ones = torch.ones(N, P, 1, dtype=points.dtype, device=points.device)
|
|
points_batch = torch.cat([points_batch, ones], dim=2)
|
|
|
|
composed_matrix = self.get_matrix()
|
|
points_out = _broadcast_bmm(points_batch, composed_matrix)
|
|
denom = points_out[..., 3:] # denominator
|
|
if eps is not None:
|
|
denom_sign = denom.sign() + (denom == 0.0).type_as(denom)
|
|
denom = denom_sign * torch.clamp(denom.abs(), eps)
|
|
points_out = points_out[..., :3] / denom
|
|
|
|
# When transform is (1, 4, 4) and points is (P, 3) return
|
|
# points_out of shape (P, 3)
|
|
if points_out.shape[0] == 1 and points.dim() == 2:
|
|
points_out = points_out.reshape(points.shape)
|
|
|
|
return points_out
|
|
|
|
def transform_normals(self, normals) -> torch.Tensor:
|
|
"""
|
|
Use this transform to transform a set of normal vectors.
|
|
|
|
Args:
|
|
normals: Tensor of shape (P, 3) or (N, P, 3)
|
|
|
|
Returns:
|
|
normals_out: Tensor of shape (P, 3) or (N, P, 3) depending
|
|
on the dimensions of the transform
|
|
"""
|
|
if normals.dim() not in [2, 3]:
|
|
msg = "Expected normals to have dim = 2 or dim = 3: got shape %r"
|
|
raise ValueError(msg % (normals.shape,))
|
|
composed_matrix = self.get_matrix()
|
|
|
|
# TODO: inverse is bad! Solve a linear system instead
|
|
mat = composed_matrix[:, :3, :3]
|
|
normals_out = _broadcast_bmm(normals, mat.transpose(1, 2).inverse())
|
|
|
|
# This doesn't pass unit tests. TODO investigate further
|
|
# if self._lu is None:
|
|
# self._lu = self._matrix[:, :3, :3].transpose(1, 2).lu()
|
|
# normals_out = normals.lu_solve(*self._lu)
|
|
|
|
# When transform is (1, 4, 4) and normals is (P, 3) return
|
|
# normals_out of shape (P, 3)
|
|
if normals_out.shape[0] == 1 and normals.dim() == 2:
|
|
normals_out = normals_out.reshape(normals.shape)
|
|
|
|
return normals_out
|
|
|
|
def translate(self, *args, **kwargs) -> "Transform3d":
|
|
return self.compose(
|
|
Translate(device=self.device, dtype=self.dtype, *args, **kwargs)
|
|
)
|
|
|
|
def scale(self, *args, **kwargs) -> "Transform3d":
|
|
return self.compose(
|
|
Scale(device=self.device, dtype=self.dtype, *args, **kwargs)
|
|
)
|
|
|
|
def rotate(self, *args, **kwargs) -> "Transform3d":
|
|
return self.compose(
|
|
Rotate(device=self.device, dtype=self.dtype, *args, **kwargs)
|
|
)
|
|
|
|
def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d":
|
|
return self.compose(
|
|
RotateAxisAngle(device=self.device, dtype=self.dtype, *args, **kwargs)
|
|
)
|
|
|
|
def clone(self) -> "Transform3d":
|
|
"""
|
|
Deep copy of Transforms object. All internal tensors are cloned
|
|
individually.
|
|
|
|
Returns:
|
|
new Transforms object.
|
|
"""
|
|
other = Transform3d(dtype=self.dtype, device=self.device)
|
|
if self._lu is not None:
|
|
other._lu = [elem.clone() for elem in self._lu]
|
|
other._matrix = self._matrix.clone()
|
|
other._transforms = [t.clone() for t in self._transforms]
|
|
return other
|
|
|
|
def to(
|
|
self,
|
|
device: Device,
|
|
copy: bool = False,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> "Transform3d":
|
|
"""
|
|
Match functionality of torch.Tensor.to()
|
|
If copy = True or the self Tensor is on a different device, the
|
|
returned tensor is a copy of self with the desired torch.device.
|
|
If copy = False and the self Tensor already has the correct torch.device,
|
|
then self is returned.
|
|
|
|
Args:
|
|
device: Device (as str or torch.device) for the new tensor.
|
|
copy: Boolean indicator whether or not to clone self. Default False.
|
|
dtype: If not None, casts the internal tensor variables
|
|
to a given torch.dtype.
|
|
|
|
Returns:
|
|
Transform3d object.
|
|
"""
|
|
device_ = make_device(device)
|
|
dtype_ = self.dtype if dtype is None else dtype
|
|
skip_to = self.device == device_ and self.dtype == dtype_
|
|
|
|
if not copy and skip_to:
|
|
return self
|
|
|
|
other = self.clone()
|
|
|
|
if skip_to:
|
|
return other
|
|
|
|
other.device = device_
|
|
other.dtype = dtype_
|
|
other._matrix = other._matrix.to(device=device_, dtype=dtype_)
|
|
other._transforms = [
|
|
t.to(device_, copy=copy, dtype=dtype_) for t in other._transforms
|
|
]
|
|
return other
|
|
|
|
def cpu(self) -> "Transform3d":
|
|
return self.to("cpu")
|
|
|
|
def cuda(self) -> "Transform3d":
|
|
return self.to("cuda")
|
|
|
|
class Translate(Transform3d):
|
|
def __init__(
|
|
self,
|
|
x,
|
|
y=None,
|
|
z=None,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Optional[Device] = None,
|
|
) -> None:
|
|
"""
|
|
Create a new Transform3d representing 3D translations.
|
|
|
|
Option I: Translate(xyz, dtype=torch.float32, device='cpu')
|
|
xyz should be a tensor of shape (N, 3)
|
|
|
|
Option II: Translate(x, y, z, dtype=torch.float32, device='cpu')
|
|
Here x, y, and z will be broadcast against each other and
|
|
concatenated to form the translation. Each can be:
|
|
- A python scalar
|
|
- A torch scalar
|
|
- A 1D torch tensor
|
|
"""
|
|
xyz = _handle_input(x, y, z, dtype, device, "Translate")
|
|
super().__init__(device=xyz.device, dtype=dtype)
|
|
N = xyz.shape[0]
|
|
|
|
mat = torch.eye(4, dtype=dtype, device=self.device)
|
|
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
|
|
mat[:, 3, :3] = xyz
|
|
self._matrix = mat
|
|
|
|
def _get_matrix_inverse(self) -> torch.Tensor:
|
|
"""
|
|
Return the inverse of self._matrix.
|
|
"""
|
|
inv_mask = self._matrix.new_ones([1, 4, 4])
|
|
inv_mask[0, 3, :3] = -1.0
|
|
i_matrix = self._matrix * inv_mask
|
|
return i_matrix
|
|
|
|
class Rotate(Transform3d):
|
|
def __init__(
|
|
self,
|
|
R: torch.Tensor,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Optional[Device] = None,
|
|
orthogonal_tol: float = 1e-5,
|
|
) -> None:
|
|
"""
|
|
Create a new Transform3d representing 3D rotation using a rotation
|
|
matrix as the input.
|
|
|
|
Args:
|
|
R: a tensor of shape (3, 3) or (N, 3, 3)
|
|
orthogonal_tol: tolerance for the test of the orthogonality of R
|
|
|
|
"""
|
|
device_ = get_device(R, device)
|
|
super().__init__(device=device_, dtype=dtype)
|
|
if R.dim() == 2:
|
|
R = R[None]
|
|
if R.shape[-2:] != (3, 3):
|
|
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
|
|
raise ValueError(msg % repr(R.shape))
|
|
R = R.to(device=device_, dtype=dtype)
|
|
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
|
|
N = R.shape[0]
|
|
mat = torch.eye(4, dtype=dtype, device=device_)
|
|
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
|
|
mat[:, :3, :3] = R
|
|
self._matrix = mat
|
|
|
|
def _get_matrix_inverse(self) -> torch.Tensor:
|
|
"""
|
|
Return the inverse of self._matrix.
|
|
"""
|
|
return self._matrix.permute(0, 2, 1).contiguous()
|
|
|
|
class TensorAccessor(nn.Module):
|
|
"""
|
|
A helper class to be used with the __getitem__ method. This can be used for
|
|
getting/setting the values for an attribute of a class at one particular
|
|
index. This is useful when the attributes of a class are batched tensors
|
|
and one element in the batch needs to be modified.
|
|
"""
|
|
|
|
def __init__(self, class_object, index: Union[int, slice]) -> None:
|
|
"""
|
|
Args:
|
|
class_object: this should be an instance of a class which has
|
|
attributes which are tensors representing a batch of
|
|
values.
|
|
index: int/slice, an index indicating the position in the batch.
|
|
In __setattr__ and __getattr__ only the value of class
|
|
attributes at this index will be accessed.
|
|
"""
|
|
self.__dict__["class_object"] = class_object
|
|
self.__dict__["index"] = index
|
|
|
|
def __setattr__(self, name: str, value: Any):
|
|
"""
|
|
Update the attribute given by `name` to the value given by `value`
|
|
at the index specified by `self.index`.
|
|
Args:
|
|
name: str, name of the attribute.
|
|
value: value to set the attribute to.
|
|
"""
|
|
v = getattr(self.class_object, name)
|
|
if not torch.is_tensor(v):
|
|
msg = "Can only set values on attributes which are tensors; got %r"
|
|
raise AttributeError(msg % type(v))
|
|
|
|
# Convert the attribute to a tensor if it is not a tensor.
|
|
if not torch.is_tensor(value):
|
|
value = torch.tensor(
|
|
value, device=v.device, dtype=v.dtype, requires_grad=v.requires_grad
|
|
)
|
|
|
|
# Check the shapes match the existing shape and the shape of the index.
|
|
if v.dim() > 1 and value.dim() > 1 and value.shape[1:] != v.shape[1:]:
|
|
msg = "Expected value to have shape %r; got %r"
|
|
raise ValueError(msg % (v.shape, value.shape))
|
|
if (
|
|
v.dim() == 0
|
|
and isinstance(self.index, slice)
|
|
and len(value) != len(self.index)
|
|
):
|
|
msg = "Expected value to have len %r; got %r"
|
|
raise ValueError(msg % (len(self.index), len(value)))
|
|
self.class_object.__dict__[name][self.index] = value
|
|
|
|
def __getattr__(self, name: str):
|
|
"""
|
|
Return the value of the attribute given by "name" on self.class_object
|
|
at the index specified in self.index.
|
|
Args:
|
|
name: string of the attribute name
|
|
"""
|
|
if hasattr(self.class_object, name):
|
|
return self.class_object.__dict__[name][self.index]
|
|
else:
|
|
msg = "Attribute %s not found on %r"
|
|
return AttributeError(msg % (name, self.class_object.__name__))
|
|
|
|
BROADCAST_TYPES = (float, int, list, tuple, torch.Tensor, np.ndarray)
|
|
|
|
class TensorProperties(nn.Module):
|
|
"""
|
|
A mix-in class for storing tensors as properties with helper methods.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Device = "cpu",
|
|
**kwargs,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
dtype: data type to set for the inputs
|
|
device: Device (as str or torch.device)
|
|
kwargs: any number of keyword arguments. Any arguments which are
|
|
of type (float/int/list/tuple/tensor/array) are broadcasted and
|
|
other keyword arguments are set as attributes.
|
|
"""
|
|
super().__init__()
|
|
self.device = make_device(device)
|
|
self._N = 0
|
|
if kwargs is not None:
|
|
|
|
# broadcast all inputs which are float/int/list/tuple/tensor/array
|
|
# set as attributes anything else e.g. strings, bools
|
|
args_to_broadcast = {}
|
|
for k, v in kwargs.items():
|
|
if v is None or isinstance(v, (str, bool)):
|
|
setattr(self, k, v)
|
|
elif isinstance(v, BROADCAST_TYPES):
|
|
args_to_broadcast[k] = v
|
|
else:
|
|
msg = "Arg %s with type %r is not broadcastable"
|
|
warnings.warn(msg % (k, type(v)))
|
|
|
|
names = args_to_broadcast.keys()
|
|
# convert from type dict.values to tuple
|
|
values = tuple(v for v in args_to_broadcast.values())
|
|
|
|
if len(values) > 0:
|
|
broadcasted_values = convert_to_tensors_and_broadcast(
|
|
*values, device=device
|
|
)
|
|
|
|
# Set broadcasted values as attributes on self.
|
|
for i, n in enumerate(names):
|
|
setattr(self, n, broadcasted_values[i])
|
|
if self._N == 0:
|
|
self._N = broadcasted_values[i].shape[0]
|
|
|
|
def __len__(self) -> int:
|
|
return self._N
|
|
|
|
def isempty(self) -> bool:
|
|
return self._N == 0
|
|
|
|
def __getitem__(self, index: Union[int, slice]) -> TensorAccessor:
|
|
"""
|
|
Args:
|
|
index: an int or slice used to index all the fields.
|
|
Returns:
|
|
if `index` is an index int/slice return a TensorAccessor class
|
|
with getattribute/setattribute methods which return/update the value
|
|
at the index in the original class.
|
|
"""
|
|
if isinstance(index, (int, slice)):
|
|
return TensorAccessor(class_object=self, index=index)
|
|
|
|
msg = "Expected index of type int or slice; got %r"
|
|
raise ValueError(msg % type(index))
|
|
|
|
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
|
|
def to(self, device: Device = "cpu") -> "TensorProperties":
|
|
"""
|
|
In place operation to move class properties which are tensors to a
|
|
specified device. If self has a property "device", update this as well.
|
|
"""
|
|
device_ = make_device(device)
|
|
for k in dir(self):
|
|
v = getattr(self, k)
|
|
if k == "device":
|
|
setattr(self, k, device_)
|
|
if torch.is_tensor(v) and v.device != device_:
|
|
setattr(self, k, v.to(device_))
|
|
return self
|
|
|
|
def cpu(self) -> "TensorProperties":
|
|
return self.to("cpu")
|
|
|
|
# pyre-fixme[14]: `cuda` overrides method defined in `Module` inconsistently.
|
|
def cuda(self, device: Optional[int] = None) -> "TensorProperties":
|
|
return self.to(f"cuda:{device}" if device is not None else "cuda")
|
|
|
|
def clone(self, other) -> "TensorProperties":
|
|
"""
|
|
Update the tensor properties of other with the cloned properties of self.
|
|
"""
|
|
for k in dir(self):
|
|
v = getattr(self, k)
|
|
if inspect.ismethod(v) or k.startswith("__"):
|
|
continue
|
|
if torch.is_tensor(v):
|
|
v_clone = v.clone()
|
|
else:
|
|
v_clone = copy.deepcopy(v)
|
|
setattr(other, k, v_clone)
|
|
return other
|
|
|
|
def gather_props(self, batch_idx) -> "TensorProperties":
|
|
"""
|
|
This is an in place operation to reformat all tensor class attributes
|
|
based on a set of given indices using torch.gather. This is useful when
|
|
attributes which are batched tensors e.g. shape (N, 3) need to be
|
|
multiplied with another tensor which has a different first dimension
|
|
e.g. packed vertices of shape (V, 3).
|
|
Example
|
|
.. code-block:: python
|
|
self.specular_color = (N, 3) tensor of specular colors for each mesh
|
|
A lighting calculation may use
|
|
.. code-block:: python
|
|
verts_packed = meshes.verts_packed() # (V, 3)
|
|
To multiply these two tensors the batch dimension needs to be the same.
|
|
To achieve this we can do
|
|
.. code-block:: python
|
|
batch_idx = meshes.verts_packed_to_mesh_idx() # (V)
|
|
This gives index of the mesh for each vertex in verts_packed.
|
|
.. code-block:: python
|
|
self.gather_props(batch_idx)
|
|
self.specular_color = (V, 3) tensor with the specular color for
|
|
each packed vertex.
|
|
torch.gather requires the index tensor to have the same shape as the
|
|
input tensor so this method takes care of the reshaping of the index
|
|
tensor to use with class attributes with arbitrary dimensions.
|
|
Args:
|
|
batch_idx: shape (B, ...) where `...` represents an arbitrary
|
|
number of dimensions
|
|
Returns:
|
|
self with all properties reshaped. e.g. a property with shape (N, 3)
|
|
is transformed to shape (B, 3).
|
|
"""
|
|
# Iterate through the attributes of the class which are tensors.
|
|
for k in dir(self):
|
|
v = getattr(self, k)
|
|
if torch.is_tensor(v):
|
|
if v.shape[0] > 1:
|
|
# There are different values for each batch element
|
|
# so gather these using the batch_idx.
|
|
# First clone the input batch_idx tensor before
|
|
# modifying it.
|
|
_batch_idx = batch_idx.clone()
|
|
idx_dims = _batch_idx.shape
|
|
tensor_dims = v.shape
|
|
if len(idx_dims) > len(tensor_dims):
|
|
msg = "batch_idx cannot have more dimensions than %s. "
|
|
msg += "got shape %r and %s has shape %r"
|
|
raise ValueError(msg % (k, idx_dims, k, tensor_dims))
|
|
if idx_dims != tensor_dims:
|
|
# To use torch.gather the index tensor (_batch_idx) has
|
|
# to have the same shape as the input tensor.
|
|
new_dims = len(tensor_dims) - len(idx_dims)
|
|
new_shape = idx_dims + (1,) * new_dims
|
|
expand_dims = (-1,) + tensor_dims[1:]
|
|
_batch_idx = _batch_idx.view(*new_shape)
|
|
_batch_idx = _batch_idx.expand(*expand_dims)
|
|
|
|
v = v.gather(0, _batch_idx)
|
|
setattr(self, k, v)
|
|
return self
|
|
|
|
class CamerasBase(TensorProperties):
|
|
"""
|
|
`CamerasBase` implements a base class for all cameras.
|
|
For cameras, there are four different coordinate systems (or spaces)
|
|
- World coordinate system: This is the system the object lives - the world.
|
|
- Camera view coordinate system: This is the system that has its origin on the camera
|
|
and the and the Z-axis perpendicular to the image plane.
|
|
In PyTorch3D, we assume that +X points left, and +Y points up and
|
|
+Z points out from the image plane.
|
|
The transformation from world --> view happens after applying a rotation (R)
|
|
and translation (T)
|
|
- NDC coordinate system: This is the normalized coordinate system that confines
|
|
in a volume the rendered part of the object or scene. Also known as view volume.
|
|
For square images, given the PyTorch3D convention, (+1, +1, znear)
|
|
is the top left near corner, and (-1, -1, zfar) is the bottom right far
|
|
corner of the volume.
|
|
The transformation from view --> NDC happens after applying the camera
|
|
projection matrix (P) if defined in NDC space.
|
|
For non square images, we scale the points such that smallest side
|
|
has range [-1, 1] and the largest side has range [-u, u], with u > 1.
|
|
- Screen coordinate system: This is another representation of the view volume with
|
|
the XY coordinates defined in image space instead of a normalized space.
|
|
A better illustration of the coordinate systems can be found in
|
|
pytorch3d/docs/notes/cameras.md.
|
|
It defines methods that are common to all camera models:
|
|
- `get_camera_center` that returns the optical center of the camera in
|
|
world coordinates
|
|
- `get_world_to_view_transform` which returns a 3D transform from
|
|
world coordinates to the camera view coordinates (R, T)
|
|
- `get_full_projection_transform` which composes the projection
|
|
transform (P) with the world-to-view transform (R, T)
|
|
- `transform_points` which takes a set of input points in world coordinates and
|
|
projects to the space the camera is defined in (NDC or screen)
|
|
- `get_ndc_camera_transform` which defines the transform from screen/NDC to
|
|
PyTorch3D's NDC space
|
|
- `transform_points_ndc` which takes a set of points in world coordinates and
|
|
projects them to PyTorch3D's NDC space
|
|
- `transform_points_screen` which takes a set of points in world coordinates and
|
|
projects them to screen space
|
|
For each new camera, one should implement the `get_projection_transform`
|
|
routine that returns the mapping from camera view coordinates to camera
|
|
coordinates (NDC or screen).
|
|
Another useful function that is specific to each camera model is
|
|
`unproject_points` which sends points from camera coordinates (NDC or screen)
|
|
back to camera view or world coordinates depending on the `world_coordinates`
|
|
boolean argument of the function.
|
|
"""
|
|
|
|
# Used in __getitem__ to index the relevant fields
|
|
# When creating a new camera, this should be set in the __init__
|
|
_FIELDS: Tuple[str, ...] = ()
|
|
|
|
# Names of fields which are a constant property of the whole batch, rather
|
|
# than themselves a batch of data.
|
|
# When joining objects into a batch, they will have to agree.
|
|
_SHARED_FIELDS: Tuple[str, ...] = ()
|
|
|
|
def get_projection_transform(self):
|
|
"""
|
|
Calculate the projective transformation matrix.
|
|
Args:
|
|
**kwargs: parameters for the projection can be passed in as keyword
|
|
arguments to override the default values set in `__init__`.
|
|
Return:
|
|
a `Transform3d` object which represents a batch of projection
|
|
matrices of shape (N, 3, 3)
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def unproject_points(self, xy_depth: torch.Tensor, **kwargs):
|
|
"""
|
|
Transform input points from camera coodinates (NDC or screen)
|
|
to the world / camera coordinates.
|
|
Each of the input points `xy_depth` of shape (..., 3) is
|
|
a concatenation of the x, y location and its depth.
|
|
For instance, for an input 2D tensor of shape `(num_points, 3)`
|
|
`xy_depth` takes the following form:
|
|
`xy_depth[i] = [x[i], y[i], depth[i]]`,
|
|
for a each point at an index `i`.
|
|
The following example demonstrates the relationship between
|
|
`transform_points` and `unproject_points`:
|
|
.. code-block:: python
|
|
cameras = # camera object derived from CamerasBase
|
|
xyz = # 3D points of shape (batch_size, num_points, 3)
|
|
# transform xyz to the camera view coordinates
|
|
xyz_cam = cameras.get_world_to_view_transform().transform_points(xyz)
|
|
# extract the depth of each point as the 3rd coord of xyz_cam
|
|
depth = xyz_cam[:, :, 2:]
|
|
# project the points xyz to the camera
|
|
xy = cameras.transform_points(xyz)[:, :, :2]
|
|
# append depth to xy
|
|
xy_depth = torch.cat((xy, depth), dim=2)
|
|
# unproject to the world coordinates
|
|
xyz_unproj_world = cameras.unproject_points(xy_depth, world_coordinates=True)
|
|
print(torch.allclose(xyz, xyz_unproj_world)) # True
|
|
# unproject to the camera coordinates
|
|
xyz_unproj = cameras.unproject_points(xy_depth, world_coordinates=False)
|
|
print(torch.allclose(xyz_cam, xyz_unproj)) # True
|
|
Args:
|
|
xy_depth: torch tensor of shape (..., 3).
|
|
world_coordinates: If `True`, unprojects the points back to world
|
|
coordinates using the camera extrinsics `R` and `T`.
|
|
`False` ignores `R` and `T` and unprojects to
|
|
the camera view coordinates.
|
|
from_ndc: If `False` (default), assumes xy part of input is in
|
|
NDC space if self.in_ndc(), otherwise in screen space. If
|
|
`True`, assumes xy is in NDC space even if the camera
|
|
is defined in screen space.
|
|
Returns
|
|
new_points: unprojected points with the same shape as `xy_depth`.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def get_camera_center(self, **kwargs) -> torch.Tensor:
|
|
"""
|
|
Return the 3D location of the camera optical center
|
|
in the world coordinates.
|
|
Args:
|
|
**kwargs: parameters for the camera extrinsics can be passed in
|
|
as keyword arguments to override the default values
|
|
set in __init__.
|
|
Setting T here will update the values set in init as this
|
|
value may be needed later on in the rendering pipeline e.g. for
|
|
lighting calculations.
|
|
Returns:
|
|
C: a batch of 3D locations of shape (N, 3) denoting
|
|
the locations of the center of each camera in the batch.
|
|
"""
|
|
w2v_trans = self.get_world_to_view_transform(**kwargs)
|
|
P = w2v_trans.inverse().get_matrix()
|
|
# the camera center is the translation component (the first 3 elements
|
|
# of the last row) of the inverted world-to-view
|
|
# transform (4x4 RT matrix)
|
|
C = P[:, 3, :3]
|
|
return C
|
|
|
|
def get_world_to_view_transform(self, **kwargs) -> Transform3d:
|
|
"""
|
|
Return the world-to-view transform.
|
|
Args:
|
|
**kwargs: parameters for the camera extrinsics can be passed in
|
|
as keyword arguments to override the default values
|
|
set in __init__.
|
|
Setting R and T here will update the values set in init as these
|
|
values may be needed later on in the rendering pipeline e.g. for
|
|
lighting calculations.
|
|
Returns:
|
|
A Transform3d object which represents a batch of transforms
|
|
of shape (N, 3, 3)
|
|
"""
|
|
R: torch.Tensor = kwargs.get("R", self.R)
|
|
T: torch.Tensor = kwargs.get("T", self.T)
|
|
self.R = R # pyre-ignore[16]
|
|
self.T = T # pyre-ignore[16]
|
|
world_to_view_transform = get_world_to_view_transform(R=R, T=T)
|
|
return world_to_view_transform
|
|
|
|
def get_full_projection_transform(self, **kwargs) -> Transform3d:
|
|
"""
|
|
Return the full world-to-camera transform composing the
|
|
world-to-view and view-to-camera transforms.
|
|
If camera is defined in NDC space, the projected points are in NDC space.
|
|
If camera is defined in screen space, the projected points are in screen space.
|
|
Args:
|
|
**kwargs: parameters for the projection transforms can be passed in
|
|
as keyword arguments to override the default values
|
|
set in __init__.
|
|
Setting R and T here will update the values set in init as these
|
|
values may be needed later on in the rendering pipeline e.g. for
|
|
lighting calculations.
|
|
Returns:
|
|
a Transform3d object which represents a batch of transforms
|
|
of shape (N, 3, 3)
|
|
"""
|
|
self.R: torch.Tensor = kwargs.get("R", self.R) # pyre-ignore[16]
|
|
self.T: torch.Tensor = kwargs.get("T", self.T) # pyre-ignore[16]
|
|
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
|
|
view_to_proj_transform = self.get_projection_transform(**kwargs)
|
|
return world_to_view_transform.compose(view_to_proj_transform)
|
|
|
|
def transform_points(
|
|
self, points, eps: Optional[float] = None, **kwargs
|
|
) -> torch.Tensor:
|
|
"""
|
|
Transform input points from world to camera space with the
|
|
projection matrix defined by the camera.
|
|
For `CamerasBase.transform_points`, setting `eps > 0`
|
|
stabilizes gradients since it leads to avoiding division
|
|
by excessively low numbers for points close to the camera plane.
|
|
Args:
|
|
points: torch tensor of shape (..., 3).
|
|
eps: If eps!=None, the argument is used to clamp the
|
|
divisor in the homogeneous normalization of the points
|
|
transformed to the ndc space. Please see
|
|
`transforms.Transform3d.transform_points` for details.
|
|
For `CamerasBase.transform_points`, setting `eps > 0`
|
|
stabilizes gradients since it leads to avoiding division
|
|
by excessively low numbers for points close to the
|
|
camera plane.
|
|
Returns
|
|
new_points: transformed points with the same shape as the input.
|
|
"""
|
|
world_to_proj_transform = self.get_full_projection_transform(**kwargs)
|
|
return world_to_proj_transform.transform_points(points, eps=eps)
|
|
|
|
def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
|
|
"""
|
|
Returns the transform from camera projection space (screen or NDC) to NDC space.
|
|
For cameras that can be specified in screen space, this transform
|
|
allows points to be converted from screen to NDC space.
|
|
The default transform scales the points from [0, W]x[0, H]
|
|
to [-1, 1]x[-u, u] or [-u, u]x[-1, 1] where u > 1 is the aspect ratio of the image.
|
|
This function should be modified per camera definitions if need be,
|
|
e.g. for Perspective/Orthographic cameras we provide a custom implementation.
|
|
This transform assumes PyTorch3D coordinate system conventions for
|
|
both the NDC space and the input points.
|
|
This transform interfaces with the PyTorch3D renderer which assumes
|
|
input points to the renderer to be in NDC space.
|
|
"""
|
|
if self.in_ndc():
|
|
return Transform3d(device=self.device, dtype=torch.float32)
|
|
else:
|
|
# For custom cameras which can be defined in screen space,
|
|
# users might might have to implement the screen to NDC transform based
|
|
# on the definition of the camera parameters.
|
|
# See PerspectiveCameras/OrthographicCameras for an example.
|
|
# We don't flip xy because we assume that world points are in
|
|
# PyTorch3D coordinates, and thus conversion from screen to ndc
|
|
# is a mere scaling from image to [-1, 1] scale.
|
|
image_size = kwargs.get("image_size", self.get_image_size())
|
|
return get_screen_to_ndc_transform(
|
|
self, with_xyflip=False, image_size=image_size
|
|
)
|
|
|
|
def transform_points_ndc(
|
|
self, points, eps: Optional[float] = None, **kwargs
|
|
) -> torch.Tensor:
|
|
"""
|
|
Transforms points from PyTorch3D world/camera space to NDC space.
|
|
Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up.
|
|
Output points are in NDC space: +X left, +Y up, origin at image center.
|
|
Args:
|
|
points: torch tensor of shape (..., 3).
|
|
eps: If eps!=None, the argument is used to clamp the
|
|
divisor in the homogeneous normalization of the points
|
|
transformed to the ndc space. Please see
|
|
`transforms.Transform3d.transform_points` for details.
|
|
For `CamerasBase.transform_points`, setting `eps > 0`
|
|
stabilizes gradients since it leads to avoiding division
|
|
by excessively low numbers for points close to the
|
|
camera plane.
|
|
Returns
|
|
new_points: transformed points with the same shape as the input.
|
|
"""
|
|
world_to_ndc_transform = self.get_full_projection_transform(**kwargs)
|
|
if not self.in_ndc():
|
|
to_ndc_transform = self.get_ndc_camera_transform(**kwargs)
|
|
world_to_ndc_transform = world_to_ndc_transform.compose(to_ndc_transform)
|
|
|
|
return world_to_ndc_transform.transform_points(points, eps=eps)
|
|
|
|
def transform_points_screen(
|
|
self, points, eps: Optional[float] = None, **kwargs
|
|
) -> torch.Tensor:
|
|
"""
|
|
Transforms points from PyTorch3D world/camera space to screen space.
|
|
Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up.
|
|
Output points are in screen space: +X right, +Y down, origin at top left corner.
|
|
Args:
|
|
points: torch tensor of shape (..., 3).
|
|
eps: If eps!=None, the argument is used to clamp the
|
|
divisor in the homogeneous normalization of the points
|
|
transformed to the ndc space. Please see
|
|
`transforms.Transform3d.transform_points` for details.
|
|
For `CamerasBase.transform_points`, setting `eps > 0`
|
|
stabilizes gradients since it leads to avoiding division
|
|
by excessively low numbers for points close to the
|
|
camera plane.
|
|
Returns
|
|
new_points: transformed points with the same shape as the input.
|
|
"""
|
|
points_ndc = self.transform_points_ndc(points, eps=eps, **kwargs)
|
|
image_size = kwargs.get("image_size", self.get_image_size())
|
|
return get_ndc_to_screen_transform(
|
|
self, with_xyflip=True, image_size=image_size
|
|
).transform_points(points_ndc, eps=eps)
|
|
|
|
def clone(self):
|
|
"""
|
|
Returns a copy of `self`.
|
|
"""
|
|
cam_type = type(self)
|
|
other = cam_type(device=self.device)
|
|
return super().clone(other)
|
|
|
|
def is_perspective(self):
|
|
raise NotImplementedError()
|
|
|
|
def in_ndc(self):
|
|
"""
|
|
Specifies whether the camera is defined in NDC space
|
|
or in screen (image) space
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def get_znear(self):
|
|
return self.znear if hasattr(self, "znear") else None
|
|
|
|
def get_image_size(self):
|
|
"""
|
|
Returns the image size, if provided, expected in the form of (height, width)
|
|
The image size is used for conversion of projected points to screen coordinates.
|
|
"""
|
|
return self.image_size if hasattr(self, "image_size") else None
|
|
|
|
def __getitem__(
|
|
self, index: Union[int, List[int], torch.LongTensor]
|
|
) -> "CamerasBase":
|
|
"""
|
|
Override for the __getitem__ method in TensorProperties which needs to be
|
|
refactored.
|
|
Args:
|
|
index: an int/list/long tensor used to index all the fields in the cameras given by
|
|
self._FIELDS.
|
|
Returns:
|
|
if `index` is an index int/list/long tensor return an instance of the current
|
|
cameras class with only the values at the selected index.
|
|
"""
|
|
|
|
kwargs = {}
|
|
|
|
if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)):
|
|
msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r"
|
|
raise ValueError(msg % type(index))
|
|
|
|
if isinstance(index, int):
|
|
index = [index]
|
|
|
|
if max(index) >= len(self):
|
|
raise ValueError(f"Index {max(index)} is out of bounds for select cameras")
|
|
|
|
for field in self._FIELDS:
|
|
val = getattr(self, field, None)
|
|
if val is None:
|
|
continue
|
|
|
|
# e.g. "in_ndc" is set as attribute "_in_ndc" on the class
|
|
# but provided as "in_ndc" on initialization
|
|
if field.startswith("_"):
|
|
field = field[1:]
|
|
|
|
if isinstance(val, (str, bool)):
|
|
kwargs[field] = val
|
|
elif isinstance(val, torch.Tensor):
|
|
# In the init, all inputs will be converted to
|
|
# tensors before setting as attributes
|
|
kwargs[field] = val[index]
|
|
else:
|
|
raise ValueError(f"Field {field} type is not supported for indexing")
|
|
|
|
kwargs["device"] = self.device
|
|
return self.__class__(**kwargs)
|
|
|
|
class FoVPerspectiveCameras(CamerasBase):
|
|
"""
|
|
A class which stores a batch of parameters to generate a batch of
|
|
projection matrices by specifying the field of view.
|
|
The definition of the parameters follow the OpenGL perspective camera.
|
|
|
|
The extrinsics of the camera (R and T matrices) can also be set in the
|
|
initializer or passed in to `get_full_projection_transform` to get
|
|
the full transformation from world -> ndc.
|
|
|
|
The `transform_points` method calculates the full world -> ndc transform
|
|
and then applies it to the input points.
|
|
|
|
The transforms can also be returned separately as Transform3d objects.
|
|
|
|
* Setting the Aspect Ratio for Non Square Images *
|
|
|
|
If the desired output image size is non square (i.e. a tuple of (H, W) where H != W)
|
|
the aspect ratio needs special consideration: There are two aspect ratios
|
|
to be aware of:
|
|
- the aspect ratio of each pixel
|
|
- the aspect ratio of the output image
|
|
The `aspect_ratio` setting in the FoVPerspectiveCameras sets the
|
|
pixel aspect ratio. When using this camera with the differentiable rasterizer
|
|
be aware that in the rasterizer we assume square pixels, but allow
|
|
variable image aspect ratio (i.e rectangle images).
|
|
|
|
In most cases you will want to set the camera `aspect_ratio=1.0`
|
|
(i.e. square pixels) and only vary the output image dimensions in pixels
|
|
for rasterization.
|
|
"""
|
|
|
|
# For __getitem__
|
|
_FIELDS = (
|
|
"K",
|
|
"znear",
|
|
"zfar",
|
|
"aspect_ratio",
|
|
"fov",
|
|
"R",
|
|
"T",
|
|
"degrees",
|
|
)
|
|
|
|
_SHARED_FIELDS = ("degrees",)
|
|
|
|
def __init__(
|
|
self,
|
|
znear=1.0,
|
|
zfar=100.0,
|
|
aspect_ratio=1.0,
|
|
fov=60.0,
|
|
degrees: bool = True,
|
|
R: torch.Tensor = _R,
|
|
T: torch.Tensor = _T,
|
|
K: Optional[torch.Tensor] = None,
|
|
device: Device = "cpu",
|
|
) -> None:
|
|
"""
|
|
|
|
Args:
|
|
znear: near clipping plane of the view frustrum.
|
|
zfar: far clipping plane of the view frustrum.
|
|
aspect_ratio: aspect ratio of the image pixels.
|
|
1.0 indicates square pixels.
|
|
fov: field of view angle of the camera.
|
|
degrees: bool, set to True if fov is specified in degrees.
|
|
R: Rotation matrix of shape (N, 3, 3)
|
|
T: Translation matrix of shape (N, 3)
|
|
K: (optional) A calibration matrix of shape (N, 4, 4)
|
|
If provided, don't need znear, zfar, fov, aspect_ratio, degrees
|
|
device: Device (as str or torch.device)
|
|
"""
|
|
# The initializer formats all inputs to torch tensors and broadcasts
|
|
# all the inputs to have the same batch dimension where necessary.
|
|
super().__init__(
|
|
device=device,
|
|
znear=znear,
|
|
zfar=zfar,
|
|
aspect_ratio=aspect_ratio,
|
|
fov=fov,
|
|
R=R,
|
|
T=T,
|
|
K=K,
|
|
)
|
|
|
|
# No need to convert to tensor or broadcast.
|
|
self.degrees = degrees
|
|
|
|
def compute_projection_matrix(
|
|
self, znear, zfar, fov, aspect_ratio, degrees: bool
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute the calibration matrix K of shape (N, 4, 4)
|
|
|
|
Args:
|
|
znear: near clipping plane of the view frustrum.
|
|
zfar: far clipping plane of the view frustrum.
|
|
fov: field of view angle of the camera.
|
|
aspect_ratio: aspect ratio of the image pixels.
|
|
1.0 indicates square pixels.
|
|
degrees: bool, set to True if fov is specified in degrees.
|
|
|
|
Returns:
|
|
torch.FloatTensor of the calibration matrix with shape (N, 4, 4)
|
|
"""
|
|
K = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32)
|
|
ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
|
|
if degrees:
|
|
fov = (np.pi / 180) * fov
|
|
|
|
if not torch.is_tensor(fov):
|
|
fov = torch.tensor(fov, device=self.device)
|
|
tanHalfFov = torch.tan((fov / 2))
|
|
max_y = tanHalfFov * znear
|
|
min_y = -max_y
|
|
max_x = max_y * aspect_ratio
|
|
min_x = -max_x
|
|
|
|
# NOTE: In OpenGL the projection matrix changes the handedness of the
|
|
# coordinate frame. i.e the NDC space positive z direction is the
|
|
# camera space negative z direction. This is because the sign of the z
|
|
# in the projection matrix is set to -1.0.
|
|
# In pytorch3d we maintain a right handed coordinate system throughout
|
|
# so the so the z sign is 1.0.
|
|
z_sign = 1.0
|
|
|
|
K[:, 0, 0] = 2.0 * znear / (max_x - min_x)
|
|
K[:, 1, 1] = 2.0 * znear / (max_y - min_y)
|
|
K[:, 0, 2] = (max_x + min_x) / (max_x - min_x)
|
|
K[:, 1, 2] = (max_y + min_y) / (max_y - min_y)
|
|
K[:, 3, 2] = z_sign * ones
|
|
|
|
# NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
|
|
# is at the near clipping plane and z = 1 when the point is at the far
|
|
# clipping plane.
|
|
K[:, 2, 2] = z_sign * zfar / (zfar - znear)
|
|
K[:, 2, 3] = -(zfar * znear) / (zfar - znear)
|
|
|
|
return K
|
|
|
|
def get_projection_transform(self, **kwargs) -> Transform3d:
|
|
"""
|
|
Calculate the perspective projection matrix with a symmetric
|
|
viewing frustrum. Use column major order.
|
|
The viewing frustrum will be projected into ndc, s.t.
|
|
(max_x, max_y) -> (+1, +1)
|
|
(min_x, min_y) -> (-1, -1)
|
|
|
|
Args:
|
|
**kwargs: parameters for the projection can be passed in as keyword
|
|
arguments to override the default values set in `__init__`.
|
|
|
|
Return:
|
|
a Transform3d object which represents a batch of projection
|
|
matrices of shape (N, 4, 4)
|
|
|
|
.. code-block:: python
|
|
|
|
h1 = (max_y + min_y)/(max_y - min_y)
|
|
w1 = (max_x + min_x)/(max_x - min_x)
|
|
tanhalffov = tan((fov/2))
|
|
s1 = 1/tanhalffov
|
|
s2 = 1/(tanhalffov * (aspect_ratio))
|
|
|
|
# To map z to the range [0, 1] use:
|
|
f1 = far / (far - near)
|
|
f2 = -(far * near) / (far - near)
|
|
|
|
# Projection matrix
|
|
K = [
|
|
[s1, 0, w1, 0],
|
|
[0, s2, h1, 0],
|
|
[0, 0, f1, f2],
|
|
[0, 0, 1, 0],
|
|
]
|
|
"""
|
|
K = kwargs.get("K", self.K)
|
|
if K is not None:
|
|
if K.shape != (self._N, 4, 4):
|
|
msg = "Expected K to have shape of (%r, 4, 4)"
|
|
raise ValueError(msg % (self._N))
|
|
else:
|
|
K = self.compute_projection_matrix(
|
|
kwargs.get("znear", self.znear),
|
|
kwargs.get("zfar", self.zfar),
|
|
kwargs.get("fov", self.fov),
|
|
kwargs.get("aspect_ratio", self.aspect_ratio),
|
|
kwargs.get("degrees", self.degrees),
|
|
)
|
|
|
|
# Transpose the projection matrix as PyTorch3D transforms use row vectors.
|
|
transform = Transform3d(
|
|
matrix=K.transpose(1, 2).contiguous(), device=self.device
|
|
)
|
|
return transform
|
|
|
|
def unproject_points(
|
|
self,
|
|
xy_depth: torch.Tensor,
|
|
world_coordinates: bool = True,
|
|
scaled_depth_input: bool = False,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
""">!
|
|
FoV cameras further allow for passing depth in world units
|
|
(`scaled_depth_input=False`) or in the [0, 1]-normalized units
|
|
(`scaled_depth_input=True`)
|
|
|
|
Args:
|
|
scaled_depth_input: If `True`, assumes the input depth is in
|
|
the [0, 1]-normalized units. If `False` the input depth is in
|
|
the world units.
|
|
"""
|
|
|
|
# obtain the relevant transformation to ndc
|
|
if world_coordinates:
|
|
to_ndc_transform = self.get_full_projection_transform()
|
|
else:
|
|
to_ndc_transform = self.get_projection_transform()
|
|
|
|
if scaled_depth_input:
|
|
# the input is scaled depth, so we don't have to do anything
|
|
xy_sdepth = xy_depth
|
|
else:
|
|
# parse out important values from the projection matrix
|
|
K_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix()
|
|
# parse out f1, f2 from K_matrix
|
|
unsqueeze_shape = [1] * xy_depth.dim()
|
|
unsqueeze_shape[0] = K_matrix.shape[0]
|
|
f1 = K_matrix[:, 2, 2].reshape(unsqueeze_shape)
|
|
f2 = K_matrix[:, 3, 2].reshape(unsqueeze_shape)
|
|
# get the scaled depth
|
|
sdepth = (f1 * xy_depth[..., 2:3] + f2) / xy_depth[..., 2:3]
|
|
# concatenate xy + scaled depth
|
|
xy_sdepth = torch.cat((xy_depth[..., 0:2], sdepth), dim=-1)
|
|
|
|
# unproject with inverse of the projection
|
|
unprojection_transform = to_ndc_transform.inverse()
|
|
return unprojection_transform.transform_points(xy_sdepth)
|
|
|
|
def is_perspective(self):
|
|
return True
|
|
|
|
def in_ndc(self):
|
|
return True
|
|
|
|
#######################################################################################
|
|
## ██████╗ ███████╗███████╗██╗███╗ ██╗██╗████████╗██╗ ██████╗ ███╗ ██╗███████╗ ##
|
|
## ██╔══██╗██╔════╝██╔════╝██║████╗ ██║██║╚══██╔══╝██║██╔═══██╗████╗ ██║██╔════╝ ##
|
|
## ██║ ██║█████╗ █████╗ ██║██╔██╗ ██║██║ ██║ ██║██║ ██║██╔██╗ ██║███████╗ ##
|
|
## ██║ ██║██╔══╝ ██╔══╝ ██║██║╚██╗██║██║ ██║ ██║██║ ██║██║╚██╗██║╚════██║ ##
|
|
## ██████╔╝███████╗██║ ██║██║ ╚████║██║ ██║ ██║╚██████╔╝██║ ╚████║███████║ ##
|
|
## ╚═════╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═══╝╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚══════╝ ##
|
|
#######################################################################################
|
|
|
|
def make_device(device: Device) -> torch.device:
|
|
"""
|
|
Makes an actual torch.device object from the device specified as
|
|
either a string or torch.device object. If the device is `cuda` without
|
|
a specific index, the index of the current device is assigned.
|
|
Args:
|
|
device: Device (as str or torch.device)
|
|
Returns:
|
|
A matching torch.device object
|
|
"""
|
|
device = torch.device(device) if isinstance(device, str) else device
|
|
if device.type == "cuda" and device.index is None: # pyre-ignore[16]
|
|
# If cuda but with no index, then the current cuda device is indicated.
|
|
# In that case, we fix to that device
|
|
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
|
return device
|
|
|
|
def get_device(x, device: Optional[Device] = None) -> torch.device:
|
|
"""
|
|
Gets the device of the specified variable x if it is a tensor, or
|
|
falls back to a default CPU device otherwise. Allows overriding by
|
|
providing an explicit device.
|
|
Args:
|
|
x: a torch.Tensor to get the device from or another type
|
|
device: Device (as str or torch.device) to fall back to
|
|
Returns:
|
|
A matching torch.device object
|
|
"""
|
|
|
|
# User overrides device
|
|
if device is not None:
|
|
return make_device(device)
|
|
|
|
# Set device based on input tensor
|
|
if torch.is_tensor(x):
|
|
return x.device
|
|
|
|
# Default device is cpu
|
|
return torch.device("cpu")
|
|
|
|
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Return the rotation matrices for one of the rotations about an axis
|
|
of which Euler angles describe, for each value of the angle given.
|
|
|
|
Args:
|
|
axis: Axis label "X" or "Y or "Z".
|
|
angle: any shape tensor of Euler angles in radians
|
|
|
|
Returns:
|
|
Rotation matrices as tensor of shape (..., 3, 3).
|
|
"""
|
|
|
|
cos = torch.cos(angle)
|
|
sin = torch.sin(angle)
|
|
one = torch.ones_like(angle)
|
|
zero = torch.zeros_like(angle)
|
|
|
|
if axis == "X":
|
|
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
|
elif axis == "Y":
|
|
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
|
elif axis == "Z":
|
|
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
|
else:
|
|
raise ValueError("letter must be either X, Y or Z.")
|
|
|
|
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
|
|
|
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
|
|
"""
|
|
Convert rotations given as Euler angles in radians to rotation matrices.
|
|
|
|
Args:
|
|
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
|
convention: Convention string of three uppercase letters from
|
|
{"X", "Y", and "Z"}.
|
|
|
|
Returns:
|
|
Rotation matrices as tensor of shape (..., 3, 3).
|
|
"""
|
|
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
|
raise ValueError("Invalid input euler angles.")
|
|
if len(convention) != 3:
|
|
raise ValueError("Convention must have 3 letters.")
|
|
if convention[1] in (convention[0], convention[2]):
|
|
raise ValueError(f"Invalid convention {convention}.")
|
|
for letter in convention:
|
|
if letter not in ("X", "Y", "Z"):
|
|
raise ValueError(f"Invalid letter {letter} in convention string.")
|
|
matrices = [
|
|
_axis_angle_rotation(c, e)
|
|
for c, e in zip(convention, torch.unbind(euler_angles, -1))
|
|
]
|
|
# return functools.reduce(torch.matmul, matrices)
|
|
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
|
|
|
|
def _broadcast_bmm(a, b) -> torch.Tensor:
|
|
"""
|
|
Batch multiply two matrices and broadcast if necessary.
|
|
|
|
Args:
|
|
a: torch tensor of shape (P, K) or (M, P, K)
|
|
b: torch tensor of shape (N, K, K)
|
|
|
|
Returns:
|
|
a and b broadcast multiplied. The output batch dimension is max(N, M).
|
|
|
|
To broadcast transforms across a batch dimension if M != N then
|
|
expect that either M = 1 or N = 1. The tensor with batch dimension 1 is
|
|
expanded to have shape N or M.
|
|
"""
|
|
if a.dim() == 2:
|
|
a = a[None]
|
|
if len(a) != len(b):
|
|
if not ((len(a) == 1) or (len(b) == 1)):
|
|
msg = "Expected batch dim for bmm to be equal or 1; got %r, %r"
|
|
raise ValueError(msg % (a.shape, b.shape))
|
|
if len(a) == 1:
|
|
a = a.expand(len(b), -1, -1)
|
|
if len(b) == 1:
|
|
b = b.expand(len(a), -1, -1)
|
|
return a.bmm(b)
|
|
|
|
def _safe_det_3x3(t: torch.Tensor):
|
|
"""
|
|
Fast determinant calculation for a batch of 3x3 matrices.
|
|
Note, result of this function might not be the same as `torch.det()`.
|
|
The differences might be in the last significant digit.
|
|
Args:
|
|
t: Tensor of shape (N, 3, 3).
|
|
Returns:
|
|
Tensor of shape (N) with determinants.
|
|
"""
|
|
|
|
det = (
|
|
t[..., 0, 0] * (t[..., 1, 1] * t[..., 2, 2] - t[..., 1, 2] * t[..., 2, 1])
|
|
- t[..., 0, 1] * (t[..., 1, 0] * t[..., 2, 2] - t[..., 2, 0] * t[..., 1, 2])
|
|
+ t[..., 0, 2] * (t[..., 1, 0] * t[..., 2, 1] - t[..., 2, 0] * t[..., 1, 1])
|
|
)
|
|
|
|
return det
|
|
|
|
def get_world_to_view_transform(
|
|
R: torch.Tensor = _R, T: torch.Tensor = _T
|
|
) -> Transform3d:
|
|
"""
|
|
This function returns a Transform3d representing the transformation
|
|
matrix to go from world space to view space by applying a rotation and
|
|
a translation.
|
|
PyTorch3D uses the same convention as Hartley & Zisserman.
|
|
I.e., for camera extrinsic parameters R (rotation) and T (translation),
|
|
we map a 3D point `X_world` in world coordinates to
|
|
a point `X_cam` in camera coordinates with:
|
|
`X_cam = X_world R + T`
|
|
Args:
|
|
R: (N, 3, 3) matrix representing the rotation.
|
|
T: (N, 3) matrix representing the translation.
|
|
Returns:
|
|
a Transform3d object which represents the composed RT transformation.
|
|
"""
|
|
# TODO: also support the case where RT is specified as one matrix
|
|
# of shape (N, 4, 4).
|
|
|
|
if T.shape[0] != R.shape[0]:
|
|
msg = "Expected R, T to have the same batch dimension; got %r, %r"
|
|
raise ValueError(msg % (R.shape[0], T.shape[0]))
|
|
if T.dim() != 2 or T.shape[1:] != (3,):
|
|
msg = "Expected T to have shape (N, 3); got %r"
|
|
raise ValueError(msg % repr(T.shape))
|
|
if R.dim() != 3 or R.shape[1:] != (3, 3):
|
|
msg = "Expected R to have shape (N, 3, 3); got %r"
|
|
raise ValueError(msg % repr(R.shape))
|
|
|
|
# Create a Transform3d object
|
|
T_ = Translate(T, device=T.device)
|
|
R_ = Rotate(R, device=R.device)
|
|
return R_.compose(T_)
|
|
|
|
def _check_valid_rotation_matrix(R, tol: float = 1e-7) -> None:
|
|
"""
|
|
Determine if R is a valid rotation matrix by checking it satisfies the
|
|
following conditions:
|
|
|
|
``RR^T = I and det(R) = 1``
|
|
|
|
Args:
|
|
R: an (N, 3, 3) matrix
|
|
|
|
Returns:
|
|
None
|
|
|
|
Emits a warning if R is an invalid rotation matrix.
|
|
"""
|
|
N = R.shape[0]
|
|
eye = torch.eye(3, dtype=R.dtype, device=R.device)
|
|
eye = eye.view(1, 3, 3).expand(N, -1, -1)
|
|
orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol)
|
|
det_R = _safe_det_3x3(R)
|
|
no_distortion = torch.allclose(det_R, torch.ones_like(det_R))
|
|
if not (orthogonal and no_distortion):
|
|
msg = "R is not a valid rotation matrix"
|
|
warnings.warn(msg)
|
|
return
|
|
|
|
def format_tensor(
|
|
input,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Device = "cpu",
|
|
) -> torch.Tensor:
|
|
"""
|
|
Helper function for converting a scalar value to a tensor.
|
|
Args:
|
|
input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor
|
|
dtype: data type for the input
|
|
device: Device (as str or torch.device) on which the tensor should be placed.
|
|
Returns:
|
|
input_vec: torch tensor with optional added batch dimension.
|
|
"""
|
|
device_ = make_device(device)
|
|
if not torch.is_tensor(input):
|
|
input = torch.tensor(input, dtype=dtype, device=device_)
|
|
elif not input.device.type.startswith('mps'):
|
|
input = torch.tensor(input, dtype=torch.float32,device=device_)
|
|
|
|
if input.dim() == 0:
|
|
input = input.view(1)
|
|
|
|
if input.device == device_:
|
|
return input
|
|
|
|
input = input.to(device=device)
|
|
return input
|
|
|
|
def convert_to_tensors_and_broadcast(
|
|
*args,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Device = "cpu",
|
|
):
|
|
"""
|
|
Helper function to handle parsing an arbitrary number of inputs (*args)
|
|
which all need to have the same batch dimension.
|
|
The output is a list of tensors.
|
|
Args:
|
|
*args: an arbitrary number of inputs
|
|
Each of the values in `args` can be one of the following
|
|
- Python scalar
|
|
- Torch scalar
|
|
- Torch tensor of shape (N, K_i) or (1, K_i) where K_i are
|
|
an arbitrary number of dimensions which can vary for each
|
|
value in args. In this case each input is broadcast to a
|
|
tensor of shape (N, K_i)
|
|
dtype: data type to use when creating new tensors.
|
|
device: torch device on which the tensors should be placed.
|
|
Output:
|
|
args: A list of tensors of shape (N, K_i)
|
|
"""
|
|
# Convert all inputs to tensors with a batch dimension
|
|
args_1d = [format_tensor(c, dtype, device) for c in args]
|
|
|
|
# Find broadcast size
|
|
sizes = [c.shape[0] for c in args_1d]
|
|
N = max(sizes)
|
|
|
|
args_Nd = []
|
|
for c in args_1d:
|
|
if c.shape[0] != 1 and c.shape[0] != N:
|
|
msg = "Got non-broadcastable sizes %r" % sizes
|
|
raise ValueError(msg)
|
|
|
|
# Expand broadcast dim and keep non broadcast dims the same size
|
|
expand_sizes = (N,) + (-1,) * len(c.shape[1:])
|
|
args_Nd.append(c.expand(*expand_sizes))
|
|
|
|
return args_Nd
|
|
|
|
def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
|
"""
|
|
Helper function for _handle_input.
|
|
|
|
Args:
|
|
c: Python scalar, torch scalar, or 1D torch tensor
|
|
|
|
Returns:
|
|
c_vec: 1D torch tensor
|
|
"""
|
|
if not torch.is_tensor(c):
|
|
c = torch.tensor(c, dtype=dtype, device=device)
|
|
if c.dim() == 0:
|
|
c = c.view(1)
|
|
if c.device != device or c.dtype != dtype:
|
|
c = c.to(device=device, dtype=dtype)
|
|
return c
|
|
|
|
def _handle_input(
|
|
x,
|
|
y,
|
|
z,
|
|
dtype: torch.dtype,
|
|
device: Optional[Device],
|
|
name: str,
|
|
allow_singleton: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Helper function to handle parsing logic for building transforms. The output
|
|
is always a tensor of shape (N, 3), but there are several types of allowed
|
|
input.
|
|
|
|
Case I: Single Matrix
|
|
In this case x is a tensor of shape (N, 3), and y and z are None. Here just
|
|
return x.
|
|
|
|
Case II: Vectors and Scalars
|
|
In this case each of x, y, and z can be one of the following
|
|
- Python scalar
|
|
- Torch scalar
|
|
- Torch tensor of shape (N, 1) or (1, 1)
|
|
In this case x, y and z are broadcast to tensors of shape (N, 1)
|
|
and concatenated to a tensor of shape (N, 3)
|
|
|
|
Case III: Singleton (only if allow_singleton=True)
|
|
In this case y and z are None, and x can be one of the following:
|
|
- Python scalar
|
|
- Torch scalar
|
|
- Torch tensor of shape (N, 1) or (1, 1)
|
|
Here x will be duplicated 3 times, and we return a tensor of shape (N, 3)
|
|
|
|
Returns:
|
|
xyz: Tensor of shape (N, 3)
|
|
"""
|
|
device_ = get_device(x, device)
|
|
# If x is actually a tensor of shape (N, 3) then just return it
|
|
if torch.is_tensor(x) and x.dim() == 2:
|
|
if x.shape[1] != 3:
|
|
msg = "Expected tensor of shape (N, 3); got %r (in %s)"
|
|
raise ValueError(msg % (x.shape, name))
|
|
if y is not None or z is not None:
|
|
msg = "Expected y and z to be None (in %s)" % name
|
|
raise ValueError(msg)
|
|
return x.to(device=device_, dtype=dtype)
|
|
|
|
if allow_singleton and y is None and z is None:
|
|
y = x
|
|
z = x
|
|
|
|
# Convert all to 1D tensors
|
|
xyz = [_handle_coord(c, dtype, device_) for c in [x, y, z]]
|
|
|
|
# Broadcast and concatenate
|
|
sizes = [c.shape[0] for c in xyz]
|
|
N = max(sizes)
|
|
for c in xyz:
|
|
if c.shape[0] != 1 and c.shape[0] != N:
|
|
msg = "Got non-broadcastable sizes %r (in %s)" % (sizes, name)
|
|
raise ValueError(msg)
|
|
xyz = [c.expand(N) for c in xyz]
|
|
xyz = torch.stack(xyz, dim=1)
|
|
return xyz
|