Skip to content

Commit

Permalink
upload pointnet2 package
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxi-wang committed Nov 29, 2022
1 parent e8ab37b commit 41350fe
Show file tree
Hide file tree
Showing 24 changed files with 2,515 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
**/__pycache__/**
**/*.egg-info/
*.o
*.so
*.egg
.DS_Store
.TimeRecord
license
Empty file added pointnet2/pointnet2/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions pointnet2/pointnet2/_ext_src/include/ball_query.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <torch/extension.h>

at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
const int nsample);
46 changes: 46 additions & 0 deletions pointnet2/pointnet2/_ext_src/include/cuda_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cmath>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

#define TOTAL_THREADS 512

inline int opt_n_threads(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);

return max(min(1 << pow_2, TOTAL_THREADS), 1);
}

inline dim3 opt_block_config(int x, int y) {
const int x_threads = opt_n_threads(x);
const int y_threads =
max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
dim3 block_config(x_threads, y_threads, 1);

return block_config;
}

#define CUDA_CHECK_ERRORS() \
do { \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \
fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
__FILE__); \
exit(-1); \
} \
} while (0)

#endif
10 changes: 10 additions & 0 deletions pointnet2/pointnet2/_ext_src/include/cylinder_query.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <torch/extension.h>

at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax,
const int nsample);
10 changes: 10 additions & 0 deletions pointnet2/pointnet2/_ext_src/include/group_points.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <torch/extension.h>

at::Tensor group_points(at::Tensor points, at::Tensor idx);
at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
15 changes: 15 additions & 0 deletions pointnet2/pointnet2/_ext_src/include/interpolate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <torch/extension.h>
#include <vector>

std::vector<at::Tensor> three_nn(at::Tensor unknowns, at::Tensor knows);
at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
at::Tensor weight);
at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
at::Tensor weight, const int m);
11 changes: 11 additions & 0 deletions pointnet2/pointnet2/_ext_src/include/sampling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <torch/extension.h>

at::Tensor gather_points(at::Tensor points, at::Tensor idx);
at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
30 changes: 30 additions & 0 deletions pointnet2/pointnet2/_ext_src/include/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#define CHECK_CUDA(x) \
do { \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \
} while (0)

#define CHECK_CONTIGUOUS(x) \
do { \
TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \
} while (0)

#define CHECK_IS_INT(x) \
do { \
TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
#x " must be an int tensor"); \
} while (0)

#define CHECK_IS_FLOAT(x) \
do { \
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \
#x " must be a float tensor"); \
} while (0)
37 changes: 37 additions & 0 deletions pointnet2/pointnet2/_ext_src/src/ball_query.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#include "ball_query.h"
#include "utils.h"

void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
int nsample, const float *new_xyz,
const float *xyz, int *idx);

at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
const int nsample) {
CHECK_CONTIGUOUS(new_xyz);
CHECK_CONTIGUOUS(xyz);
CHECK_IS_FLOAT(new_xyz);
CHECK_IS_FLOAT(xyz);

if (new_xyz.type().is_cuda()) {
CHECK_CUDA(xyz);
}

at::Tensor idx =
torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
at::device(new_xyz.device()).dtype(at::ScalarType::Int));

if (new_xyz.type().is_cuda()) {
query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
radius, nsample, new_xyz.data<float>(),
xyz.data<float>(), idx.data<int>());
} else {
TORCH_CHECK(false, "CPU not supported");
}

return idx;
}
59 changes: 59 additions & 0 deletions pointnet2/pointnet2/_ext_src/src/ball_query_gpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "cuda_utils.h"

// input: new_xyz(b, m, 3) xyz(b, n, 3)
// output: idx(b, m, nsample)
__global__ void query_ball_point_kernel(int b, int n, int m, float radius,
int nsample,
const float *__restrict__ new_xyz,
const float *__restrict__ xyz,
int *__restrict__ idx) {
int batch_index = blockIdx.x;
xyz += batch_index * n * 3;
new_xyz += batch_index * m * 3;
idx += m * nsample * batch_index;

int index = threadIdx.x;
int stride = blockDim.x;

float radius2 = radius * radius;
for (int j = index; j < m; j += stride) {
float new_x = new_xyz[j * 3 + 0];
float new_y = new_xyz[j * 3 + 1];
float new_z = new_xyz[j * 3 + 2];
for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
float x = xyz[k * 3 + 0];
float y = xyz[k * 3 + 1];
float z = xyz[k * 3 + 2];
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < radius2) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
idx[j * nsample + l] = k;
}
}
idx[j * nsample + cnt] = k;
++cnt;
}
}
}
}

void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
int nsample, const float *new_xyz,
const float *xyz, int *idx) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
query_ball_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(
b, n, m, radius, nsample, new_xyz, xyz, idx);

CUDA_CHECK_ERRORS();
}
27 changes: 27 additions & 0 deletions pointnet2/pointnet2/_ext_src/src/bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#include "ball_query.h"
#include "group_points.h"
#include "interpolate.h"
#include "sampling.h"
#include "cylinder_query.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_points", &gather_points);
m.def("gather_points_grad", &gather_points_grad);
m.def("furthest_point_sampling", &furthest_point_sampling);

m.def("three_nn", &three_nn);
m.def("three_interpolate", &three_interpolate);
m.def("three_interpolate_grad", &three_interpolate_grad);

m.def("ball_query", &ball_query);

m.def("group_points", &group_points);
m.def("group_points_grad", &group_points_grad);

m.def("cylinder_query", &cylinder_query);
}
40 changes: 40 additions & 0 deletions pointnet2/pointnet2/_ext_src/src/cylinder_query.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#include "cylinder_query.h"
#include "utils.h"

void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax,
int nsample, const float *new_xyz,
const float *xyz, const float *rot, int *idx);

at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax,
const int nsample) {
CHECK_CONTIGUOUS(new_xyz);
CHECK_CONTIGUOUS(xyz);
CHECK_CONTIGUOUS(rot);
CHECK_IS_FLOAT(new_xyz);
CHECK_IS_FLOAT(xyz);
CHECK_IS_FLOAT(rot);

if (new_xyz.type().is_cuda()) {
CHECK_CUDA(xyz);
CHECK_CUDA(rot);
}

at::Tensor idx =
torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
at::device(new_xyz.device()).dtype(at::ScalarType::Int));

if (new_xyz.type().is_cuda()) {
query_cylinder_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
radius, hmin, hmax, nsample, new_xyz.data<float>(),
xyz.data<float>(), rot.data<float>(), idx.data<int>());
} else {
TORCH_CHECK(false, "CPU not supported");
}

return idx;
}
72 changes: 72 additions & 0 deletions pointnet2/pointnet2/_ext_src/src/cylinder_query_gpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "cuda_utils.h"

// input: new_xyz(b, m, 3) xyz(b, n, 3) rot_c2w(b, m, 9)
// output: idx(b, m, nsample)
__global__ void query_cylinder_point_kernel(int b, int n, int m, float radius, float hmin, float hmax,
int nsample,
const float *__restrict__ new_xyz,
const float *__restrict__ xyz,
const float *__restrict__ rot,
int *__restrict__ idx) {
int batch_index = blockIdx.x;
xyz += batch_index * n * 3;
new_xyz += batch_index * m * 3;
rot += batch_index * m * 9;
idx += m * nsample * batch_index;

int index = threadIdx.x;
int stride = blockDim.x;

float radius2 = radius * radius;
for (int j = index; j < m; j += stride) {
float new_x = new_xyz[j * 3 + 0];
float new_y = new_xyz[j * 3 + 1];
float new_z = new_xyz[j * 3 + 2];
float r0 = rot[j * 9 + 0];
float r1 = rot[j * 9 + 1];
float r2 = rot[j * 9 + 2];
float r3 = rot[j * 9 + 3];
float r4 = rot[j * 9 + 4];
float r5 = rot[j * 9 + 5];
float r6 = rot[j * 9 + 6];
float r7 = rot[j * 9 + 7];
float r8 = rot[j * 9 + 8];
for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
float x = xyz[k * 3 + 0] - new_x;
float y = xyz[k * 3 + 1] - new_y;
float z = xyz[k * 3 + 2] - new_z;
float x_rot = r0 * x + r3 * y + r6 * z;
float y_rot = r1 * x + r4 * y + r7 * z;
float z_rot = r2 * x + r5 * y + r8 * z;
float d2 = y_rot * y_rot + z_rot * z_rot;
if (d2 < radius2 && x_rot > hmin && x_rot < hmax) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
idx[j * nsample + l] = k;
}
}
idx[j * nsample + cnt] = k;
++cnt;
}
}
}
}

void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax,
int nsample, const float *new_xyz,
const float *xyz, const float *rot, int *idx) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
query_cylinder_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(
b, n, m, radius, hmin, hmax, nsample, new_xyz, xyz, rot, idx);

CUDA_CHECK_ERRORS();
}
Loading

0 comments on commit 41350fe

Please sign in to comment.