|
9 | 9 | #include <ATen/ATen.h>
|
10 | 10 | #include <ATen/cuda/CUDAContext.h>
|
11 | 11 | #include <c10/cuda/CUDAGuard.h>
|
12 |
| -#include <torch/extension.h> |
13 | 12 |
|
14 |
| -using torch::PackedTensorAccessor64; |
15 |
| -using torch::RestrictPtrTraits; |
| 13 | +using at::PackedTensorAccessor64; |
| 14 | +using at::RestrictPtrTraits; |
16 | 15 |
|
17 | 16 | // A chunk of work is blocksize-many points.
|
18 | 17 | // There are N clouds in the batch, and P points in each cloud.
|
@@ -117,12 +116,12 @@ __global__ void PointsToVolumesForwardKernel(
|
117 | 116 | }
|
118 | 117 |
|
119 | 118 | void PointsToVolumesForwardCuda(
|
120 |
| - const torch::Tensor& points_3d, |
121 |
| - const torch::Tensor& points_features, |
122 |
| - const torch::Tensor& volume_densities, |
123 |
| - const torch::Tensor& volume_features, |
124 |
| - const torch::Tensor& grid_sizes, |
125 |
| - const torch::Tensor& mask, |
| 119 | + const at::Tensor& points_3d, |
| 120 | + const at::Tensor& points_features, |
| 121 | + const at::Tensor& volume_densities, |
| 122 | + const at::Tensor& volume_features, |
| 123 | + const at::Tensor& grid_sizes, |
| 124 | + const at::Tensor& mask, |
126 | 125 | const float point_weight,
|
127 | 126 | const bool align_corners,
|
128 | 127 | const bool splat) {
|
@@ -285,17 +284,17 @@ __global__ void PointsToVolumesBackwardKernel(
|
285 | 284 | }
|
286 | 285 |
|
287 | 286 | void PointsToVolumesBackwardCuda(
|
288 |
| - const torch::Tensor& points_3d, |
289 |
| - const torch::Tensor& points_features, |
290 |
| - const torch::Tensor& grid_sizes, |
291 |
| - const torch::Tensor& mask, |
| 287 | + const at::Tensor& points_3d, |
| 288 | + const at::Tensor& points_features, |
| 289 | + const at::Tensor& grid_sizes, |
| 290 | + const at::Tensor& mask, |
292 | 291 | const float point_weight,
|
293 | 292 | const bool align_corners,
|
294 | 293 | const bool splat,
|
295 |
| - const torch::Tensor& grad_volume_densities, |
296 |
| - const torch::Tensor& grad_volume_features, |
297 |
| - const torch::Tensor& grad_points_3d, |
298 |
| - const torch::Tensor& grad_points_features) { |
| 294 | + const at::Tensor& grad_volume_densities, |
| 295 | + const at::Tensor& grad_volume_features, |
| 296 | + const at::Tensor& grad_points_3d, |
| 297 | + const at::Tensor& grad_points_features) { |
299 | 298 | // Check inputs are on the same device
|
300 | 299 | at::TensorArg points_3d_t{points_3d, "points_3d", 1},
|
301 | 300 | points_features_t{points_features, "points_features", 2},
|
|
0 commit comments