Skip to content

Commit e290f87

Browse files
jcjohnsonfacebook-github-bot
authored andcommitted
Add CPU implementation for nearest neighbor
Summary: Adds a CPU implementation for `pytorch3d.ops.nn_points_idx`. Also renames the associated C++ and CUDA functions to use `AllCaps` names used in other C++ / CUDA code. Reviewed By: gkioxari Differential Revision: D19670491 fbshipit-source-id: 1b6409404025bf05e6a93f5d847e35afc9062f05
1 parent 25c2f34 commit e290f87

6 files changed

+86
-26
lines changed

pytorch3d/csrc/ext.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1212
m.def("face_areas_normals", &face_areas_normals);
1313
m.def("packed_to_padded_tensor", &packed_to_padded_tensor);
14-
m.def("nn_points_idx", &nn_points_idx);
14+
m.def("nn_points_idx", &NearestNeighborIdx);
1515
m.def("gather_scatter", &gather_scatter);
1616
m.def("rasterize_points", &RasterizePoints);
1717
m.def("rasterize_points_backward", &RasterizePointsBackward);

pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.cu

+8-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <float.h>
55

66
template <typename scalar_t>
7-
__device__ void warp_reduce(
7+
__device__ void WarpReduce(
88
volatile scalar_t* min_dists,
99
volatile long* min_idxs,
1010
const size_t tid) {
@@ -54,7 +54,7 @@ __device__ void warp_reduce(
5454
// is aligned.
5555
//
5656
template <typename scalar_t>
57-
__global__ void nearest_neighbor_kernel(
57+
__global__ void NearestNeighborKernel(
5858
const scalar_t* __restrict__ points1,
5959
const scalar_t* __restrict__ points2,
6060
long* __restrict__ idx,
@@ -123,7 +123,7 @@ __global__ void nearest_neighbor_kernel(
123123
// Unroll the last 6 iterations of the loop since they will happen
124124
// synchronized within a single warp.
125125
if (tid < 32)
126-
warp_reduce<scalar_t>(min_dists, min_idxs, tid);
126+
WarpReduce<scalar_t>(min_dists, min_idxs, tid);
127127

128128
// Finally thread 0 writes the result to the output buffer.
129129
if (tid == 0) {
@@ -144,7 +144,7 @@ __global__ void nearest_neighbor_kernel(
144144
// P2: Number of points in points2.
145145
//
146146
template <typename scalar_t>
147-
__global__ void nearest_neighbor_kernel_D3(
147+
__global__ void NearestNeighborKernelD3(
148148
const scalar_t* __restrict__ points1,
149149
const scalar_t* __restrict__ points2,
150150
long* __restrict__ idx,
@@ -204,15 +204,15 @@ __global__ void nearest_neighbor_kernel_D3(
204204
// Unroll the last 6 iterations of the loop since they will happen
205205
// synchronized within a single warp.
206206
if (tid < 32)
207-
warp_reduce<scalar_t>(min_dists, min_idxs, tid);
207+
WarpReduce<scalar_t>(min_dists, min_idxs, tid);
208208

209209
// Finally thread 0 writes the result to the output buffer.
210210
if (tid == 0) {
211211
idx[n * P1 + i] = min_idxs[0];
212212
}
213213
}
214214

215-
at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) {
215+
at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2) {
216216
const auto N = p1.size(0);
217217
const auto P1 = p1.size(1);
218218
const auto P2 = p2.size(1);
@@ -231,7 +231,7 @@ at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) {
231231
AT_DISPATCH_FLOATING_TYPES(p1.type(), "nearest_neighbor_v3_cuda", ([&] {
232232
size_t shared_size = threads * sizeof(size_t) +
233233
threads * sizeof(long);
234-
nearest_neighbor_kernel_D3<scalar_t>
234+
NearestNeighborKernelD3<scalar_t>
235235
<<<blocks, threads, shared_size>>>(
236236
p1.data_ptr<scalar_t>(),
237237
p2.data_ptr<scalar_t>(),
@@ -249,7 +249,7 @@ at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) {
249249
size_t D_2 = D + (D % 2);
250250
size_t shared_size = (D_2 + threads) * sizeof(size_t);
251251
shared_size += threads * sizeof(long);
252-
nearest_neighbor_kernel<scalar_t><<<blocks, threads, shared_size>>>(
252+
NearestNeighborKernel<scalar_t><<<blocks, threads, shared_size>>>(
253253
p1.data_ptr<scalar_t>(),
254254
p2.data_ptr<scalar_t>(),
255255
idx.data_ptr<long>(),

pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h

+7-4
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,22 @@
1919
// to p1[n, i] in the cloud p2[n] is p2[n, j].
2020
//
2121

22+
// CPU implementation.
23+
at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2);
24+
2225
// Cuda implementation.
23-
at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2);
26+
at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2);
2427

2528
// Implementation which is exposed.
26-
at::Tensor nn_points_idx(at::Tensor p1, at::Tensor p2) {
29+
at::Tensor NearestNeighborIdx(at::Tensor p1, at::Tensor p2) {
2730
if (p1.type().is_cuda() && p2.type().is_cuda()) {
2831
#ifdef WITH_CUDA
2932
CHECK_CONTIGUOUS_CUDA(p1);
3033
CHECK_CONTIGUOUS_CUDA(p2);
31-
return nn_points_idx_cuda(p1, p2);
34+
return NearestNeighborIdxCuda(p1, p2);
3235
#else
3336
AT_ERROR("Not compiled with GPU support.");
3437
#endif
3538
}
36-
AT_ERROR("Not implemented on the CPU.");
39+
return NearestNeighborIdxCpu(p1, p2);
3740
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#include <torch/extension.h>
4+
5+
at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2) {
6+
const int N = p1.size(0);
7+
const int P1 = p1.size(1);
8+
const int D = p1.size(2);
9+
const int P2 = p2.size(1);
10+
11+
auto long_opts = p1.options().dtype(torch::kInt64);
12+
torch::Tensor out = torch::empty({N, P1}, long_opts);
13+
14+
auto p1_a = p1.accessor<float, 3>();
15+
auto p2_a = p2.accessor<float, 3>();
16+
auto out_a = out.accessor<int64_t, 2>();
17+
18+
for (int n = 0; n < N; ++n) {
19+
for (int i1 = 0; i1 < P1; ++i1) {
20+
// TODO: support other floating-point types?
21+
float min_dist = -1;
22+
int64_t min_idx = -1;
23+
for (int i2 = 0; i2 < P2; ++i2) {
24+
float dist = 0;
25+
for (int d = 0; d < D; ++d) {
26+
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
27+
dist += diff * diff;
28+
}
29+
if (min_dist == -1 || dist < min_dist) {
30+
min_dist = dist;
31+
min_idx = i2;
32+
}
33+
}
34+
out_a[n][i1] = min_idx;
35+
}
36+
}
37+
return out;
38+
}

tests/bm_nearest_neighbor_points.py

+7
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ def bm_nn_points() -> None:
2727
warmup_iters=1,
2828
)
2929

30+
benchmark(
31+
TestNearestNeighborPoints.bm_nn_points_cpu_with_init,
32+
"NN_CPU",
33+
kwargs_list,
34+
warmup_iters=1,
35+
)
36+
3037
if torch.cuda.is_available():
3138
benchmark(
3239
TestNearestNeighborPoints.bm_nn_points_cuda_with_init,

tests/test_nearest_neighbor_points.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,7 @@ def nn_points_idx_naive(x, y):
2121
idx = dists2.argmin(2)
2222
return idx
2323

24-
def test_nn_cuda(self):
25-
"""
26-
Test cuda output vs naive python implementation.
27-
"""
28-
device = torch.device("cuda:0")
24+
def _test_nn_helper(self, device):
2925
for D in [3, 4]:
3026
for N in [1, 4]:
3127
for P1 in [1, 8, 64, 128]:
@@ -43,16 +39,32 @@ def test_nn_cuda(self):
4339
self.assertTrue(idx1.size(1) == P1)
4440
self.assertTrue(torch.all(idx1 == idx2))
4541

46-
def test_nn_cuda_error(self):
42+
def test_nn_cuda(self):
43+
"""
44+
Test cuda output vs naive python implementation.
45+
"""
46+
device = torch.device('cuda:0')
47+
self._test_nn_helper(device)
48+
49+
def test_nn_cpu(self):
4750
"""
48-
Check that nn_points_idx throws an error if cpu tensors
49-
are given as input.
51+
Test cpu output vs naive python implementation
5052
"""
51-
x = torch.randn(1, 1, 3)
52-
y = torch.randn(1, 1, 3)
53-
with self.assertRaises(Exception) as err:
54-
_C.nn_points_idx(x, y)
55-
self.assertTrue("Not implemented on the CPU" in str(err.exception))
53+
device = torch.device('cpu')
54+
self._test_nn_helper(device)
55+
56+
@staticmethod
57+
def bm_nn_points_cpu_with_init(
58+
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
59+
):
60+
device = torch.device('cpu')
61+
x = torch.randn(N, P1, D, device=device)
62+
y = torch.randn(N, P2, D, device=device)
63+
64+
def nn_cpu():
65+
_C.nn_points_idx(x.contiguous(), y.contiguous())
66+
67+
return nn_cpu
5668

5769
@staticmethod
5870
def bm_nn_points_cuda_with_init(

0 commit comments

Comments
 (0)