-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e8ab37b
commit 41350fe
Showing
24 changed files
with
2,515 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
**/__pycache__/** | ||
**/*.egg-info/ | ||
*.o | ||
*.so | ||
*.egg | ||
.DS_Store | ||
.TimeRecord | ||
license |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
#include <torch/extension.h> | ||
|
||
at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, | ||
const int nsample); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#ifndef _CUDA_UTILS_H | ||
#define _CUDA_UTILS_H | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <cmath> | ||
|
||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
|
||
#include <vector> | ||
|
||
#define TOTAL_THREADS 512 | ||
|
||
inline int opt_n_threads(int work_size) { | ||
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0); | ||
|
||
return max(min(1 << pow_2, TOTAL_THREADS), 1); | ||
} | ||
|
||
inline dim3 opt_block_config(int x, int y) { | ||
const int x_threads = opt_n_threads(x); | ||
const int y_threads = | ||
max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); | ||
dim3 block_config(x_threads, y_threads, 1); | ||
|
||
return block_config; | ||
} | ||
|
||
#define CUDA_CHECK_ERRORS() \ | ||
do { \ | ||
cudaError_t err = cudaGetLastError(); \ | ||
if (cudaSuccess != err) { \ | ||
fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ | ||
cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ | ||
__FILE__); \ | ||
exit(-1); \ | ||
} \ | ||
} while (0) | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
#include <torch/extension.h> | ||
|
||
at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax, | ||
const int nsample); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
#include <torch/extension.h> | ||
|
||
at::Tensor group_points(at::Tensor points, at::Tensor idx); | ||
at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
|
||
#include <torch/extension.h> | ||
#include <vector> | ||
|
||
std::vector<at::Tensor> three_nn(at::Tensor unknowns, at::Tensor knows); | ||
at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, | ||
at::Tensor weight); | ||
at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, | ||
at::Tensor weight, const int m); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
#include <torch/extension.h> | ||
|
||
at::Tensor gather_points(at::Tensor points, at::Tensor idx); | ||
at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); | ||
at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <torch/extension.h> | ||
|
||
#define CHECK_CUDA(x) \ | ||
do { \ | ||
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ | ||
} while (0) | ||
|
||
#define CHECK_CONTIGUOUS(x) \ | ||
do { \ | ||
TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ | ||
} while (0) | ||
|
||
#define CHECK_IS_INT(x) \ | ||
do { \ | ||
TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ | ||
#x " must be an int tensor"); \ | ||
} while (0) | ||
|
||
#define CHECK_IS_FLOAT(x) \ | ||
do { \ | ||
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ | ||
#x " must be a float tensor"); \ | ||
} while (0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#include "ball_query.h" | ||
#include "utils.h" | ||
|
||
void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, | ||
int nsample, const float *new_xyz, | ||
const float *xyz, int *idx); | ||
|
||
at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, | ||
const int nsample) { | ||
CHECK_CONTIGUOUS(new_xyz); | ||
CHECK_CONTIGUOUS(xyz); | ||
CHECK_IS_FLOAT(new_xyz); | ||
CHECK_IS_FLOAT(xyz); | ||
|
||
if (new_xyz.type().is_cuda()) { | ||
CHECK_CUDA(xyz); | ||
} | ||
|
||
at::Tensor idx = | ||
torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, | ||
at::device(new_xyz.device()).dtype(at::ScalarType::Int)); | ||
|
||
if (new_xyz.type().is_cuda()) { | ||
query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), | ||
radius, nsample, new_xyz.data<float>(), | ||
xyz.data<float>(), idx.data<int>()); | ||
} else { | ||
TORCH_CHECK(false, "CPU not supported"); | ||
} | ||
|
||
return idx; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#include <math.h> | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
|
||
#include "cuda_utils.h" | ||
|
||
// input: new_xyz(b, m, 3) xyz(b, n, 3) | ||
// output: idx(b, m, nsample) | ||
__global__ void query_ball_point_kernel(int b, int n, int m, float radius, | ||
int nsample, | ||
const float *__restrict__ new_xyz, | ||
const float *__restrict__ xyz, | ||
int *__restrict__ idx) { | ||
int batch_index = blockIdx.x; | ||
xyz += batch_index * n * 3; | ||
new_xyz += batch_index * m * 3; | ||
idx += m * nsample * batch_index; | ||
|
||
int index = threadIdx.x; | ||
int stride = blockDim.x; | ||
|
||
float radius2 = radius * radius; | ||
for (int j = index; j < m; j += stride) { | ||
float new_x = new_xyz[j * 3 + 0]; | ||
float new_y = new_xyz[j * 3 + 1]; | ||
float new_z = new_xyz[j * 3 + 2]; | ||
for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { | ||
float x = xyz[k * 3 + 0]; | ||
float y = xyz[k * 3 + 1]; | ||
float z = xyz[k * 3 + 2]; | ||
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + | ||
(new_z - z) * (new_z - z); | ||
if (d2 < radius2) { | ||
if (cnt == 0) { | ||
for (int l = 0; l < nsample; ++l) { | ||
idx[j * nsample + l] = k; | ||
} | ||
} | ||
idx[j * nsample + cnt] = k; | ||
++cnt; | ||
} | ||
} | ||
} | ||
} | ||
|
||
void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, | ||
int nsample, const float *new_xyz, | ||
const float *xyz, int *idx) { | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
query_ball_point_kernel<<<b, opt_n_threads(m), 0, stream>>>( | ||
b, n, m, radius, nsample, new_xyz, xyz, idx); | ||
|
||
CUDA_CHECK_ERRORS(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#include "ball_query.h" | ||
#include "group_points.h" | ||
#include "interpolate.h" | ||
#include "sampling.h" | ||
#include "cylinder_query.h" | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("gather_points", &gather_points); | ||
m.def("gather_points_grad", &gather_points_grad); | ||
m.def("furthest_point_sampling", &furthest_point_sampling); | ||
|
||
m.def("three_nn", &three_nn); | ||
m.def("three_interpolate", &three_interpolate); | ||
m.def("three_interpolate_grad", &three_interpolate_grad); | ||
|
||
m.def("ball_query", &ball_query); | ||
|
||
m.def("group_points", &group_points); | ||
m.def("group_points_grad", &group_points_grad); | ||
|
||
m.def("cylinder_query", &cylinder_query); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#include "cylinder_query.h" | ||
#include "utils.h" | ||
|
||
void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax, | ||
int nsample, const float *new_xyz, | ||
const float *xyz, const float *rot, int *idx); | ||
|
||
at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax, | ||
const int nsample) { | ||
CHECK_CONTIGUOUS(new_xyz); | ||
CHECK_CONTIGUOUS(xyz); | ||
CHECK_CONTIGUOUS(rot); | ||
CHECK_IS_FLOAT(new_xyz); | ||
CHECK_IS_FLOAT(xyz); | ||
CHECK_IS_FLOAT(rot); | ||
|
||
if (new_xyz.type().is_cuda()) { | ||
CHECK_CUDA(xyz); | ||
CHECK_CUDA(rot); | ||
} | ||
|
||
at::Tensor idx = | ||
torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, | ||
at::device(new_xyz.device()).dtype(at::ScalarType::Int)); | ||
|
||
if (new_xyz.type().is_cuda()) { | ||
query_cylinder_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), | ||
radius, hmin, hmax, nsample, new_xyz.data<float>(), | ||
xyz.data<float>(), rot.data<float>(), idx.data<int>()); | ||
} else { | ||
TORCH_CHECK(false, "CPU not supported"); | ||
} | ||
|
||
return idx; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#include <math.h> | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
|
||
#include "cuda_utils.h" | ||
|
||
// input: new_xyz(b, m, 3) xyz(b, n, 3) rot_c2w(b, m, 9) | ||
// output: idx(b, m, nsample) | ||
__global__ void query_cylinder_point_kernel(int b, int n, int m, float radius, float hmin, float hmax, | ||
int nsample, | ||
const float *__restrict__ new_xyz, | ||
const float *__restrict__ xyz, | ||
const float *__restrict__ rot, | ||
int *__restrict__ idx) { | ||
int batch_index = blockIdx.x; | ||
xyz += batch_index * n * 3; | ||
new_xyz += batch_index * m * 3; | ||
rot += batch_index * m * 9; | ||
idx += m * nsample * batch_index; | ||
|
||
int index = threadIdx.x; | ||
int stride = blockDim.x; | ||
|
||
float radius2 = radius * radius; | ||
for (int j = index; j < m; j += stride) { | ||
float new_x = new_xyz[j * 3 + 0]; | ||
float new_y = new_xyz[j * 3 + 1]; | ||
float new_z = new_xyz[j * 3 + 2]; | ||
float r0 = rot[j * 9 + 0]; | ||
float r1 = rot[j * 9 + 1]; | ||
float r2 = rot[j * 9 + 2]; | ||
float r3 = rot[j * 9 + 3]; | ||
float r4 = rot[j * 9 + 4]; | ||
float r5 = rot[j * 9 + 5]; | ||
float r6 = rot[j * 9 + 6]; | ||
float r7 = rot[j * 9 + 7]; | ||
float r8 = rot[j * 9 + 8]; | ||
for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { | ||
float x = xyz[k * 3 + 0] - new_x; | ||
float y = xyz[k * 3 + 1] - new_y; | ||
float z = xyz[k * 3 + 2] - new_z; | ||
float x_rot = r0 * x + r3 * y + r6 * z; | ||
float y_rot = r1 * x + r4 * y + r7 * z; | ||
float z_rot = r2 * x + r5 * y + r8 * z; | ||
float d2 = y_rot * y_rot + z_rot * z_rot; | ||
if (d2 < radius2 && x_rot > hmin && x_rot < hmax) { | ||
if (cnt == 0) { | ||
for (int l = 0; l < nsample; ++l) { | ||
idx[j * nsample + l] = k; | ||
} | ||
} | ||
idx[j * nsample + cnt] = k; | ||
++cnt; | ||
} | ||
} | ||
} | ||
} | ||
|
||
void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax, | ||
int nsample, const float *new_xyz, | ||
const float *xyz, const float *rot, int *idx) { | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
query_cylinder_point_kernel<<<b, opt_n_threads(m), 0, stream>>>( | ||
b, n, m, radius, hmin, hmax, nsample, new_xyz, xyz, rot, idx); | ||
|
||
CUDA_CHECK_ERRORS(); | ||
} |
Oops, something went wrong.