From 72c3a0ebe59a7bc058c8f3a081ed9a2b7cb132bf Mon Sep 17 00:00:00 2001 From: Darijan Gudelj Date: Mon, 5 Sep 2022 06:26:06 -0700 Subject: [PATCH] raybundle input to ImplicitFunctions -> api unification Summary: Currently some implicit functions in implicitron take a raybundle, others take ray_points_world. raybundle is what they really need. However, the raybundle is going to become a bit more flexible later, as it will contain different numbers of rays for each camera. Reviewed By: bottler Differential Revision: D39173751 fbshipit-source-id: ebc038e426d22e831e67a18ba64655d8a61e1eb9 --- .../models/implicit_function/base.py | 1 + .../implicit_function/idr_feature_field.py | 15 ++++++---- .../neural_radiance_field.py | 1 + .../scene_representation_networks.py | 2 ++ .../models/implicit_function/utils.py | 30 +++++++++++++++++++ .../models/renderer/lstm_renderer.py | 2 +- .../models/renderer/multipass_ea.py | 2 +- .../models/renderer/sdf_renderer.py | 20 +++++++------ tests/implicitron/test_srn.py | 6 ++-- 9 files changed, 60 insertions(+), 19 deletions(-) diff --git a/pytorch3d/implicitron/models/implicit_function/base.py b/pytorch3d/implicitron/models/implicit_function/base.py index 742fde169..2e0c77984 100644 --- a/pytorch3d/implicitron/models/implicit_function/base.py +++ b/pytorch3d/implicitron/models/implicit_function/base.py @@ -19,6 +19,7 @@ def __init__(self): @abstractmethod def forward( self, + *, ray_bundle: RayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, diff --git a/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py index 09a9fa304..557ba1387 100644 --- a/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py +++ b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py @@ -3,14 +3,15 @@ # implicit_differentiable_renderer.py # Copyright (c) 2020 Lior Yariv import math -from typing import Tuple +from typing import Optional, Tuple import torch from pytorch3d.implicitron.tools.config import registry -from pytorch3d.renderer.implicit import HarmonicEmbedding +from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle from torch import nn from .base import ImplicitFunctionBase +from .utils import get_rays_points_world @registry.register @@ -125,14 +126,16 @@ def __post_init__(self): # inconsistently. def forward( self, - # ray_bundle: RayBundle, - rays_points_world: torch.Tensor, # TODO: unify the APIs + *, + ray_bundle: Optional[RayBundle] = None, + rays_points_world: Optional[torch.Tensor] = None, fun_viewpool=None, global_code=None, + **kwargs, ): # this field only uses point locations - # rays_points_world = ray_bundle_to_ray_points(ray_bundle) # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3] + rays_points_world = get_rays_points_world(ray_bundle, rays_points_world) if rays_points_world.numel() == 0 or ( self.embed_fn is None and fun_viewpool is None and global_code is None @@ -179,4 +182,4 @@ def forward( # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. x = self.softplus(x) - return x # TODO: unify the APIs + return x diff --git a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py index 78a3c8a44..d325c798c 100644 --- a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py +++ b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py @@ -129,6 +129,7 @@ def allows_multiple_passes() -> bool: def forward( self, + *, ray_bundle: RayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, diff --git a/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py index b7f10a95a..c701c54c0 100644 --- a/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py +++ b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py @@ -349,6 +349,7 @@ def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None: def forward( self, + *, ray_bundle: RayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, @@ -408,6 +409,7 @@ def hypernet_tweak_args(cls, type, args: DictConfig) -> None: def forward( self, + *, ray_bundle: RayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, diff --git a/pytorch3d/implicitron/models/implicit_function/utils.py b/pytorch3d/implicitron/models/implicit_function/utils.py index 9a26aff3f..9b401c489 100644 --- a/pytorch3d/implicitron/models/implicit_function/utils.py +++ b/pytorch3d/implicitron/models/implicit_function/utils.py @@ -10,7 +10,9 @@ import torch.nn.functional as F from pytorch3d.common.compat import prod +from pytorch3d.renderer import ray_bundle_to_ray_points from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.renderer.implicit import RayBundle def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor): @@ -185,3 +187,31 @@ def interpolate_volume( **kwargs, ) return out[:, :, :, 0, 0].permute(0, 2, 1) + + +def get_rays_points_world( + ray_bundle: Optional[RayBundle] = None, + rays_points_world: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Converts the ray_bundle to rays_points_world if rays_points_world is not defined + and raises error if both are defined. + + Args: + ray_bundle: A RayBundle object or None + rays_points_world: A torch.Tensor representing ray points converted to + world coordinates + Returns: + A torch.Tensor representing ray points converted to world coordinates + of shape [minibatch x ... x pts_per_ray x 3]. + """ + if rays_points_world is not None and ray_bundle is not None: + raise ValueError( + "Cannot define both rays_points_world and ray_bundle," + + " one has to be None." + ) + if rays_points_world is not None: + return rays_points_world + if ray_bundle is not None: + return ray_bundle_to_ray_points(ray_bundle) + raise ValueError("ray_bundle and rays_points_world cannot both be None") diff --git a/pytorch3d/implicitron/models/renderer/lstm_renderer.py b/pytorch3d/implicitron/models/renderer/lstm_renderer.py index 64cd89cf0..c5ce094f5 100644 --- a/pytorch3d/implicitron/models/renderer/lstm_renderer.py +++ b/pytorch3d/implicitron/models/renderer/lstm_renderer.py @@ -118,7 +118,7 @@ def forward( # eval the raymarching function raymarch_features, _ = implicit_function( - ray_bundle_t, + ray_bundle=ray_bundle_t, raymarch_features=None, ) if self.verbose: diff --git a/pytorch3d/implicitron/models/renderer/multipass_ea.py b/pytorch3d/implicitron/models/renderer/multipass_ea.py index 89bceae17..61cf0d4c3 100644 --- a/pytorch3d/implicitron/models/renderer/multipass_ea.py +++ b/pytorch3d/implicitron/models/renderer/multipass_ea.py @@ -148,7 +148,7 @@ def _run_raymarcher( ) output = self.raymarcher( - *implicit_functions[0](ray_bundle), + *implicit_functions[0](ray_bundle=ray_bundle), ray_lengths=ray_bundle.lengths, density_noise_std=density_noise_std, ) diff --git a/pytorch3d/implicitron/models/renderer/sdf_renderer.py b/pytorch3d/implicitron/models/renderer/sdf_renderer.py index 15e07dc97..2f0e626c9 100644 --- a/pytorch3d/implicitron/models/renderer/sdf_renderer.py +++ b/pytorch3d/implicitron/models/renderer/sdf_renderer.py @@ -101,7 +101,7 @@ def forward( object_mask = object_mask.bool() implicit_function = implicit_functions[0] - implicit_function_gradient = functools.partial(gradient, implicit_function) + implicit_function_gradient = functools.partial(_gradient, implicit_function) # object_mask: silhouette of the object batch_size, *spatial_size, _ = ray_bundle.lengths.shape @@ -113,7 +113,7 @@ def forward( with torch.no_grad(), evaluating(implicit_function): points, network_object_mask, dists = self.ray_tracer( - sdf=lambda x: implicit_function(x)[ + sdf=lambda x: implicit_function(rays_points_world=x)[ :, 0 ], # TODO: get rid of this wrapper cam_loc=cam_loc, @@ -125,7 +125,7 @@ def forward( depth = dists.reshape(batch_size, num_pixels, 1) points = (cam_loc + depth * ray_dirs).reshape(-1, 3) - sdf_output = implicit_function(points)[:, 0:1] + sdf_output = implicit_function(rays_points_world=points)[:, 0:1] # NOTE most of the intermediate variables are flattened for # no apparent reason (here and in the ray tracer) ray_dirs = ray_dirs.reshape(-1, 3) @@ -157,7 +157,7 @@ def forward( points_all = torch.cat([surface_points, eikonal_points], dim=0) - output = implicit_function(surface_points) + output = implicit_function(rays_points_world=surface_points) surface_sdf_values = output[ :N, 0:1 ].detach() # how is it different from sdf_output? @@ -181,7 +181,9 @@ def forward( grad_theta = None empty_render = differentiable_surface_points.shape[0] == 0 - features = implicit_function(differentiable_surface_points)[None, :, 1:] + features = implicit_function(rays_points_world=differentiable_surface_points)[ + None, :, 1: + ] normals_full = features.new_zeros( batch_size, *spatial_size, 3, requires_grad=empty_render ) @@ -260,13 +262,13 @@ def _sample_network( @torch.enable_grad() -def gradient(module, x): - x.requires_grad_(True) - y = module.forward(x)[:, :1] +def _gradient(module, rays_points_world): + rays_points_world.requires_grad_(True) + y = module.forward(rays_points_world=rays_points_world)[:, :1] d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, - inputs=x, + inputs=rays_points_world, grad_outputs=d_output, create_graph=True, retain_graph=True, diff --git a/tests/implicitron/test_srn.py b/tests/implicitron/test_srn.py index a50c341c3..f6905ef4d 100644 --- a/tests/implicitron/test_srn.py +++ b/tests/implicitron/test_srn.py @@ -44,7 +44,7 @@ def test_srn_implicit_function(self): implicit_function = SRNImplicitFunction() device = torch.device("cpu") bundle = self._get_bundle(device=device) - rays_densities, rays_colors = implicit_function(bundle) + rays_densities, rays_colors = implicit_function(ray_bundle=bundle) out_features = implicit_function.raymarch_function.out_features self.assertEqual( rays_densities.shape, @@ -62,7 +62,9 @@ def test_srn_hypernet_implicit_function(self): implicit_function.to(device) global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device) bundle = self._get_bundle(device=device) - rays_densities, rays_colors = implicit_function(bundle, global_code=global_code) + rays_densities, rays_colors = implicit_function( + ray_bundle=bundle, global_code=global_code + ) out_features = implicit_function.hypernet.out_features self.assertEqual( rays_densities.shape,