4
4
#include < float.h>
5
5
6
6
template <typename scalar_t >
7
- __device__ void warp_reduce (
7
+ __device__ void WarpReduce (
8
8
volatile scalar_t * min_dists,
9
9
volatile long * min_idxs,
10
10
const size_t tid) {
@@ -54,7 +54,7 @@ __device__ void warp_reduce(
54
54
// is aligned.
55
55
//
56
56
template <typename scalar_t >
57
- __global__ void nearest_neighbor_kernel (
57
+ __global__ void NearestNeighborKernel (
58
58
const scalar_t * __restrict__ points1,
59
59
const scalar_t * __restrict__ points2,
60
60
long * __restrict__ idx,
@@ -123,7 +123,7 @@ __global__ void nearest_neighbor_kernel(
123
123
// Unroll the last 6 iterations of the loop since they will happen
124
124
// synchronized within a single warp.
125
125
if (tid < 32 )
126
- warp_reduce <scalar_t >(min_dists, min_idxs, tid);
126
+ WarpReduce <scalar_t >(min_dists, min_idxs, tid);
127
127
128
128
// Finally thread 0 writes the result to the output buffer.
129
129
if (tid == 0 ) {
@@ -144,7 +144,7 @@ __global__ void nearest_neighbor_kernel(
144
144
// P2: Number of points in points2.
145
145
//
146
146
template <typename scalar_t >
147
- __global__ void nearest_neighbor_kernel_D3 (
147
+ __global__ void NearestNeighborKernelD3 (
148
148
const scalar_t * __restrict__ points1,
149
149
const scalar_t * __restrict__ points2,
150
150
long * __restrict__ idx,
@@ -204,15 +204,15 @@ __global__ void nearest_neighbor_kernel_D3(
204
204
// Unroll the last 6 iterations of the loop since they will happen
205
205
// synchronized within a single warp.
206
206
if (tid < 32 )
207
- warp_reduce <scalar_t >(min_dists, min_idxs, tid);
207
+ WarpReduce <scalar_t >(min_dists, min_idxs, tid);
208
208
209
209
// Finally thread 0 writes the result to the output buffer.
210
210
if (tid == 0 ) {
211
211
idx[n * P1 + i] = min_idxs[0 ];
212
212
}
213
213
}
214
214
215
- at::Tensor nn_points_idx_cuda (at::Tensor p1, at::Tensor p2) {
215
+ at::Tensor NearestNeighborIdxCuda (at::Tensor p1, at::Tensor p2) {
216
216
const auto N = p1.size (0 );
217
217
const auto P1 = p1.size (1 );
218
218
const auto P2 = p2.size (1 );
@@ -231,7 +231,7 @@ at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) {
231
231
AT_DISPATCH_FLOATING_TYPES (p1.type (), " nearest_neighbor_v3_cuda" , ([&] {
232
232
size_t shared_size = threads * sizeof (size_t ) +
233
233
threads * sizeof (long );
234
- nearest_neighbor_kernel_D3 <scalar_t >
234
+ NearestNeighborKernelD3 <scalar_t >
235
235
<<<blocks, threads, shared_size>>> (
236
236
p1.data_ptr <scalar_t >(),
237
237
p2.data_ptr <scalar_t >(),
@@ -249,7 +249,7 @@ at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) {
249
249
size_t D_2 = D + (D % 2 );
250
250
size_t shared_size = (D_2 + threads) * sizeof (size_t );
251
251
shared_size += threads * sizeof (long );
252
- nearest_neighbor_kernel <scalar_t ><<<blocks, threads, shared_size>>> (
252
+ NearestNeighborKernel <scalar_t ><<<blocks, threads, shared_size>>> (
253
253
p1.data_ptr <scalar_t >(),
254
254
p2.data_ptr <scalar_t >(),
255
255
idx.data_ptr <long >(),
0 commit comments