Skip to content

Commit d4a1051

Browse files
Chris Lambertfacebook-github-bot
Chris Lambert
authored andcommitted
Remove pytorch3d's wrappers for eigh, solve, lstsq, qr
Summary: Remove the compat functions eigh, solve, lstsq, and qr. Migrate callers to use torch.linalg directly. Reviewed By: bottler Differential Revision: D39172949 fbshipit-source-id: 484230a553237808f06ee5cdfde64651cba91c4c
1 parent 9a1213e commit d4a1051

File tree

9 files changed

+11
-67
lines changed

9 files changed

+11
-67
lines changed

pytorch3d/common/compat.py

-47
Original file line numberDiff line numberDiff line change
@@ -14,53 +14,6 @@
1414
"""
1515

1616

17-
def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
18-
"""
19-
Like torch.linalg.solve, tries to return X
20-
such that AX=B, with A square.
21-
"""
22-
if hasattr(torch, "linalg") and hasattr(torch.linalg, "solve"):
23-
# PyTorch version >= 1.8.0
24-
return torch.linalg.solve(A, B)
25-
26-
# pyre-fixme[16]: `Tuple` has no attribute `solution`.
27-
return torch.solve(B, A).solution
28-
29-
30-
def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
31-
"""
32-
Like torch.linalg.lstsq, tries to return X
33-
such that AX=B.
34-
"""
35-
if hasattr(torch, "linalg") and hasattr(torch.linalg, "lstsq"):
36-
# PyTorch version >= 1.9
37-
return torch.linalg.lstsq(A, B).solution
38-
39-
solution = torch.lstsq(B, A).solution
40-
if A.shape[1] < A.shape[0]:
41-
return solution[: A.shape[1]]
42-
return solution
43-
44-
45-
def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
46-
"""
47-
Like torch.linalg.qr.
48-
"""
49-
if hasattr(torch, "linalg") and hasattr(torch.linalg, "qr"):
50-
# PyTorch version >= 1.9
51-
return torch.linalg.qr(A)
52-
return torch.qr(A)
53-
54-
55-
def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
56-
"""
57-
Like torch.linalg.eigh, assuming the argument is a symmetric real matrix.
58-
"""
59-
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
60-
return torch.linalg.eigh(A)
61-
return torch.symeig(A, eigenvectors=True)
62-
63-
6417
def meshgrid_ij(
6518
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
6619
) -> Tuple[torch.Tensor, ...]: # pragma: no cover

pytorch3d/implicitron/tools/circle_fitting.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Optional
1111

1212
import torch
13-
from pytorch3d.common.compat import eigh, lstsq
1413

1514

1615
def _get_rotation_to_best_fit_xy(
@@ -28,7 +27,7 @@ def _get_rotation_to_best_fit_xy(
2827
(3,3) tensor rotation matrix
2928
"""
3029
points_centered = points - centroid[None]
31-
return eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
30+
return torch.linalg.eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
3231

3332

3433
def _signed_area(path: torch.Tensor) -> torch.Tensor:
@@ -106,9 +105,8 @@ def fit_circle_in_2d(
106105
n_provided = points2d.shape[0]
107106
if n_provided < 3:
108107
raise ValueError(f"{n_provided} points are not enough to determine a circle")
109-
solution = lstsq(design, rhs[:, None])
108+
solution = torch.linalg.lstsq(design, rhs[:, None]).solution
110109
center = solution[:2, 0] / 2
111-
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
112110
radius = torch.sqrt(solution[2, 0] + (center**2).sum())
113111
if n_points > 0:
114112
if angles is not None:

pytorch3d/implicitron/tools/eval_video_trajectory.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import Optional, Tuple
1010

1111
import torch
12-
from pytorch3d.common.compat import eigh
1312
from pytorch3d.implicitron.tools import utils
1413
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
1514
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
@@ -205,7 +204,7 @@ def _disambiguate_normal(normal, up):
205204
def _fit_plane(x):
206205
x = x - x.mean(dim=0)[None]
207206
cov = (x.t() @ x) / x.shape[0]
208-
_, e_vec = eigh(cov)
207+
_, e_vec = torch.linalg.eigh(cov)
209208
return e_vec
210209

211210

pytorch3d/ops/perspective_n_points.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import torch
1818
import torch.nn.functional as F
19-
from pytorch3d.common.compat import eigh
2019
from pytorch3d.ops import points_alignment, utils as oputil
2120

2221

@@ -106,7 +105,7 @@ def _null_space(m, kernel_dim):
106105
kernel vectors, of size B x kernel_dim
107106
"""
108107
mTm = torch.bmm(m.transpose(1, 2), m)
109-
s, v = eigh(mTm)
108+
s, v = torch.linalg.eigh(mTm)
110109
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
111110

112111

pytorch3d/ops/points_normals.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Tuple, TYPE_CHECKING, Union
88

99
import torch
10-
from pytorch3d.common.compat import eigh
1110
from pytorch3d.common.workaround import symeig3x3
1211

1312
from .utils import convert_pointclouds_to_tensor, get_point_covariances
@@ -147,7 +146,7 @@ def estimate_pointcloud_local_coord_frames(
147146
if use_symeig_workaround:
148147
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
149148
else:
150-
curvatures, local_coord_frames = eigh(cov)
149+
curvatures, local_coord_frames = torch.linalg.eigh(cov)
151150

152151
# disambiguate the directions of individual principal vectors
153152
if disambiguate_directions:

pytorch3d/transforms/se3.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from pytorch3d.common.compat import solve
98

109
from .so3 import _so3_exp_map, hat, so3_log_map
1110

@@ -174,7 +173,7 @@ def se3_log_map(
174173
# log_translation is V^-1 @ T
175174
T = transform[:, 3, :3]
176175
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
177-
log_translation = solve(V, T[:, :, None])[:, :, 0]
176+
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]
178177

179178
return torch.cat((log_translation, log_rotation), dim=1)
180179

tests/test_acos_linear_extrapolation.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import numpy as np
1111
import torch
12-
from pytorch3d.common.compat import lstsq
1312
from pytorch3d.transforms import acos_linear_extrapolation
1413

1514
from .common_testing import TestCaseMixin
@@ -66,7 +65,7 @@ def _test_acos_outside_bounds(self, x, y, dydx, bound):
6665
bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype)
6766
# fit a line: slope * x + bias = y
6867
x_1 = torch.stack([x, torch.ones_like(x)], dim=-1)
69-
slope, bias = lstsq(x_1, y[:, None]).view(-1)[:2]
68+
slope, bias = torch.linalg.lstsq(x_1, y[:, None]).solution.view(-1)[:2]
7069
desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t**2)
7170
# test that the desired slope is the same as the fitted one
7271
self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2)

tests/test_se3.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import numpy as np
1111
import torch
12-
from pytorch3d.common.compat import qr
1312
from pytorch3d.transforms.rotation_conversions import random_rotations
1413
from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map
1514
from pytorch3d.transforms.so3 import so3_exp_map, so3_log_map, so3_rotation_angle
@@ -199,7 +198,7 @@ def test_se3_log_singularity(self, batch_size: int = 100):
199198
r = [identity, rot180]
200199
r.extend(
201200
[
202-
qr(identity + torch.randn_like(identity) * 1e-6)[0]
201+
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-6)[0]
203202
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8
204203
# this adds random noise to the second half
205204
# of the random orthogonal matrices to generate

tests/test_so3.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import numpy as np
1313
import torch
14-
from pytorch3d.common.compat import qr
1514
from pytorch3d.transforms.so3 import (
1615
hat,
1716
so3_exp_map,
@@ -49,7 +48,7 @@ def init_rot(batch_size: int = 10):
4948
# TODO(dnovotny): replace with random_rotation from random_rotation.py
5049
rot = []
5150
for _ in range(batch_size):
52-
r = qr(torch.randn((3, 3), device=device))[0]
51+
r = torch.linalg.qr(torch.randn((3, 3), device=device))[0]
5352
f = torch.randint(2, (3,), device=device, dtype=torch.float32)
5453
if f.sum() % 2 == 0:
5554
f = 1 - f
@@ -145,7 +144,7 @@ def test_so3_log_singularity(self, batch_size: int = 100):
145144
# add random rotations and random almost orthonormal matrices
146145
r.extend(
147146
[
148-
qr(identity + torch.randn_like(identity) * 1e-4)[0]
147+
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0]
149148
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
150149
# this adds random noise to the second half
151150
# of the random orthogonal matrices to generate
@@ -245,7 +244,7 @@ def test_so3_cos_bound(self, batch_size: int = 100):
245244
r = [identity, rot180]
246245
r.extend(
247246
[
248-
qr(identity + torch.randn_like(identity) * 1e-4)[0]
247+
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0]
249248
for _ in range(batch_size - 2)
250249
]
251250
)

0 commit comments

Comments
 (0)