Skip to content

Commit 9a14f54

Browse files
patricklabatutfacebook-github-bot
authored andcommitted
Fix circular import
Summary: This fixes a recently introduced circular import: the problem went unnoticed by having `pytorch3d.renderer` imported first... Reviewed By: bottler Differential Revision: D29686235 fbshipit-source-id: 4b9f2faecec2cc8347ee259cfc359dc9e4f67784
1 parent 5eec5e2 commit 9a14f54

File tree

4 files changed

+189
-139
lines changed

4 files changed

+189
-139
lines changed
+176
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from typing import Tuple
9+
10+
import torch
11+
12+
from ..transforms import matrix_to_rotation_6d
13+
from .cameras import PerspectiveCameras
14+
15+
16+
LOGGER = logging.getLogger(__name__)
17+
18+
19+
def _cameras_from_opencv_projection(
20+
R: torch.Tensor,
21+
tvec: torch.Tensor,
22+
camera_matrix: torch.Tensor,
23+
image_size: torch.Tensor,
24+
) -> PerspectiveCameras:
25+
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
26+
principal_point = camera_matrix[:, :2, 2]
27+
28+
# Retype the image_size correctly and flip to width, height.
29+
image_size_wh = image_size.to(R).flip(dims=(1,))
30+
31+
# Get the PyTorch3D focal length and principal point.
32+
focal_pytorch3d = focal_length / (0.5 * image_size_wh)
33+
p0_pytorch3d = -(principal_point / (0.5 * image_size_wh) - 1)
34+
35+
# For R, T we flip x, y axes (opencv screen space has an opposite
36+
# orientation of screen axes).
37+
# We also transpose R (opencv multiplies points from the opposite=left side).
38+
R_pytorch3d = R.clone().permute(0, 2, 1)
39+
T_pytorch3d = tvec.clone()
40+
R_pytorch3d[:, :, :2] *= -1
41+
T_pytorch3d[:, :2] *= -1
42+
43+
return PerspectiveCameras(
44+
R=R_pytorch3d,
45+
T=T_pytorch3d,
46+
focal_length=focal_pytorch3d,
47+
principal_point=p0_pytorch3d,
48+
)
49+
50+
51+
def _opencv_from_cameras_projection(
52+
cameras: PerspectiveCameras,
53+
image_size: torch.Tensor,
54+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
55+
R_pytorch3d = cameras.R.clone() # pyre-ignore
56+
T_pytorch3d = cameras.T.clone() # pyre-ignore
57+
focal_pytorch3d = cameras.focal_length
58+
p0_pytorch3d = cameras.principal_point
59+
T_pytorch3d[:, :2] *= -1
60+
R_pytorch3d[:, :, :2] *= -1
61+
tvec = T_pytorch3d
62+
R = R_pytorch3d.permute(0, 2, 1)
63+
64+
# Retype the image_size correctly and flip to width, height.
65+
image_size_wh = image_size.to(R).flip(dims=(1,))
66+
67+
principal_point = (-p0_pytorch3d + 1.0) * (0.5 * image_size_wh) # pyre-ignore
68+
focal_length = focal_pytorch3d * (0.5 * image_size_wh)
69+
70+
camera_matrix = torch.zeros_like(R)
71+
camera_matrix[:, :2, 2] = principal_point
72+
camera_matrix[:, 2, 2] = 1.0
73+
camera_matrix[:, 0, 0] = focal_length[:, 0]
74+
camera_matrix[:, 1, 1] = focal_length[:, 1]
75+
return R, tvec, camera_matrix
76+
77+
78+
def _pulsar_from_opencv_projection(
79+
R: torch.Tensor,
80+
tvec: torch.Tensor,
81+
camera_matrix: torch.Tensor,
82+
image_size: torch.Tensor,
83+
znear: float = 0.1,
84+
) -> torch.Tensor:
85+
assert len(camera_matrix.size()) == 3, "This function requires batched inputs!"
86+
assert len(R.size()) == 3, "This function requires batched inputs!"
87+
assert len(tvec.size()) in (2, 3), "This function reuqires batched inputs!"
88+
89+
# Validate parameters.
90+
image_size_wh = image_size.to(R).flip(dims=(1,))
91+
assert torch.all(
92+
image_size_wh > 0
93+
), "height and width must be positive but min is: %s" % (
94+
str(image_size_wh.min().item())
95+
)
96+
assert (
97+
camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
98+
), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
99+
camera_matrix.size(1),
100+
camera_matrix.size(2),
101+
)
102+
assert (
103+
R.size(1) == 3 and R.size(2) == 3
104+
), "Incorrect R shape: expected 3x3 but got %dx%d" % (
105+
R.size(1),
106+
R.size(2),
107+
)
108+
if len(tvec.size()) == 2:
109+
tvec = tvec.unsqueeze(2)
110+
assert (
111+
tvec.size(1) == 3 and tvec.size(2) == 1
112+
), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
113+
tvec.size(1),
114+
tvec.size(2),
115+
)
116+
# Check batch size.
117+
batch_size = camera_matrix.size(0)
118+
assert R.size(0) == batch_size, "Expected R to have batch size %d. Has size %d." % (
119+
batch_size,
120+
R.size(0),
121+
)
122+
assert (
123+
tvec.size(0) == batch_size
124+
), "Expected tvec to have batch size %d. Has size %d." % (
125+
batch_size,
126+
tvec.size(0),
127+
)
128+
# Check image sizes.
129+
image_w = image_size_wh[0, 0]
130+
image_h = image_size_wh[0, 1]
131+
assert torch.all(
132+
image_size_wh[:, 0] == image_w
133+
), "All images in a batch must have the same width!"
134+
assert torch.all(
135+
image_size_wh[:, 1] == image_h
136+
), "All images in a batch must have the same height!"
137+
# Focal length.
138+
fx = camera_matrix[:, 0, 0].unsqueeze(1)
139+
fy = camera_matrix[:, 1, 1].unsqueeze(1)
140+
# Check that we introduce less than 1% error by averaging the focal lengths.
141+
fx_y = fx / fy
142+
if torch.any(fx_y > 1.01) or torch.any(fx_y < 0.99):
143+
LOGGER.warning(
144+
"Pulsar only supports a single focal lengths. For converting OpenCV "
145+
"focal lengths, we average them for x and y directions. "
146+
"The focal lengths for x and y you provided differ by more than 1%, "
147+
"which means this could introduce a noticeable error."
148+
)
149+
f = (fx + fy) / 2
150+
# Normalize f into normalized device coordinates.
151+
focal_length_px = f / image_w
152+
# Transfer into focal_length and sensor_width.
153+
focal_length = torch.tensor([znear - 1e-5], dtype=torch.float32, device=R.device)
154+
focal_length = focal_length[None, :].repeat(batch_size, 1)
155+
sensor_width = focal_length / focal_length_px
156+
# Principal point.
157+
cx = camera_matrix[:, 0, 2].unsqueeze(1)
158+
cy = camera_matrix[:, 1, 2].unsqueeze(1)
159+
# Transfer principal point offset into centered offset.
160+
cx = -(cx - image_w / 2)
161+
cy = cy - image_h / 2
162+
# Concatenate to final vector.
163+
param = torch.cat([focal_length, sensor_width, cx, cy], dim=1)
164+
R_trans = R.permute(0, 2, 1)
165+
cam_pos = -torch.bmm(R_trans, tvec).squeeze(2)
166+
cam_rot = matrix_to_rotation_6d(R_trans)
167+
cam_params = torch.cat([cam_pos, cam_rot, param], dim=1)
168+
return cam_params
169+
170+
171+
def _pulsar_from_cameras_projection(
172+
cameras: PerspectiveCameras,
173+
image_size: torch.Tensor,
174+
) -> torch.Tensor:
175+
opencv_R, opencv_T, opencv_K = _opencv_from_cameras_projection(cameras, image_size)
176+
return _pulsar_from_opencv_projection(opencv_R, opencv_T, opencv_K, image_size)

pytorch3d/renderer/points/pulsar/unified.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313

14-
from ....utils import pulsar_from_cameras_projection
14+
from ...camera_conversions import _pulsar_from_cameras_projection
1515
from ...cameras import (
1616
FoVOrthographicCameras,
1717
FoVPerspectiveCameras,
@@ -378,7 +378,7 @@ def _extract_extrinsics(
378378
size_tensor = torch.tensor(
379379
[[self.renderer._renderer.height, self.renderer._renderer.width]]
380380
)
381-
pulsar_cam = pulsar_from_cameras_projection(tmp_cams, size_tensor)
381+
pulsar_cam = _pulsar_from_cameras_projection(tmp_cams, size_tensor)
382382
cam_pos = pulsar_cam[0, :3]
383383
cam_rot = pulsar_cam[0, 3:9]
384384
return cam_pos, cam_rot

pytorch3d/utils/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from .camera_conversions import (
88
cameras_from_opencv_projection,
99
opencv_from_cameras_projection,
10-
pulsar_from_opencv_projection,
1110
pulsar_from_cameras_projection,
11+
pulsar_from_opencv_projection,
1212
)
1313
from .ico_sphere import ico_sphere
1414
from .torus import torus

pytorch3d/utils/camera_conversions.py

+10-136
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import logging
87
from typing import Tuple
98

109
import torch
1110

1211
from ..renderer import PerspectiveCameras
13-
from ..transforms import matrix_to_rotation_6d
14-
15-
16-
LOGGER = logging.getLogger(__name__)
12+
from ..renderer.camera_conversions import (
13+
_cameras_from_opencv_projection,
14+
_opencv_from_cameras_projection,
15+
_pulsar_from_cameras_projection,
16+
_pulsar_from_opencv_projection,
17+
)
1718

1819

1920
def cameras_from_opencv_projection(
@@ -58,30 +59,7 @@ def cameras_from_opencv_projection(
5859
Returns:
5960
cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention.
6061
"""
61-
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
62-
principal_point = camera_matrix[:, :2, 2]
63-
64-
# Retype the image_size correctly and flip to width, height.
65-
image_size_wh = image_size.to(R).flip(dims=(1,))
66-
67-
# Get the PyTorch3D focal length and principal point.
68-
focal_pytorch3d = focal_length / (0.5 * image_size_wh)
69-
p0_pytorch3d = -(principal_point / (0.5 * image_size_wh) - 1)
70-
71-
# For R, T we flip x, y axes (opencv screen space has an opposite
72-
# orientation of screen axes).
73-
# We also transpose R (opencv multiplies points from the opposite=left side).
74-
R_pytorch3d = R.clone().permute(0, 2, 1)
75-
T_pytorch3d = tvec.clone()
76-
R_pytorch3d[:, :, :2] *= -1
77-
T_pytorch3d[:, :2] *= -1
78-
79-
return PerspectiveCameras(
80-
R=R_pytorch3d,
81-
T=T_pytorch3d,
82-
focal_length=focal_pytorch3d,
83-
principal_point=p0_pytorch3d,
84-
)
62+
return _cameras_from_opencv_projection(R, tvec, camera_matrix, image_size)
8563

8664

8765
def opencv_from_cameras_projection(
@@ -114,27 +92,7 @@ def opencv_from_cameras_projection(
11492
tvec: A batch of translation vectors of shape `(N, 3)`.
11593
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
11694
"""
117-
R_pytorch3d = cameras.R.clone() # pyre-ignore
118-
T_pytorch3d = cameras.T.clone() # pyre-ignore
119-
focal_pytorch3d = cameras.focal_length
120-
p0_pytorch3d = cameras.principal_point
121-
T_pytorch3d[:, :2] *= -1
122-
R_pytorch3d[:, :, :2] *= -1
123-
tvec = T_pytorch3d
124-
R = R_pytorch3d.permute(0, 2, 1)
125-
126-
# Retype the image_size correctly and flip to width, height.
127-
image_size_wh = image_size.to(R).flip(dims=(1,))
128-
129-
principal_point = (-p0_pytorch3d + 1.0) * (0.5 * image_size_wh) # pyre-ignore
130-
focal_length = focal_pytorch3d * (0.5 * image_size_wh)
131-
132-
camera_matrix = torch.zeros_like(R)
133-
camera_matrix[:, :2, 2] = principal_point
134-
camera_matrix[:, 2, 2] = 1.0
135-
camera_matrix[:, 0, 0] = focal_length[:, 0]
136-
camera_matrix[:, 1, 1] = focal_length[:, 1]
137-
return R, tvec, camera_matrix
95+
return _opencv_from_cameras_projection(cameras, image_size)
13896

13997

14098
def pulsar_from_opencv_projection(
@@ -170,90 +128,7 @@ def pulsar_from_opencv_projection(
170128
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
171129
c_x, c_y).
172130
"""
173-
assert len(camera_matrix.size()) == 3, "This function requires batched inputs!"
174-
assert len(R.size()) == 3, "This function requires batched inputs!"
175-
assert len(tvec.size()) in (2, 3), "This function reuqires batched inputs!"
176-
177-
# Validate parameters.
178-
image_size_wh = image_size.to(R).flip(dims=(1,))
179-
assert torch.all(
180-
image_size_wh > 0
181-
), "height and width must be positive but min is: %s" % (
182-
str(image_size_wh.min().item())
183-
)
184-
assert (
185-
camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
186-
), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
187-
camera_matrix.size(1),
188-
camera_matrix.size(2),
189-
)
190-
assert (
191-
R.size(1) == 3 and R.size(2) == 3
192-
), "Incorrect R shape: expected 3x3 but got %dx%d" % (
193-
R.size(1),
194-
R.size(2),
195-
)
196-
if len(tvec.size()) == 2:
197-
tvec = tvec.unsqueeze(2)
198-
assert (
199-
tvec.size(1) == 3 and tvec.size(2) == 1
200-
), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
201-
tvec.size(1),
202-
tvec.size(2),
203-
)
204-
# Check batch size.
205-
batch_size = camera_matrix.size(0)
206-
assert R.size(0) == batch_size, "Expected R to have batch size %d. Has size %d." % (
207-
batch_size,
208-
R.size(0),
209-
)
210-
assert (
211-
tvec.size(0) == batch_size
212-
), "Expected tvec to have batch size %d. Has size %d." % (
213-
batch_size,
214-
tvec.size(0),
215-
)
216-
# Check image sizes.
217-
image_w = image_size_wh[0, 0]
218-
image_h = image_size_wh[0, 1]
219-
assert torch.all(
220-
image_size_wh[:, 0] == image_w
221-
), "All images in a batch must have the same width!"
222-
assert torch.all(
223-
image_size_wh[:, 1] == image_h
224-
), "All images in a batch must have the same height!"
225-
# Focal length.
226-
fx = camera_matrix[:, 0, 0].unsqueeze(1)
227-
fy = camera_matrix[:, 1, 1].unsqueeze(1)
228-
# Check that we introduce less than 1% error by averaging the focal lengths.
229-
fx_y = fx / fy
230-
if torch.any(fx_y > 1.01) or torch.any(fx_y < 0.99):
231-
LOGGER.warning(
232-
"Pulsar only supports a single focal lengths. For converting OpenCV "
233-
"focal lengths, we average them for x and y directions. "
234-
"The focal lengths for x and y you provided differ by more than 1%, "
235-
"which means this could introduce a noticeable error."
236-
)
237-
f = (fx + fy) / 2
238-
# Normalize f into normalized device coordinates.
239-
focal_length_px = f / image_w
240-
# Transfer into focal_length and sensor_width.
241-
focal_length = torch.tensor([znear - 1e-5], dtype=torch.float32, device=R.device)
242-
focal_length = focal_length[None, :].repeat(batch_size, 1)
243-
sensor_width = focal_length / focal_length_px
244-
# Principal point.
245-
cx = camera_matrix[:, 0, 2].unsqueeze(1)
246-
cy = camera_matrix[:, 1, 2].unsqueeze(1)
247-
# Transfer principal point offset into centered offset.
248-
cx = -(cx - image_w / 2)
249-
cy = cy - image_h / 2
250-
# Concatenate to final vector.
251-
param = torch.cat([focal_length, sensor_width, cx, cy], dim=1)
252-
R_trans = R.permute(0, 2, 1)
253-
cam_pos = -torch.bmm(R_trans, tvec).squeeze(2)
254-
cam_rot = matrix_to_rotation_6d(R_trans)
255-
cam_params = torch.cat([cam_pos, cam_rot, param], dim=1)
256-
return cam_params
131+
return _pulsar_from_opencv_projection(R, tvec, camera_matrix, image_size, znear)
257132

258133

259134
def pulsar_from_cameras_projection(
@@ -281,5 +156,4 @@ def pulsar_from_cameras_projection(
281156
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
282157
c_x, c_y).
283158
"""
284-
opencv_R, opencv_T, opencv_K = opencv_from_cameras_projection(cameras, image_size)
285-
return pulsar_from_opencv_projection(opencv_R, opencv_T, opencv_K, image_size)
159+
return _pulsar_from_cameras_projection(cameras, image_size)

0 commit comments

Comments
 (0)