Skip to content

Commit

Permalink
Avoid to keep in memory lengths and bins for ImplicitronRayBundle
Browse files Browse the repository at this point in the history
Summary:
Convert ImplicitronRayBundle to a "classic" class instead of a dataclass. This change is introduced as a way to preserve the ImplicitronRayBundle interface while allowing two outcomes:
- init lengths arguments is now a Optional[torch.Tensor] instead of torch.Tensor
- lengths is now a property which returns a `torch.Tensor`. The lengths property will either recompute lengths from bins or return the stored _lengths. `_lenghts` is None if bins is set. It saves us a bit of memory.

Reviewed By: shapovalov

Differential Revision: D46686094

fbshipit-source-id: 3c75c0947216476ebff542b6f552d311024a679b
  • Loading branch information
EmGarr authored and facebook-github-bot committed Jul 6, 2023
1 parent 3d011a9 commit 9446d91
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 60 deletions.
84 changes: 43 additions & 41 deletions pytorch3d/implicitron/models/renderer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from __future__ import annotations

import dataclasses

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
Expand All @@ -29,7 +27,6 @@ class RenderSamplingMode(Enum):
FULL_GRID = "full_grid"


@dataclasses.dataclass
class ImplicitronRayBundle:
"""
Parametrizes points along projection rays by storing ray `origins`,
Expand Down Expand Up @@ -69,53 +66,58 @@ class ImplicitronRayBundle:
lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
pixel_radii_2d: An optional tensor of shape `(..., 1)`
base radii of the conical frustums.
Raises:
ValueError: If either bins or lengths are not provided.
ValueError: If bins is provided and the last dim is inferior or equal to 1.
"""

origins: torch.Tensor
directions: torch.Tensor
lengths: torch.Tensor
xys: torch.Tensor
camera_ids: Optional[torch.LongTensor] = None
camera_counts: Optional[torch.LongTensor] = None
bins: Optional[torch.Tensor] = None
pixel_radii_2d: Optional[torch.Tensor] = None

@classmethod
def from_bins(
cls,
def __init__(
self,
origins: torch.Tensor,
directions: torch.Tensor,
bins: torch.Tensor,
lengths: Optional[torch.Tensor],
xys: torch.Tensor,
**kwargs,
) -> "ImplicitronRayBundle":
"""
Creates a new instance from bins instead of lengths.
Attributes:
origins: A tensor of shape `(..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(..., 3)` containing the direction
vectors of sampling rays in world coords. They don't have to be normalized;
they define unit vectors in the respective 1D coordinate systems; see
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
bins: A tensor of shape `(..., num_points_per_ray + 1)`
containing the bins at which the rays are sampled. In this case
lengths is equal to the midpoints of bins `(..., num_points_per_ray)`.
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle
Returns:
An instance of ImplicitronRayBundle.
"""

if bins.shape[-1] <= 1:
camera_ids: Optional[torch.LongTensor] = None,
camera_counts: Optional[torch.LongTensor] = None,
bins: Optional[torch.Tensor] = None,
pixel_radii_2d: Optional[torch.Tensor] = None,
):
if bins is not None and bins.shape[-1] <= 1:
raise ValueError(
"The last dim of bins must be at least superior or equal to 2."
)
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
lengths = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)

return cls(origins, directions, lengths, xys, bins=bins, **kwargs)
if bins is None and lengths is None:
raise ValueError(
"Please set either bins or lengths to initialize an ImplicitronRayBundle."
)

self.origins = origins
self.directions = directions
self._lengths = lengths if bins is None else None
self.xys = xys
self.bins = bins
self.pixel_radii_2d = pixel_radii_2d
self.camera_ids = camera_ids
self.camera_counts = camera_counts

@property
def lengths(self) -> torch.Tensor:
if self.bins is not None:
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
# pyre-ignore
return torch.lerp(self.bins[..., :-1], self.bins[..., 1:], 0.5)
return self._lengths

@lengths.setter
def lengths(self, value):
if self.bins is not None:
raise ValueError(
"If the bins attribute is not None you cannot set the lengths attribute."
)
else:
self._lengths = value

def is_packed(self) -> bool:
"""
Expand Down
13 changes: 6 additions & 7 deletions pytorch3d/implicitron/models/renderer/lstm_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
import copy
import logging
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -102,12 +102,11 @@ def forward(
)

# jitter the initial depths
ray_bundle_t = dataclasses.replace(
ray_bundle,
lengths=(
ray_bundle.lengths
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
),

ray_bundle_t = copy.copy(ray_bundle)
ray_bundle_t.lengths = (
ray_bundle.lengths
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
)

states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
Expand Down
15 changes: 8 additions & 7 deletions pytorch3d/implicitron/models/renderer/ray_point_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy

import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
Expand Down Expand Up @@ -106,14 +108,13 @@ def forward(
z_vals = z_samples
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)

kwargs_ray = dict(vars(input_ray_bundle))
ray_bundle = copy.copy(input_ray_bundle)
if input_ray_bundle.bins is None:
kwargs_ray["lengths"] = z_vals
return ImplicitronRayBundle(**kwargs_ray)
kwargs_ray["bins"] = z_vals
del kwargs_ray["lengths"]
return ImplicitronRayBundle.from_bins(**kwargs_ray)
ray_bundle.lengths = z_vals
else:
ray_bundle.bins = z_vals

return ray_bundle


def apply_blurpool_on_weights(weights) -> torch.Tensor:
Expand Down
5 changes: 3 additions & 2 deletions pytorch3d/implicitron/models/renderer/ray_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,12 @@ def forward(
elif self.cast_ray_bundle_as_cone:
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
return ImplicitronRayBundle.from_bins(
return ImplicitronRayBundle(
directions=ray_bundle.directions,
origins=ray_bundle.origins,
bins=ray_bundle.lengths,
lengths=None,
xys=ray_bundle.xys,
bins=ray_bundle.lengths,
pixel_radii_2d=pixel_radii_2d,
)

Expand Down
45 changes: 42 additions & 3 deletions tests/implicitron/test_models_renderer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,62 @@
class TestRendererBase(TestCaseMixin, unittest.TestCase):
def test_implicitron_from_bins(self) -> None:
bins = torch.randn(2, 3, 4, 5)
ray_bundle = ImplicitronRayBundle.from_bins(
ray_bundle = ImplicitronRayBundle(
origins=None,
directions=None,
lengths=None,
xys=None,
bins=bins,
)
self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1]))
self.assertClose(ray_bundle.bins, bins)

def test_implicitron_raise_value_error_bins_is_set_and_try_to_set_lengths(
self,
) -> None:
with self.assertRaises(ValueError) as context:
ray_bundle = ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3),
lengths=None,
xys=torch.rand(2, 3, 4, 2),
bins=torch.rand(2, 3, 4, 1),
)
ray_bundle.lengths = torch.empty(2)
self.assertEqual(
str(context.exception),
"If the bins attribute is not None you cannot set the lengths attribute.",
)

def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
with self.assertRaises(ValueError):
ImplicitronRayBundle.from_bins(
with self.assertRaises(ValueError) as context:
ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3),
lengths=None,
xys=torch.rand(2, 3, 4, 2),
bins=torch.rand(2, 3, 4, 1),
)
self.assertEqual(
str(context.exception),
"The last dim of bins must be at least superior or equal to 2.",
)

def test_implicitron_raise_value_error_if_neither_bins_or_lengths_provided(
self,
) -> None:
with self.assertRaises(ValueError) as context:
ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3),
lengths=None,
xys=torch.rand(2, 3, 4, 2),
bins=None,
)
self.assertEqual(
str(context.exception),
"Please set either bins or lengths to initialize an ImplicitronRayBundle.",
)

def test_conical_frustum_to_gaussian(self) -> None:
origins = torch.zeros(3, 3, 3)
Expand Down

0 comments on commit 9446d91

Please sign in to comment.