Skip to content

Commit dc28b61

Browse files
davnov134facebook-github-bot
authored andcommitted
Generation of test camera trajectories
Summary: Implements methods for generating trajectories of test cameras. Reviewed By: nikhilaravi Differential Revision: D26100869 fbshipit-source-id: cf2b61a34d4c749cd8cba881e97f6c388e57d1f8
1 parent 9751f1f commit dc28b61

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
import math
3+
from typing import Tuple
4+
5+
import torch
6+
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
7+
8+
9+
def generate_eval_video_cameras(
10+
train_dataset,
11+
n_eval_cams: int = 100,
12+
trajectory_type: str = "figure_eight",
13+
trajectory_scale: float = 0.2,
14+
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
15+
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
16+
) -> dict:
17+
"""
18+
Generate a camera trajectory for visualizing a NeRF model.
19+
20+
Args:
21+
train_dataset: The training dataset object.
22+
n_eval_cams: Number of cameras in the trajectory.
23+
trajectory_type: The type of the camera trajectory. Can be one of:
24+
circular: Rotating around the center of the scene at a fixed radius.
25+
figure_eight: Figure-of-8 trajectory around the center of the
26+
central camera of the training dataset.
27+
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
28+
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
29+
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
30+
of a figure-eight knot
31+
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
32+
trajectory_scale: The extent of the trajectory.
33+
up: The "up" vector of the scene (=the normal of the scene floor).
34+
Active for the `trajectory_type="circular"`.
35+
scene_center: The center of the scene in world coordinates which all
36+
the cameras from the generated trajectory look at.
37+
Returns:
38+
Dictionary of camera instances which can be used as the test dataset
39+
"""
40+
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
41+
cam_centers = torch.cat(
42+
[e["camera"].get_camera_center() for e in train_dataset]
43+
)
44+
# get the nearest camera center to the mean of centers
45+
mean_camera_idx = (
46+
((cam_centers - cam_centers.mean(dim=0)[None]) ** 2)
47+
.sum(dim=1)
48+
.min(dim=0)
49+
.indices
50+
)
51+
# generate the knot trajectory in canonical coords
52+
time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams]
53+
if trajectory_type == "trefoil_knot":
54+
traj = _trefoil_knot(time)
55+
elif trajectory_type == "figure_eight_knot":
56+
traj = _figure_eight_knot(time)
57+
elif trajectory_type == "figure_eight":
58+
traj = _figure_eight(time)
59+
traj[:, 2] -= traj[:, 2].max()
60+
61+
# transform the canonical knot to the coord frame of the mean camera
62+
traj_trans = (
63+
train_dataset[mean_camera_idx]["camera"]
64+
.get_world_to_view_transform()
65+
.inverse()
66+
)
67+
traj_trans = traj_trans.scale(cam_centers.std(dim=0).mean() * trajectory_scale)
68+
traj = traj_trans.transform_points(traj)
69+
70+
elif trajectory_type == "circular":
71+
cam_centers = torch.cat(
72+
[e["camera"].get_camera_center() for e in train_dataset]
73+
)
74+
75+
# fit plane to the camera centers
76+
plane_mean = cam_centers.mean(dim=0)
77+
cam_centers_c = cam_centers - plane_mean[None]
78+
79+
if up is not None:
80+
# us the up vector instad of the plane through the camera centers
81+
plane_normal = torch.FloatTensor(up)
82+
else:
83+
cov = (cam_centers_c.t() @ cam_centers_c) / cam_centers_c.shape[0]
84+
_, e_vec = torch.symeig(cov, eigenvectors=True)
85+
plane_normal = e_vec[:, 0]
86+
87+
plane_dist = (plane_normal[None] * cam_centers_c).sum(dim=-1)
88+
cam_centers_on_plane = cam_centers_c - plane_dist[:, None] * plane_normal[None]
89+
90+
cov = (
91+
cam_centers_on_plane.t() @ cam_centers_on_plane
92+
) / cam_centers_on_plane.shape[0]
93+
_, e_vec = torch.symeig(cov, eigenvectors=True)
94+
traj_radius = (cam_centers_on_plane ** 2).sum(dim=1).sqrt().mean()
95+
angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams)
96+
traj = traj_radius * torch.stack(
97+
(torch.zeros_like(angle), angle.cos(), angle.sin()), dim=-1
98+
)
99+
traj = traj @ e_vec.t() + plane_mean[None]
100+
101+
else:
102+
raise ValueError(f"Uknown trajectory_type {trajectory_type}.")
103+
104+
# point all cameras towards the center of the scene
105+
R, T = look_at_view_transform(
106+
eye=traj,
107+
at=(scene_center,), # (1, 3)
108+
up=(up,), # (1, 3)
109+
device=traj.device,
110+
)
111+
112+
# get the average focal length and principal point
113+
focal = torch.cat([e["camera"].focal_length for e in train_dataset]).mean(dim=0)
114+
p0 = torch.cat([e["camera"].principal_point for e in train_dataset]).mean(dim=0)
115+
116+
# assemble the dataset
117+
test_dataset = [
118+
{
119+
"image": None,
120+
"camera": PerspectiveCameras(
121+
focal_length=focal[None],
122+
principal_point=p0[None],
123+
R=R_[None],
124+
T=T_[None],
125+
),
126+
"camera_idx": i,
127+
}
128+
for i, (R_, T_) in enumerate(zip(R, T))
129+
]
130+
131+
return test_dataset
132+
133+
134+
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
135+
x = (2 + (2 * t).cos()) * (3 * t).cos()
136+
y = (2 + (2 * t).cos()) * (3 * t).sin()
137+
z = (4 * t).sin() * z_scale
138+
return torch.stack((x, y, z), dim=-1)
139+
140+
141+
def _trefoil_knot(t: torch.Tensor, z_scale: float = 0.5):
142+
x = t.sin() + 2 * (2 * t).sin()
143+
y = t.cos() - 2 * (2 * t).cos()
144+
z = -(3 * t).sin() * z_scale
145+
return torch.stack((x, y, z), dim=-1)
146+
147+
148+
def _figure_eight(t: torch.Tensor, z_scale: float = 0.5):
149+
x = t.cos()
150+
y = (2 * t).sin() / 2
151+
z = t.sin() * z_scale
152+
return torch.stack((x, y, z), dim=-1)

0 commit comments

Comments
 (0)