Skip to content

Commit 3953de4

Browse files
bottlerfacebook-github-bot
authored andcommitted
remove torch from cuda
Summary: Keep using at:: instead of torch:: so we don't need torch/extension.h and can keep other compilers happy. Reviewed By: patricklabatut Differential Revision: D31688436 fbshipit-source-id: 1825503da0104acaf1558d17300c02ef663bf538
1 parent 1a7442a commit 3953de4

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

pytorch3d/csrc/points_to_volumes/points_to_volumes.cu

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
1111
#include <c10/cuda/CUDAGuard.h>
12-
#include <torch/extension.h>
1312

14-
using torch::PackedTensorAccessor64;
15-
using torch::RestrictPtrTraits;
13+
using at::PackedTensorAccessor64;
14+
using at::RestrictPtrTraits;
1615

1716
// A chunk of work is blocksize-many points.
1817
// There are N clouds in the batch, and P points in each cloud.
@@ -117,12 +116,12 @@ __global__ void PointsToVolumesForwardKernel(
117116
}
118117

119118
void PointsToVolumesForwardCuda(
120-
const torch::Tensor& points_3d,
121-
const torch::Tensor& points_features,
122-
const torch::Tensor& volume_densities,
123-
const torch::Tensor& volume_features,
124-
const torch::Tensor& grid_sizes,
125-
const torch::Tensor& mask,
119+
const at::Tensor& points_3d,
120+
const at::Tensor& points_features,
121+
const at::Tensor& volume_densities,
122+
const at::Tensor& volume_features,
123+
const at::Tensor& grid_sizes,
124+
const at::Tensor& mask,
126125
const float point_weight,
127126
const bool align_corners,
128127
const bool splat) {
@@ -285,17 +284,17 @@ __global__ void PointsToVolumesBackwardKernel(
285284
}
286285

287286
void PointsToVolumesBackwardCuda(
288-
const torch::Tensor& points_3d,
289-
const torch::Tensor& points_features,
290-
const torch::Tensor& grid_sizes,
291-
const torch::Tensor& mask,
287+
const at::Tensor& points_3d,
288+
const at::Tensor& points_features,
289+
const at::Tensor& grid_sizes,
290+
const at::Tensor& mask,
292291
const float point_weight,
293292
const bool align_corners,
294293
const bool splat,
295-
const torch::Tensor& grad_volume_densities,
296-
const torch::Tensor& grad_volume_features,
297-
const torch::Tensor& grad_points_3d,
298-
const torch::Tensor& grad_points_features) {
294+
const at::Tensor& grad_volume_densities,
295+
const at::Tensor& grad_volume_features,
296+
const at::Tensor& grad_points_3d,
297+
const at::Tensor& grad_points_features) {
299298
// Check inputs are on the same device
300299
at::TensorArg points_3d_t{points_3d, "points_3d", 1},
301300
points_features_t{points_features, "points_features", 2},

0 commit comments

Comments
 (0)