Skip to content

Commit

Permalink
RayBundle visualization
Browse files Browse the repository at this point in the history
Summary: Extends plotly_vis to visualize `RayBundle`s.

Reviewed By: patricklabatut

Differential Revision: D29014098

fbshipit-source-id: 4dee426510a1fa53d4afefbe1bcdd003684c9932
  • Loading branch information
davnov134 authored and facebook-github-bot committed Jul 2, 2021
1 parent 62ff77b commit 4426a9d
Showing 1 changed file with 214 additions and 25 deletions.
239 changes: 214 additions & 25 deletions pytorch3d/vis/plotly_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,22 @@
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
from pytorch3d.renderer import TexturesVertex
from pytorch3d.renderer import TexturesVertex, RayBundle, ray_bundle_to_ray_points
from pytorch3d.renderer.camera_utils import camera_to_eye_at_up
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene


Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle]


def _get_struct_len(struct: Struct): # pragma: no cover
"""
Returns the length (usually corresponds to the batch size) of the input structure.
"""
return len(struct.directions) if isinstance(struct, RayBundle) else len(struct)


def get_camera_wireframe(scale: float = 0.3): # pragma: no cover
"""
Returns a wireframe of a 3D line-plot of a camera symbol.
Expand Down Expand Up @@ -55,18 +65,22 @@ class Lighting(NamedTuple): # pragma: no cover


def plot_scene(
plots: Dict[str, Dict[str, Union[Pointclouds, Meshes, CamerasBase]]],
plots: Dict[str, Dict[str, Struct]],
*,
viewpoint_cameras: Optional[CamerasBase] = None,
ncols: int = 1,
camera_scale: float = 0.3,
pointcloud_max_points: int = 20000,
pointcloud_marker_size: int = 1,
raybundle_max_rays: int = 20000,
raybundle_max_points_per_ray: int = 1000,
raybundle_ray_point_marker_size: int = 1,
raybundle_ray_line_width: int = 1,
**kwargs,
): # pragma: no cover
"""
Main function to visualize Meshes, Cameras and Pointclouds.
Plots input Pointclouds, Meshes, and Cameras data into named subplots,
Main function to visualize Cameras, Meshes, Pointclouds, and RayBundle.
Plots input Cameras, Meshes, Pointclouds, and RayBundle data into named subplots,
with named traces based on the dictionary keys. Cameras are
rendered at the camera center location using a wireframe.
Expand All @@ -87,6 +101,13 @@ def plot_scene(
pointcloud_max_points is used.
pointcloud_marker_size: the size of the points rendered by plotly
when plotting a pointcloud.
raybundle_max_rays: maximum number of rays of a RayBundle to visualize. Randomly
subsamples without replacement in case the number of rays is bigger than max_rays.
raybundle_max_points_per_ray: the maximum number of points per ray in RayBundle
to visualize. If more are present, a random sample of size
max_points_per_ray is used.
raybundle_ray_point_marker_size: the size of the ray points of a plotted RayBundle
raybundle_ray_line_width: the width of the plotted rays of a RayBundle
**kwargs: Accepts lighting (a Lighting object) and any of the args xaxis,
yaxis and zaxis which Plotly's scene accepts. Accepts axis_args,
which is an AxisArgs object that is applied to all 3 axes.
Expand Down Expand Up @@ -186,6 +207,18 @@ def plot_scene(
The above example will render one subplot with the mesh object
and two cameras.
RayBundle visualization is also supproted:
..code-block::python
cameras = PerspectiveCameras(...)
ray_bundle = RayBundle(origins=..., lengths=..., directions=..., xys=...)
fig = plot_scene({
"subplot1_title": {
"ray_bundle_trace_title": ray_bundle,
"cameras_trace_title": cameras,
},
})
fig.show()
For an example of using kwargs, see below:
..code-block::python
mesh = ...
Expand Down Expand Up @@ -264,11 +297,22 @@ def plot_scene(
_add_camera_trace(
fig, struct, trace_name, subplot_idx, ncols, camera_scale
)
elif isinstance(struct, RayBundle):
_add_ray_bundle_trace(
fig,
struct,
trace_name,
subplot_idx,
ncols,
raybundle_max_rays,
raybundle_max_points_per_ray,
raybundle_ray_point_marker_size,
raybundle_ray_line_width,
)
else:
raise ValueError(
"struct {} is not a Cameras, Meshes or Pointclouds object".format(
struct
)
"struct {} is not a Cameras, Meshes, Pointclouds,".format(struct)
+ " or RayBundle object."
)

# Ensure update for every subplot.
Expand Down Expand Up @@ -329,7 +373,8 @@ def plot_scene(

def plot_batch_individually(
batched_structs: Union[
List[Union[Meshes, Pointclouds, CamerasBase]], Meshes, Pointclouds, CamerasBase
List[Struct],
Struct,
],
*,
viewpoint_cameras: Optional[CamerasBase] = None,
Expand All @@ -340,26 +385,27 @@ def plot_batch_individually(
): # pragma: no cover
"""
This is a higher level plotting function than plot_scene, for plotting
Cameras, Meshes and Pointclouds in simple cases. The simplest use is to plot a
single Cameras, Meshes or Pointclouds object, where you just pass it in as a
one element list. This will plot each batch element in a separate subplot.
Cameras, Meshes, Pointclouds, and RayBundle in simple cases. The simplest use
is to plot a single Cameras, Meshes, Pointclouds, or a RayBundle object,
where you just pass it in as a one element list. This will plot each batch
element in a separate subplot.
More generally, you can supply multiple Cameras, Meshes or Pointclouds
More generally, you can supply multiple Cameras, Meshes, Pointclouds, or RayBundle
having the same batch size `n`. In this case, there will be `n` subplots,
each depicting the corresponding batch element of all the inputs.
In addition, you can include Cameras, Meshes and Pointclouds of size 1 in
In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in
the input. These will either be rendered in the first subplot
(if extend_struct is False), or in every subplot.
Args:
batched_structs: a list of Cameras, Meshes and/or Pointclouds to be rendered.
Each structure's corresponding batch element will be plotted in
a single subplot, resulting in n subplots for a batch of size n.
batched_structs: a list of Cameras, Meshes, Pointclouds, and RayBundle
to be rendered. Each structure's corresponding batch element will be
plotted in a single subplot, resulting in n subplots for a batch of size n.
Every struct should either have the same batch size or be of batch size 1.
See extend_struct and the description above for how batch size 1 structs
are handled. Also accepts a single Cameras, Meshes or Pointclouds object,
which will have each individual element plotted in its own subplot.
are handled. Also accepts a single Cameras, Meshes, Pointclouds, and RayBundle
object, which will have each individual element plotted in its own subplot.
viewpoint_cameras: an instance of a Cameras object providing a location
to view the plotly plot from. If the batch size is equal
to the number of subplots, it is a one to one mapping.
Expand Down Expand Up @@ -407,13 +453,14 @@ def plot_batch_individually(
return
max_size = 0
if isinstance(batched_structs, list):
max_size = max(len(s) for s in batched_structs)
max_size = max(_get_struct_len(s) for s in batched_structs)
for struct in batched_structs:
if len(struct) not in (1, max_size):
msg = "invalid batch size {} provided: {}".format(len(struct), struct)
struct_len = _get_struct_len(struct)
if struct_len not in (1, max_size):
msg = "invalid batch size {} provided: {}".format(struct_len, struct)
raise ValueError(msg)
else:
max_size = len(batched_structs)
max_size = _get_struct_len(batched_structs)

if max_size == 0:
msg = "No data is provided with at least one element"
Expand All @@ -437,7 +484,8 @@ def plot_batch_individually(
if isinstance(batched_structs, list):
for i, batched_struct in enumerate(batched_structs):
# check for whether this struct needs to be extended
if i >= len(batched_struct) and not extend_struct:
batched_struct_len = _get_struct_len(batched_struct)
if i >= batched_struct_len and not extend_struct:
continue
_add_struct_from_batch(
batched_struct, scene_num, subplot_title, scene_dictionary, i + 1
Expand All @@ -453,10 +501,10 @@ def plot_batch_individually(


def _add_struct_from_batch(
batched_struct: Union[CamerasBase, Meshes, Pointclouds],
batched_struct: Struct,
scene_num: int,
subplot_title: str,
scene_dictionary: Dict[str, Dict[str, Union[CamerasBase, Meshes, Pointclouds]]],
scene_dictionary: Dict[str, Dict[str, Struct]],
trace_idx: int = 1,
): # pragma: no cover
"""
Expand Down Expand Up @@ -492,6 +540,15 @@ def _add_struct_from_batch(
# torch.Tensor, torch.nn.Module]` is not a function.
T = T[t_idx].unsqueeze(0)
struct = CamerasBase(device=batched_struct.device, R=R, T=T)
elif isinstance(batched_struct, RayBundle):
# for RayBundle we treat the 1st dim as the batch index
struct_idx = min(scene_num, len(batched_struct.lengths) - 1)
struct = RayBundle(
**{
attr: getattr(batched_struct, attr)[struct_idx]
for attr in ["origins", "directions", "lengths", "xys"]
}
)
else: # batched meshes and pointclouds are indexable
struct_idx = min(scene_num, len(batched_struct) - 1)
struct = batched_struct[struct_idx]
Expand Down Expand Up @@ -702,6 +759,138 @@ def _add_camera_trace(
_update_axes_bounds(verts_center, max_expand, current_layout)


def _add_ray_bundle_trace(
fig: go.Figure,
ray_bundle: RayBundle,
trace_name: str,
subplot_idx: int,
ncols: int,
max_rays: int,
max_points_per_ray: int,
marker_size: int,
line_width: int,
): # pragma: no cover
"""
Adds a trace rendering a RayBundle object to the passed in figure, with
a given name and in a specific subplot.
Args:
fig: plotly figure to add the trace within.
cameras: the Cameras object to render. It can be batched.
trace_name: name to label the trace with.
subplot_idx: identifies the subplot, with 0 being the top left.
ncols: the number of subplots per row.
max_rays: maximum number of plotted rays in total. Randomly subsamples
without replacement in case the number of rays is bigger than max_rays.
max_points_per_ray: maximum number of points plotted per ray.
marker_size: the size of the ray point markers.
line_width: the width of the ray lines.
"""

n_pts_per_ray = ray_bundle.lengths.shape[-1]
n_rays = ray_bundle.lengths.shape[:-1].numel() # pyre-ignore[16]

# flatten all batches of rays into a single big bundle
ray_bundle_flat = RayBundle(
**{
attr: torch.flatten(getattr(ray_bundle, attr), start_dim=0, end_dim=-2)
for attr in ["origins", "directions", "lengths", "xys"]
}
)

# subsample the rays (if needed)
if n_rays > max_rays:
indices_rays = torch.randperm(n_rays)[:max_rays]
ray_bundle_flat = RayBundle(
**{
attr: getattr(ray_bundle_flat, attr)[indices_rays]
for attr in ["origins", "directions", "lengths", "xys"]
}
)

# make ray line endpoints
min_max_ray_depth = torch.stack(
[
ray_bundle_flat.lengths.min(dim=1).values, # pyre-ignore[16]
ray_bundle_flat.lengths.max(dim=1).values,
],
dim=-1,
)
ray_lines_endpoints = ray_bundle_to_ray_points(
ray_bundle_flat._replace(lengths=min_max_ray_depth)
)

# make the ray lines for plotly plotting
nan_tensor = torch.Tensor([[float("NaN")] * 3])
ray_lines = torch.empty(size=(1, 3))
for ray_line in ray_lines_endpoints:
# We combine the ray lines into a single tensor to plot them in a
# single trace. The NaNs are inserted between sets of ray lines
# so that the lines drawn by Plotly are not drawn between
# lines that belong to different rays.
ray_lines = torch.cat((ray_lines, nan_tensor, ray_line))
x, y, z = ray_lines.detach().cpu().numpy().T.astype(float)
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
fig.add_trace(
go.Scatter3d(
x=x,
y=y,
z=z,
marker={"size": 0.1},
line={"width": line_width},
name=trace_name,
),
row=row,
col=col,
)

# subsample the ray points (if needed)
if n_pts_per_ray > max_points_per_ray:
indices_ray_pts = torch.cat(
[
torch.randperm(n_pts_per_ray)[:max_points_per_ray] + ri * n_pts_per_ray
for ri in range(ray_bundle_flat.lengths.shape[0])
]
)
ray_bundle_flat = ray_bundle_flat._replace(
lengths=ray_bundle_flat.lengths.reshape(-1)[indices_ray_pts].reshape(
ray_bundle_flat.lengths.shape[0], -1
)
)

# plot the ray points
ray_points = (
ray_bundle_to_ray_points(ray_bundle_flat)
.view(-1, 3)
.detach()
.cpu()
.numpy()
.astype(float)
)
fig.add_trace(
go.Scatter3d(
x=ray_points[:, 0],
y=ray_points[:, 1],
z=ray_points[:, 2],
mode="markers",
name=trace_name + "_points",
marker={"size": marker_size},
),
row=row,
col=col,
)

# Access the current subplot's scene configuration
plot_scene = "scene" + str(subplot_idx + 1)
current_layout = fig["layout"][plot_scene]

# update the bounds of the axes for the current trace
all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3)
ray_points_center = all_ray_points.mean(dim=0)
max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item()
_update_axes_bounds(ray_points_center, float(max_expand), current_layout)


def _gen_fig_with_subplots(
batch_size: int, ncols: int, subplot_titles: List[str]
): # pragma: no cover
Expand Down

0 comments on commit 4426a9d

Please sign in to comment.