Skip to content

Commit 06ad1fb

Browse files
bottlerfacebook-github-bot
authored andcommitted
KNN return order documentation
Summary: Fix documentation of KNN, issue facebookresearch#180 Reviewed By: gkioxari Differential Revision: D21384761 fbshipit-source-id: 2b36ee496f2060d17827d2fd66c490cdfa766866
1 parent 0eca74f commit 06ad1fb

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

pytorch3d/ops/knn.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,18 @@ def knn_points(
115115
return_nn: If set to True returns the K nearest neighors in p2 for each point in p1.
116116
117117
Returns:
118-
p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
118+
dists: Tensor of shape (N, P1, K) giving the squared distances to
119+
the nearest neighbors. This is padded with zeros both where a cloud in p2
120+
has fewer than K points and where a cloud in p1 has fewer than P1 points.
121+
122+
idx: LongTensor of shape (N, P1, K) giving the indices of the
119123
K nearest neighbors from points in p1 to points in p2.
120124
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
121125
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
122126
in p2 has fewer than K points and where a cloud in p1 has fewer than P1
123127
points.
124128
125-
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
126-
the nearest neighbors. This is padded with zeros both where a cloud in p2
127-
has fewer than K points and where a cloud in p1 has fewer than P1 points.
128-
129-
p2_nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for
129+
nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for
130130
each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor
131131
for `p1[n, i]`. Returned if `return_nn` is True.
132132
The nearest neighbors are collected using `knn_gather`

pytorch3d/ops/points_alignment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def iterative_closest_point(
158158
for iteration in range(max_iterations):
159159
Xt_nn_points = knn_points(
160160
Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True
161-
)[2][:, :, 0, :]
161+
).knn[:, :, 0, :]
162162

163163
# get the alignment of the nearest neighbors from Yt with Xt_init
164164
R, T, s = corresponding_points_alignment(

pytorch3d/ops/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,14 @@ def get_point_covariances(
126126
of shape `(minibatch, num_points, neighborhood_size, dim)`.
127127
"""
128128
# get K nearest neighbor idx for each point in the point cloud
129-
_, _, k_nearest_neighbors = knn_points(
129+
k_nearest_neighbors = knn_points(
130130
points_padded,
131131
points_padded,
132132
lengths1=num_points_per_cloud,
133133
lengths2=num_points_per_cloud,
134134
K=neighborhood_size,
135135
return_nn=True,
136-
)
136+
).knn
137137
# obtain the mean of the neighborhood
138138
pt_mean = k_nearest_neighbors.mean(2, keepdim=True)
139139
# compute the diff of the neighborhood and the mean of the neighborhood

0 commit comments

Comments
 (0)