diff --git a/pytorch3d/implicitron/models/renderer/base.py b/pytorch3d/implicitron/models/renderer/base.py index 41d9ddbbf..3e891bf76 100644 --- a/pytorch3d/implicitron/models/renderer/base.py +++ b/pytorch3d/implicitron/models/renderer/base.py @@ -6,8 +6,6 @@ from __future__ import annotations -import dataclasses - from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum @@ -29,7 +27,6 @@ class RenderSamplingMode(Enum): FULL_GRID = "full_grid" -@dataclasses.dataclass class ImplicitronRayBundle: """ Parametrizes points along projection rays by storing ray `origins`, @@ -69,53 +66,58 @@ class ImplicitronRayBundle: lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`. pixel_radii_2d: An optional tensor of shape `(..., 1)` base radii of the conical frustums. + + Raises: + ValueError: If either bins or lengths are not provided. + ValueError: If bins is provided and the last dim is inferior or equal to 1. """ - origins: torch.Tensor - directions: torch.Tensor - lengths: torch.Tensor - xys: torch.Tensor - camera_ids: Optional[torch.LongTensor] = None - camera_counts: Optional[torch.LongTensor] = None - bins: Optional[torch.Tensor] = None - pixel_radii_2d: Optional[torch.Tensor] = None - - @classmethod - def from_bins( - cls, + def __init__( + self, origins: torch.Tensor, directions: torch.Tensor, - bins: torch.Tensor, + lengths: Optional[torch.Tensor], xys: torch.Tensor, - **kwargs, - ) -> "ImplicitronRayBundle": - """ - Creates a new instance from bins instead of lengths. - - Attributes: - origins: A tensor of shape `(..., 3)` denoting the - origins of the sampling rays in world coords. - directions: A tensor of shape `(..., 3)` containing the direction - vectors of sampling rays in world coords. They don't have to be normalized; - they define unit vectors in the respective 1D coordinate systems; see - documentation for :func:`ray_bundle_to_ray_points` for the conversion formula. - bins: A tensor of shape `(..., num_points_per_ray + 1)` - containing the bins at which the rays are sampled. In this case - lengths is equal to the midpoints of bins `(..., num_points_per_ray)`. - xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels - kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle - Returns: - An instance of ImplicitronRayBundle. - """ - - if bins.shape[-1] <= 1: + camera_ids: Optional[torch.LongTensor] = None, + camera_counts: Optional[torch.LongTensor] = None, + bins: Optional[torch.Tensor] = None, + pixel_radii_2d: Optional[torch.Tensor] = None, + ): + if bins is not None and bins.shape[-1] <= 1: raise ValueError( "The last dim of bins must be at least superior or equal to 2." ) - # equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient - lengths = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5) - return cls(origins, directions, lengths, xys, bins=bins, **kwargs) + if bins is None and lengths is None: + raise ValueError( + "Please set either bins or lengths to initialize an ImplicitronRayBundle." + ) + + self.origins = origins + self.directions = directions + self._lengths = lengths if bins is None else None + self.xys = xys + self.bins = bins + self.pixel_radii_2d = pixel_radii_2d + self.camera_ids = camera_ids + self.camera_counts = camera_counts + + @property + def lengths(self) -> torch.Tensor: + if self.bins is not None: + # equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient + # pyre-ignore + return torch.lerp(self.bins[..., :-1], self.bins[..., 1:], 0.5) + return self._lengths + + @lengths.setter + def lengths(self, value): + if self.bins is not None: + raise ValueError( + "If the bins attribute is not None you cannot set the lengths attribute." + ) + else: + self._lengths = value def is_packed(self) -> bool: """ diff --git a/pytorch3d/implicitron/models/renderer/lstm_renderer.py b/pytorch3d/implicitron/models/renderer/lstm_renderer.py index 8702d4e40..19848ed6e 100644 --- a/pytorch3d/implicitron/models/renderer/lstm_renderer.py +++ b/pytorch3d/implicitron/models/renderer/lstm_renderer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import dataclasses +import copy import logging from typing import List, Optional, Tuple @@ -102,12 +102,11 @@ def forward( ) # jitter the initial depths - ray_bundle_t = dataclasses.replace( - ray_bundle, - lengths=( - ray_bundle.lengths - + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std - ), + + ray_bundle_t = copy.copy(ray_bundle) + ray_bundle_t.lengths = ( + ray_bundle.lengths + + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std ) states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None] diff --git a/pytorch3d/implicitron/models/renderer/ray_point_refiner.py b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py index 0266f939d..b71574d23 100644 --- a/pytorch3d/implicitron/models/renderer/ray_point_refiner.py +++ b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy + import torch from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields @@ -106,14 +108,13 @@ def forward( z_vals = z_samples # Resort by depth. z_vals, _ = torch.sort(z_vals, dim=-1) - - kwargs_ray = dict(vars(input_ray_bundle)) + ray_bundle = copy.copy(input_ray_bundle) if input_ray_bundle.bins is None: - kwargs_ray["lengths"] = z_vals - return ImplicitronRayBundle(**kwargs_ray) - kwargs_ray["bins"] = z_vals - del kwargs_ray["lengths"] - return ImplicitronRayBundle.from_bins(**kwargs_ray) + ray_bundle.lengths = z_vals + else: + ray_bundle.bins = z_vals + + return ray_bundle def apply_blurpool_on_weights(weights) -> torch.Tensor: diff --git a/pytorch3d/implicitron/models/renderer/ray_sampler.py b/pytorch3d/implicitron/models/renderer/ray_sampler.py index d3f1c6b33..fe464f670 100644 --- a/pytorch3d/implicitron/models/renderer/ray_sampler.py +++ b/pytorch3d/implicitron/models/renderer/ray_sampler.py @@ -236,11 +236,12 @@ def forward( elif self.cast_ray_bundle_as_cone: pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width) pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw) - return ImplicitronRayBundle.from_bins( + return ImplicitronRayBundle( directions=ray_bundle.directions, origins=ray_bundle.origins, - bins=ray_bundle.lengths, + lengths=None, xys=ray_bundle.xys, + bins=ray_bundle.lengths, pixel_radii_2d=pixel_radii_2d, ) diff --git a/tests/implicitron/test_models_renderer_base.py b/tests/implicitron/test_models_renderer_base.py index 4b7827b1c..7c1f978d0 100644 --- a/tests/implicitron/test_models_renderer_base.py +++ b/tests/implicitron/test_models_renderer_base.py @@ -25,23 +25,62 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase): def test_implicitron_from_bins(self) -> None: bins = torch.randn(2, 3, 4, 5) - ray_bundle = ImplicitronRayBundle.from_bins( + ray_bundle = ImplicitronRayBundle( origins=None, directions=None, + lengths=None, xys=None, bins=bins, ) self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1])) self.assertClose(ray_bundle.bins, bins) + def test_implicitron_raise_value_error_bins_is_set_and_try_to_set_lengths( + self, + ) -> None: + with self.assertRaises(ValueError) as context: + ray_bundle = ImplicitronRayBundle( + origins=torch.rand(2, 3, 4, 3), + directions=torch.rand(2, 3, 4, 3), + lengths=None, + xys=torch.rand(2, 3, 4, 2), + bins=torch.rand(2, 3, 4, 1), + ) + ray_bundle.lengths = torch.empty(2) + self.assertEqual( + str(context.exception), + "If the bins attribute is not None you cannot set the lengths attribute.", + ) + def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None: - with self.assertRaises(ValueError): - ImplicitronRayBundle.from_bins( + with self.assertRaises(ValueError) as context: + ImplicitronRayBundle( origins=torch.rand(2, 3, 4, 3), directions=torch.rand(2, 3, 4, 3), + lengths=None, xys=torch.rand(2, 3, 4, 2), bins=torch.rand(2, 3, 4, 1), ) + self.assertEqual( + str(context.exception), + "The last dim of bins must be at least superior or equal to 2.", + ) + + def test_implicitron_raise_value_error_if_neither_bins_or_lengths_provided( + self, + ) -> None: + with self.assertRaises(ValueError) as context: + ImplicitronRayBundle( + origins=torch.rand(2, 3, 4, 3), + directions=torch.rand(2, 3, 4, 3), + lengths=None, + xys=torch.rand(2, 3, 4, 2), + bins=None, + ) + self.assertEqual( + str(context.exception), + "Please set either bins or lengths to initialize an ImplicitronRayBundle.", + ) def test_conical_frustum_to_gaussian(self) -> None: origins = torch.zeros(3, 3, 3)