diff --git a/aten/src/ATen/core/TensorAccessor.h b/aten/src/ATen/core/TensorAccessor.h index 0116964e036225..95f37fcb09510c 100644 --- a/aten/src/ATen/core/TensorAccessor.h +++ b/aten/src/ATen/core/TensorAccessor.h @@ -1,12 +1,13 @@ #pragma once #include +#include #include #include namespace at { -// The PtrTraits argument to the TensorAccessor/PackedTensorAccessor +// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor // is used to enable the __restrict__ keyword/modifier for the data // passed to cuda. template @@ -62,7 +63,7 @@ class TensorAccessorBase { // The `TensorAccessor` is typically instantiated for CPU `Tensor`s using // `Tensor.accessor()`. -// For CUDA `Tensor`s, `PackedTensorAccessor` is used on the host and only +// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only // indexing on the device uses `TensorAccessor`s. template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> class TensorAccessor : public TensorAccessorBase { @@ -103,7 +104,7 @@ class TensorAccessor : public TensorAccessorBase : public TensorAccessorBase class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> -class PackedTensorAccessorBase { +class GenericPackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; - C10_HOST PackedTensorAccessorBase( + C10_HOST GenericPackedTensorAccessorBase( PtrType data_, const index_t* sizes_, const index_t* strides_) @@ -126,7 +127,7 @@ class PackedTensorAccessorBase { // if index_t is not int64_t, we want to have an int64_t constructor template ::value>::type> - C10_HOST PackedTensorAccessorBase( + C10_HOST GenericPackedTensorAccessorBase( PtrType data_, const source_index_t* sizes_, const source_index_t* strides_) @@ -156,23 +157,23 @@ class PackedTensorAccessorBase { }; template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> -class PackedTensorAccessor : public PackedTensorAccessorBase { +class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; - C10_HOST PackedTensorAccessor( + C10_HOST GenericPackedTensorAccessor( PtrType data_, const index_t* sizes_, const index_t* strides_) - : PackedTensorAccessorBase(data_, sizes_, strides_) {} + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} // if index_t is not int64_t, we want to have an int64_t constructor template ::value>::type> - C10_HOST PackedTensorAccessor( + C10_HOST GenericPackedTensorAccessor( PtrType data_, const source_index_t* sizes_, const source_index_t* strides_) - : PackedTensorAccessorBase(data_, sizes_, strides_) {} + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} C10_DEVICE TensorAccessor operator[](index_t i) { index_t* new_sizes = this->sizes_ + 1; @@ -188,22 +189,22 @@ class PackedTensorAccessor : public PackedTensorAccessorBase class PtrTraits, typename index_t> -class PackedTensorAccessor : public PackedTensorAccessorBase { +class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; - C10_HOST PackedTensorAccessor( + C10_HOST GenericPackedTensorAccessor( PtrType data_, const index_t* sizes_, const index_t* strides_) - : PackedTensorAccessorBase(data_, sizes_, strides_) {} + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} // if index_t is not int64_t, we want to have an int64_t constructor template ::value>::type> - C10_HOST PackedTensorAccessor( + C10_HOST GenericPackedTensorAccessor( PtrType data_, const source_index_t* sizes_, const source_index_t* strides_) - : PackedTensorAccessorBase(data_, sizes_, strides_) {} + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} C10_DEVICE T & operator[](index_t i) { return this->data_[this->strides_[0] * i]; @@ -213,4 +214,19 @@ class PackedTensorAccessor : public PackedTensorAccessorB } }; -} + +// Can't put this directly into the macro function args because of commas +#define AT_X GenericPackedTensorAccessor + +// Old name for `GenericPackedTensorAccessor` +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +C10_DEFINE_DEPRECATED_USING(PackedTensorAccessor, AT_X) + +#undef AT_X + +template class PtrTraits = DefaultPtrTraits> +using PackedTensorAccessor32 = GenericPackedTensorAccessor; + +template class PtrTraits = DefaultPtrTraits> +using PackedTensorAccessor64 = GenericPackedTensorAccessor; +} // namespace at diff --git a/aten/src/ATen/core/TensorBody.h b/aten/src/ATen/core/TensorBody.h index b8e45c7cc41ab0..a7c7a764f11b78 100644 --- a/aten/src/ATen/core/TensorBody.h +++ b/aten/src/ATen/core/TensorBody.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -317,19 +318,42 @@ class CAFFE2_API Tensor { template TensorAccessor accessor() && = delete; - // Return a `PackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and + // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and // dimension. You can optionally specify RestrictPtrTraits as a template parameter to // cast the data pointer to a __restrict__ pointer. - // In order to use this, your CUDA kernel has to take a corresponding PackedTensorAccessor + // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor // as an argument. template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - PackedTensorAccessor packed_accessor() const& { + GenericPackedTensorAccessor generic_packed_accessor() const& { static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); TORCH_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim()); - return PackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); + return GenericPackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); } - template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - PackedTensorAccessor packed_accessor() && = delete; + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + GenericPackedTensorAccessor generic_packed_accessor() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() && = delete; + + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() const & { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() && = delete; Tensor operator-() const; Tensor& operator+=(const Tensor & other); diff --git a/aten/src/ATen/native/cuda/AveragePool3d.cu b/aten/src/ATen/native/cuda/AveragePool3d.cu index 17fd342878871f..214e08d92bbf02 100644 --- a/aten/src/ATen/native/cuda/AveragePool3d.cu +++ b/aten/src/ATen/native/cuda/AveragePool3d.cu @@ -23,8 +23,8 @@ __device__ inline int max(int a, int b) { template __global__ void avg_pool3d_cuda_update_output( - PackedTensorAccessor input, - PackedTensorAccessor output, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, int kT, int kH, int kW, int dT, int dH, int dW, int padT, int padH, int padW, @@ -87,8 +87,8 @@ __global__ void avg_pool3d_cuda_update_output( // template __global__ void avg_pool3d_cuda_update_output( - PackedTensorAccessor input, - PackedTensorAccessor output, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, int kT, int kH, int dT, int dH, int dW, int padT, int padH, int padW, @@ -148,8 +148,8 @@ __global__ void avg_pool3d_cuda_update_output( template __global__ void avg_pool3d_single_backward_out_frame_stride1( - PackedTensorAccessor gradOutput, - PackedTensorAccessor gradInput, + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 gradInput, int kT, int kH, int kW, accscalar_t normFactor, int offsetZ) @@ -193,8 +193,8 @@ __global__ void avg_pool3d_single_backward_out_frame_stride1( template __global__ void avg_pool3d_cuda_update_grad_input_atomic( - PackedTensorAccessor gradOutput, - PackedTensorAccessor gradInput, + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 gradInput, int kT, int kH, int kW, int dT, int dH, int dW, int padT, int padH, int padW, @@ -251,8 +251,8 @@ __global__ void avg_pool3d_cuda_update_grad_input_atomic( template __global__ void avg_pool3d_cuda_update_grad_input( - PackedTensorAccessor gradOutput, - PackedTensorAccessor gradInput, + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 gradInput, int kT, int kH, int kW, int dT, int dH, int dW, int padT, int padH, int padW, @@ -309,8 +309,8 @@ __global__ void avg_pool3d_cuda_update_grad_input( #define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \ avg_pool3d_cuda_update_output \ <<>>( \ - work_input.packed_accessor(), \ - work_output.packed_accessor(), \ + work_input.packed_accessor64(), \ + work_output.packed_accessor64(), \ kT, kH, \ dT, dH, dW, \ padT, padH, padW, \ @@ -425,8 +425,8 @@ void avg_pool3d_out_cuda_template( default: avg_pool3d_cuda_update_output <<>>( - work_input.packed_accessor(), - work_output.packed_accessor(), + work_input.packed_accessor64(), + work_output.packed_accessor64(), kT, kH, kW, dT, dH, dW, padT, padH, padW, @@ -567,8 +567,8 @@ void avg_pool3d_backward_out_cuda_template( avg_pool3d_single_backward_out_frame_stride1 <<>>( - work_grad_output.packed_accessor(), - work_grad_input.packed_accessor(), + work_grad_output.packed_accessor64(), + work_grad_input.packed_accessor64(), kT, kH, kW, 1.0f/divide_factor, offsetZ); @@ -600,8 +600,8 @@ void avg_pool3d_backward_out_cuda_template( if (kernelsOverlap) { avg_pool3d_cuda_update_grad_input_atomic <<>>( - work_grad_output.packed_accessor(), - work_grad_input.packed_accessor(), + work_grad_output.packed_accessor64(), + work_grad_input.packed_accessor64(), kT, kH, kW, dT, dH, dW, padT, padH, padW, @@ -611,8 +611,8 @@ void avg_pool3d_backward_out_cuda_template( else { avg_pool3d_cuda_update_grad_input <<>>( - work_grad_output.packed_accessor(), - work_grad_input.packed_accessor(), + work_grad_output.packed_accessor64(), + work_grad_input.packed_accessor64(), kT, kH, kW, dT, dH, dW, padT, padH, padW, diff --git a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu index f33cced00f6997..0e9dee088897dd 100644 --- a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu +++ b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu @@ -20,8 +20,8 @@ __device__ inline int min(int a, int b) { template __global__ static void max_pool3d_with_indices_single_out_frame( scalar_t* inputData, - PackedTensorAccessor output, - PackedTensorAccessor indices, + PackedTensorAccessor64 output, + PackedTensorAccessor64 indices, int itime, int iheight, int iwidth, int kT, int kH, int kW, int dT, int dH, int dW, @@ -81,8 +81,8 @@ __global__ static void max_pool3d_with_indices_single_out_frame( template __global__ static void max_pool3d_with_indices_single_out_frame( scalar_t* inputData, - PackedTensorAccessor output, - PackedTensorAccessor indices, + PackedTensorAccessor64 output, + PackedTensorAccessor64 indices, int itime, int iheight, int iwidth, int kT, int kH, int dT, int dH, int dW, @@ -143,8 +143,8 @@ __global__ static void max_pool3d_with_indices_single_out_frame( max_pool3d_with_indices_single_out_frame \ <<>>( \ input_data, \ - output.packed_accessor(), \ - indices.packed_accessor(), \ + output.packed_accessor64(), \ + indices.packed_accessor64(), \ itime, iheight, iwidth, \ kT, kH, \ dT, dH, dW, \ @@ -185,8 +185,8 @@ void max_pool3d_with_indices_out_frame( max_pool3d_with_indices_single_out_frame <<>>( input_data, - output.packed_accessor(), - indices.packed_accessor(), + output.packed_accessor64(), + indices.packed_accessor64(), itime, iheight, iwidth, kT, kH, kW, dT, dH, dW, @@ -209,8 +209,8 @@ void max_pool3d_with_indices_out_frame( template __global__ static void max_pool3d_with_indices_backward_single_out_frame( scalar_t *gradInputData, - PackedTensorAccessor gradOutput, - PackedTensorAccessor indices, + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 indices, int itime, int iheight, int iwidth, int dT, int dH, int dW, int pT, int pH, int pW, @@ -255,8 +255,8 @@ void max_pool3d_with_indices_backward_out_frame( max_pool3d_with_indices_backward_single_out_frame <<>>( gradInputData, - gradOutput.packed_accessor(), - indices.packed_accessor(), + gradOutput.packed_accessor64(), + indices.packed_accessor64(), itime, iheight, iwidth, dT, dH, dW, pT, pH, pW, diff --git a/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu index c44b49c004d4ee..ecd7188b273fdb 100644 --- a/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu +++ b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu @@ -40,10 +40,10 @@ __device__ inline int64_t get_intervals( template __global__ void fractional_max_pool3d_out_frame( - PackedTensorAccessor input, - PackedTensorAccessor output, - PackedTensorAccessor indices, - PackedTensorAccessor samples, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, + PackedTensorAccessor64 indices, + PackedTensorAccessor64 samples, int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) { using accscalar_t = at::acc_type; // Output (t, h, w) point that this thread is responsible for @@ -109,9 +109,9 @@ __global__ void fractional_max_pool3d_out_frame( template __global__ void fractional_max_pool3d_backward_out_frame( - PackedTensorAccessor gradInput, - PackedTensorAccessor gradOutput, - PackedTensorAccessor indices) { + PackedTensorAccessor64 gradInput, + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 indices) { // Output (h, w) point that this thread is responsible for int64_t ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x; int64_t plane = blockIdx.y; @@ -236,10 +236,10 @@ void fractional_max_pool3d_out_cuda_template( [&]{ fractional_max_pool3d_out_frame <<>>( - input_.packed_accessor(), - output_.packed_accessor(), - indices_.packed_accessor(), - randomSamples.packed_accessor(), + input_.packed_accessor64(), + output_.packed_accessor64(), + indices_.packed_accessor64(), + randomSamples.packed_accessor64(), poolSizeT, poolSizeH, poolSizeW ); } @@ -326,9 +326,9 @@ void fractional_max_pool3d_backward_out_cuda_template( [&] { fractional_max_pool3d_backward_out_frame <<>>( - gradInput_.packed_accessor(), - gradOutput_.packed_accessor(), - indices_.packed_accessor() + gradInput_.packed_accessor64(), + gradOutput_.packed_accessor64(), + indices_.packed_accessor64() ); } ); diff --git a/aten/src/ATen/native/cuda/MaxUnpooling.cu b/aten/src/ATen/native/cuda/MaxUnpooling.cu index 1db0afd8b3afea..e4131c701bbcd5 100644 --- a/aten/src/ATen/native/cuda/MaxUnpooling.cu +++ b/aten/src/ATen/native/cuda/MaxUnpooling.cu @@ -38,8 +38,8 @@ __global__ void max_unpooling2d_forward_kernel( template __global__ void max_unpooling3d_forward_kernel( - PackedTensorAccessor input, - PackedTensorAccessor indices, + PackedTensorAccessor64 input, + PackedTensorAccessor64 indices, T* output, const int64_t oT, const int64_t oH, @@ -82,8 +82,8 @@ __global__ void max_unpooling3d_backward_kernel( int64_t oT, int64_t oH, int64_t oW, - PackedTensorAccessor indices, - PackedTensorAccessor gradInput, + PackedTensorAccessor64 indices, + PackedTensorAccessor64 gradInput, int offsetZ) { int iColumn = blockIdx.x * blockDim.x + threadIdx.x; int iRow = blockIdx.y * blockDim.y + threadIdx.y; @@ -339,8 +339,8 @@ Tensor& max_unpooling3d_forward_out_cuda( block, 0, at::cuda::getCurrentCUDAStream()>>>( - self.packed_accessor(), - indices.packed_accessor(), + self.packed_accessor64(), + indices.packed_accessor64(), output.data_ptr(), oT, oH, @@ -558,8 +558,8 @@ at::Tensor& max_unpooling3d_backward_out_cuda( oT, oH, oW, - indices.packed_accessor(), - grad_input_reshaped.packed_accessor(), + indices.packed_accessor64(), + grad_input_reshaped.packed_accessor64(), offsetZ); TORCH_CHECK( cudaGetLastError() == cudaSuccess, diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index cba05589d5fe88..414db1ac9e37e2 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -157,12 +157,12 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) { template __global__ void batch_norm_transform_input_kernel( - const PackedTensorAccessor input, - PackedTensorAccessor output, - const PackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> mean_, - const PackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> var_or_invstd, - const PackedTensorAccessor weight, - const PackedTensorAccessor bias, + const GenericPackedTensorAccessor input, + GenericPackedTensorAccessor output, + const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> mean_, + const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> var_or_invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor bias, stat_accscalar_t epsilon) { index_t plane = blockIdx.x; @@ -214,13 +214,13 @@ struct Var { template class VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t> __global__ void batch_norm_collect_statistics_kernel( - const PackedTensorAccessor input, + const GenericPackedTensorAccessor input, const stat_accscalar_t epsilon, const stat_accscalar_t momentum, - PackedTensorAccessor running_mean, - PackedTensorAccessor running_var, - PackedTensorAccessor save_mean, - PackedTensorAccessor save_transformed_var) { + GenericPackedTensorAccessor running_mean, + GenericPackedTensorAccessor running_var, + GenericPackedTensorAccessor save_mean, + GenericPackedTensorAccessor save_transformed_var) { __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE]; @@ -310,16 +310,16 @@ __global__ void batch_norm_collect_statistics_kernel( template __global__ void batch_norm_backward_kernel( - const PackedTensorAccessor input, - const PackedTensorAccessor grad_output, - PackedTensorAccessor grad_input, - PackedTensorAccessor grad_weight, - PackedTensorAccessor grad_bias, - const PackedTensorAccessor weight, - const PackedTensorAccessor running_mean, - const PackedTensorAccessor running_var, - const PackedTensorAccessor save_mean, - const PackedTensorAccessor save_invstd, + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor grad_input, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor running_mean, + const GenericPackedTensorAccessor running_var, + const GenericPackedTensorAccessor save_mean, + const GenericPackedTensorAccessor save_invstd, bool train, stat_accscalar_t epsilon) { @@ -341,9 +341,9 @@ __global__ void batch_norm_backward_kernel( // Compute two values across (batch, x/y/z) in one pass: // 1. Sum(grad_output) // 2. DotProduct(input - mean, grad_output) - GradOp> g(mean, input, grad_output); + GradOp> g(mean, input, grad_output); Float2 res = reduce, GradOp>>(g, grad_output, plane); + GenericPackedTensorAccessor>>(g, grad_output, plane); stat_accscalar_t grad_output_sum = res.v1; stat_accscalar_t dot_p = res.v2; @@ -381,15 +381,15 @@ __global__ void batch_norm_backward_kernel( template __global__ void batch_norm_reduce_statistics_kernel( - const PackedTensorAccessor vec_mean, - const PackedTensorAccessor vec_invstd, - PackedTensorAccessor mean, - PackedTensorAccessor invstd, - PackedTensorAccessor running_mean, - PackedTensorAccessor running_var, + const GenericPackedTensorAccessor vec_mean, + const GenericPackedTensorAccessor vec_invstd, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor running_mean, + GenericPackedTensorAccessor running_var, const accscalar_t epsilon, const accscalar_t momentum, - const PackedTensorAccessor counts) { + const GenericPackedTensorAccessor counts) { int feature_size = vec_mean.size(1); int world_size = vec_mean.size(0); @@ -427,14 +427,14 @@ __global__ void batch_norm_reduce_statistics_kernel( template __global__ void batch_norm_backward_reduce_kernel( - const PackedTensorAccessor input, - const PackedTensorAccessor grad_output, - PackedTensorAccessor mean, - PackedTensorAccessor invstd, - PackedTensorAccessor mean_dy, - PackedTensorAccessor mean_dy_xmu, - PackedTensorAccessor grad_weight, - PackedTensorAccessor grad_bias) { + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor mean_dy, + GenericPackedTensorAccessor mean_dy_xmu, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias) { index_t plane = blockIdx.x; index_t N = input.size(0) * input.size(2); @@ -442,9 +442,9 @@ __global__ void batch_norm_backward_reduce_kernel( stat_accscalar_t r_mean = mean[plane]; stat_accscalar_t factor = invstd[plane]; - GradOp> g(r_mean, input, grad_output); + GradOp> g(r_mean, input, grad_output); Float2 res = reduce, GradOp>>(g, grad_output, plane); + GenericPackedTensorAccessor>>(g, grad_output, plane); stat_accscalar_t norm = stat_accscalar_t(1) / N; if (threadIdx.x == 0) { @@ -465,14 +465,14 @@ __global__ void batch_norm_backward_reduce_kernel( template __global__ void batch_norm_backward_elemt_kernel( - const PackedTensorAccessor input, - const PackedTensorAccessor grad_output, - const PackedTensorAccessor mean, - const PackedTensorAccessor invstd, - const PackedTensorAccessor weight, - const PackedTensorAccessor mean_dy, - const PackedTensorAccessor mean_dy_xmu, - PackedTensorAccessor grad_input) { + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor mean_dy, + const GenericPackedTensorAccessor mean_dy_xmu, + GenericPackedTensorAccessor grad_input) { index_t plane = blockIdx.x; @@ -502,12 +502,12 @@ __global__ void batch_norm_backward_elemt_kernel( } template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> -static PackedTensorAccessor packed_accessor_or_dummy(const Tensor& t) { +static GenericPackedTensorAccessor packed_accessor_or_dummy(const Tensor& t) { if (! t.defined()) { const std::vector zeros(dim); - return PackedTensorAccessor(nullptr, zeros.data(), zeros.data()); + return GenericPackedTensorAccessor(nullptr, zeros.data(), zeros.data()); } - return t.packed_accessor(); + return t.generic_packed_accessor(); } template @@ -532,7 +532,7 @@ std::tuple batch_norm_cuda_template(const Tensor& input_ auto bs = input_reshaped.size(0); auto features = input_reshaped.size(2); - auto input = input_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); auto input_options = input_.options(); if (input_.scalar_type() == at::ScalarType::Half) { input_options = input_options.dtype(ScalarType::Float); @@ -544,13 +544,13 @@ std::tuple batch_norm_cuda_template(const Tensor& input_ save_mean_ = at::empty({0}, input_options); save_invstd_ = at::empty({0}, input_options); } - auto output = output_reshaped.packed_accessor(); + auto output = output_reshaped.generic_packed_accessor(); auto weight = packed_accessor_or_dummy(weight_); auto bias = packed_accessor_or_dummy(bias_); auto running_mean = packed_accessor_or_dummy(running_mean_); auto running_var = packed_accessor_or_dummy(running_var_); - auto save_mean = save_mean_.packed_accessor(); - auto save_invstd = save_invstd_.packed_accessor(); + auto save_mean = save_mean_.generic_packed_accessor(); + auto save_invstd = save_invstd_.generic_packed_accessor(); auto stream = at::cuda::getCurrentCUDAStream(); // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean, @@ -606,8 +606,8 @@ std::tuple batch_norm_backward_cuda_template(const Tenso grad_bias_ = at::empty_like(weight_); } - auto input = input_reshaped.packed_accessor(); - auto grad_output = grad_output_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); + auto grad_output = grad_output_reshaped.generic_packed_accessor(); auto grad_input = packed_accessor_or_dummy(grad_input_reshaped); auto weight = packed_accessor_or_dummy(weight_); auto grad_weight = packed_accessor_or_dummy(grad_weight_); @@ -643,7 +643,7 @@ std::tuple batch_norm_stats_cuda_template(const Tensor& input_, auto bs = input_reshaped.size(0); auto features = input_reshaped.size(2); - auto input = input_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); auto input_options = input_.options(); dummy_mean_ = at::empty({0}, input_options); dummy_var_ = at::empty({0}, input_options); @@ -655,8 +655,8 @@ std::tuple batch_norm_stats_cuda_template(const Tensor& input_, invstd_ = at::empty({n_input}, input_options); auto mean = packed_accessor_or_dummy(mean_); auto invstd = packed_accessor_or_dummy(invstd_); - auto dummy_mean = dummy_mean_.packed_accessor(); - auto dummy_invstd = dummy_var_.packed_accessor(); + auto dummy_mean = dummy_mean_.generic_packed_accessor(); + auto dummy_invstd = dummy_var_.generic_packed_accessor(); auto stream = at::cuda::getCurrentCUDAStream(); dim3 blocks(input.size(1)); @@ -680,12 +680,12 @@ Tensor batch_norm_elemt_cuda_template(const Tensor& input_, const Tensor& weight auto bs = input_reshaped.size(0); auto features = input_reshaped.size(2); - auto input = input_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); auto input_options = input_.options(); if (input_.scalar_type() == at::ScalarType::Half) { input_options = input_options.dtype(ScalarType::Float); } - auto output = output_reshaped.packed_accessor(); + auto output = output_reshaped.generic_packed_accessor(); auto weight = packed_accessor_or_dummy(weight_); auto bias = packed_accessor_or_dummy(bias_); auto mean = packed_accessor_or_dummy(mean_); @@ -730,8 +730,8 @@ std::tuple batch_norm_gather_stats_cuda_template(const Tensor& m auto running_var = packed_accessor_or_dummy(running_var_); auto counts = packed_accessor_or_dummy(counts_); - auto save_mean = save_mean_.packed_accessor(); - auto save_invstd = save_invstd_.packed_accessor(); + auto save_mean = save_mean_.generic_packed_accessor(); + auto save_invstd = save_invstd_.generic_packed_accessor(); auto stream = at::cuda::getCurrentCUDAStream(); int block = getNumThreads(features); @@ -767,8 +767,8 @@ std::tuple batch_norm_backward_reduce_cuda_templ grad_bias_ = at::empty({n_input}, weight_.options()); } - auto input = input_reshaped.packed_accessor(); - auto grad_output = grad_output_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); + auto grad_output = grad_output_reshaped.generic_packed_accessor(); auto grad_weight = packed_accessor_or_dummy(grad_weight_); auto grad_bias = packed_accessor_or_dummy(grad_bias_); auto mean = packed_accessor_or_dummy(mean_); @@ -806,9 +806,9 @@ Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Te auto bs = input_reshaped.size(0); auto features = input_reshaped.size(2); - auto input = input_reshaped.packed_accessor(); - auto grad_input = grad_input_reshaped.packed_accessor(); - auto grad_output = grad_output_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); + auto grad_input = grad_input_reshaped.generic_packed_accessor(); + auto grad_output = grad_output_reshaped.generic_packed_accessor(); auto mean = packed_accessor_or_dummy(mean_); auto invstd = packed_accessor_or_dummy(invstd_); auto weight = packed_accessor_or_dummy(weight_); @@ -848,11 +848,11 @@ std::tuple batch_norm_update_stats_cuda_template( Tensor save_mean_ = at::empty({n_channels}, input_options); Tensor save_var_ = at::empty({n_channels}, input_options); - auto input = input_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); auto running_mean = packed_accessor_or_dummy(running_mean_); auto running_var = packed_accessor_or_dummy(running_var_); - auto save_mean = save_mean_.packed_accessor(); - auto save_var = save_var_.packed_accessor(); + auto save_mean = save_mean_.generic_packed_accessor(); + auto save_var = save_var_.generic_packed_accessor(); auto stream = at::cuda::getCurrentCUDAStream(); // for the reduction, we cannot use blocks for the batch dim, but if we have few threads in diff --git a/aten/src/ATen/native/cuda/ReplicationPadding.cu b/aten/src/ATen/native/cuda/ReplicationPadding.cu index c9da8f440b7297..ba51fc2105350c 100644 --- a/aten/src/ATen/native/cuda/ReplicationPadding.cu +++ b/aten/src/ATen/native/cuda/ReplicationPadding.cu @@ -27,8 +27,8 @@ __host__ __device__ __forceinline__ int imax(int a, int b) { namespace { template __global__ void replication_pad_forward_kernel1d( - PackedTensorAccessor input, - PackedTensorAccessor output, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, int padL, int padR) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; @@ -50,8 +50,8 @@ __global__ void replication_pad_forward_kernel1d( template __global__ void replication_pad_backward_kernel( - PackedTensorAccessor gradInput, - PackedTensorAccessor gradOutput, + PackedTensorAccessor64 gradInput, + PackedTensorAccessor64 gradOutput, int padL, int padR) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; @@ -73,8 +73,8 @@ __global__ void replication_pad_backward_kernel( template __global__ void replication_pad_forward_kernel2d( - PackedTensorAccessor input, - PackedTensorAccessor output, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, int padT, int padB, int padL, int padR) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; @@ -100,8 +100,8 @@ __global__ void replication_pad_forward_kernel2d( template __global__ void replication_pad_backward_kernel( - PackedTensorAccessor gradInput, - PackedTensorAccessor gradOutput, + PackedTensorAccessor64 gradInput, + PackedTensorAccessor64 gradOutput, int padT, int padB, int padL, int padR) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; @@ -127,8 +127,8 @@ __global__ void replication_pad_backward_kernel( template __global__ void replication_pad_forward_kernel3d( - PackedTensorAccessor input, - PackedTensorAccessor output, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, int pfront, int pback, int ptop, int pbottom, int pleft, int pright) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; @@ -163,8 +163,8 @@ __global__ void replication_pad_forward_kernel3d( template __global__ void replication_pad_backward_kernel( - PackedTensorAccessor gradInput, - PackedTensorAccessor gradOutput, + PackedTensorAccessor64 gradInput, + PackedTensorAccessor64 gradOutput, int pfront, int pback, int ptop, int pbottom, int pleft, int pright) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; int plane = blockIdx.y; @@ -242,8 +242,8 @@ void replication_pad1d_out_cuda_template( output.resize_({numPlanes, outputW}); auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); - auto devInput = input_.packed_accessor(); - auto devOutput = output_.packed_accessor(); + auto devInput = input_.packed_accessor64(); + auto devOutput = output_.packed_accessor64(); int outputPlaneSize = devOutput.size(2); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -255,8 +255,8 @@ void replication_pad1d_out_cuda_template( at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput, padL, padR); } else { output.resize_({numBatch, numPlanes, outputW}); - auto devInput = input.packed_accessor(); - auto devOutput = output.packed_accessor(); + auto devInput = input.packed_accessor64(); + auto devOutput = output.packed_accessor64(); int outputPlaneSize = devOutput.size(2); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -314,8 +314,8 @@ void replication_pad1d_backward_out_cuda_template( gradInput_ = gradInput.unsqueeze(0); gradOutput_ = gradOutput.unsqueeze(0); } - auto devGradInput = gradInput_.packed_accessor(); - auto devGradOutput = gradOutput_.packed_accessor(); + auto devGradInput = gradInput_.packed_accessor64(); + auto devGradOutput = gradOutput_.packed_accessor64(); int outputPlaneSize = devGradOutput.size(2); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -379,8 +379,8 @@ void replication_pad2d_out_cuda_template( output.resize_({numPlanes, outputH, outputW}); auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); - auto devInput = input_.packed_accessor(); - auto devOutput = output_.packed_accessor(); + auto devInput = input_.packed_accessor64(); + auto devOutput = output_.packed_accessor64(); int outputPlaneSize = devOutput.size(2) * devOutput.size(3); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -393,8 +393,8 @@ void replication_pad2d_out_cuda_template( devInput, devOutput, padT, padB, padL, padR); } else { output.resize_({numBatch, numPlanes, outputH, outputW}); - auto devInput = input.packed_accessor(); - auto devOutput = output.packed_accessor(); + auto devInput = input.packed_accessor64(); + auto devOutput = output.packed_accessor64(); int outputPlaneSize = devOutput.size(2) * devOutput.size(3); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -462,8 +462,8 @@ void replication_pad2d_backward_out_cuda_template( gradInput_ = gradInput.unsqueeze(0); gradOutput_ = gradOutput.unsqueeze(0); } - auto devGradInput = gradInput_.packed_accessor(); - auto devGradOutput = gradOutput_.packed_accessor(); + auto devGradInput = gradInput_.packed_accessor64(); + auto devGradOutput = gradOutput_.packed_accessor64(); int outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -614,8 +614,8 @@ void replication_pad3d_out_cuda_template( output.resize_({numPlanes, outputD, outputH, outputW}); auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); - auto devInput = input_.packed_accessor(); - auto devOutput = output_.packed_accessor(); + auto devInput = input_.packed_accessor64(); + auto devOutput = output_.packed_accessor64(); int outputPlaneSize = devOutput.size(2) * devOutput.size(3) * devOutput.size(4); @@ -629,8 +629,8 @@ void replication_pad3d_out_cuda_template( devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright); } else { output.resize_({numBatch, numPlanes, outputD, outputH, outputW}); - auto devInput = input.packed_accessor(); - auto devOutput = output.packed_accessor(); + auto devInput = input.packed_accessor64(); + auto devOutput = output.packed_accessor64(); int outputPlaneSize = devOutput.size(2) * devOutput.size(3) * devOutput.size(4); @@ -689,8 +689,8 @@ void replication_pad3d_backward_out_cuda_template( gradInput_ = gradInput.unsqueeze(0); gradOutput_ = gradOutput.unsqueeze(0); } - auto devGradInput = gradInput_.packed_accessor(); - auto devGradOutput = gradOutput_.packed_accessor(); + auto devGradInput = gradInput_.packed_accessor64(); + auto devGradOutput = gradOutput_.packed_accessor64(); int outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3) * devGradOutput.size(4); diff --git a/aten/src/ATen/native/cuda/UpSample.cuh b/aten/src/ATen/native/cuda/UpSample.cuh index 3b398e27cb6e57..0bde9149136a32 100644 --- a/aten/src/ATen/native/cuda/UpSample.cuh +++ b/aten/src/ATen/native/cuda/UpSample.cuh @@ -166,7 +166,7 @@ __device__ __forceinline__ static int nearest_neighbor_compute_source_index( /* Used by UpSampleBicubic2d.cu */ template __device__ __forceinline__ static scalar_t upsample_get_value_bounded( - const PackedTensorAccessor& data, + const PackedTensorAccessor64& data, int batch, int channel, int height, @@ -181,7 +181,7 @@ __device__ __forceinline__ static scalar_t upsample_get_value_bounded( /* Used by UpSampleBicubic2d.cu */ template __device__ __forceinline__ static void upsample_increment_value_bounded( - PackedTensorAccessor& data, + PackedTensorAccessor64& data, int batch, int channel, int height, diff --git a/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu index 443e88ec078b80..cd030051728906 100644 --- a/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu @@ -18,8 +18,8 @@ __global__ void upsample_bicubic2d_out_frame( const accscalar_t height_scale, const accscalar_t width_scale, const bool align_corners, - const PackedTensorAccessor idata, - PackedTensorAccessor odata) { + const PackedTensorAccessor64 idata, + PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -93,8 +93,8 @@ __global__ void upsample_bicubic2d_backward_out_frame( const accscalar_t height_scale, const accscalar_t width_scale, const bool align_corners, - PackedTensorAccessor idata, - const PackedTensorAccessor odata) { + PackedTensorAccessor64 idata, + const PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -206,8 +206,8 @@ static void upsample_bicubic2d_out_cuda_template( input.scalar_type(), "upsample_bicubic2d_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = input.packed_accessor(); - auto odata = output.packed_accessor(); + auto idata = input.packed_accessor64(); + auto odata = output.packed_accessor64(); // Get scaling factors const accscalar_t rheight = area_pixel_compute_scale( @@ -285,8 +285,8 @@ static void upsample_bicubic2d_backward_out_cuda_template( grad_output.scalar_type(), "upsample_bicubic2d_backward_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = grad_input.packed_accessor(); - auto odata = grad_output.packed_accessor(); + auto idata = grad_input.packed_accessor64(); + auto odata = grad_output.packed_accessor64(); const accscalar_t rheight = area_pixel_compute_scale( input_height, output_height, align_corners); diff --git a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu index 1f3f566893cc66..d8a8ed8904fa32 100644 --- a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu @@ -197,8 +197,8 @@ static void upsample_bilinear2d_out_cuda_template( input.scalar_type(), "upsample_bilinear2d_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = input.packed_accessor(); - auto odata = output.packed_accessor(); + auto idata = input.packed_accessor64(); + auto odata = output.packed_accessor64(); const accscalar_t rheight = area_pixel_compute_scale( input_height, output_height, align_corners); diff --git a/aten/src/ATen/native/cuda/UpSampleLinear1d.cu b/aten/src/ATen/native/cuda/UpSampleLinear1d.cu index 0f70b57344cb6e..b4fc8d5a5afd9a 100644 --- a/aten/src/ATen/native/cuda/UpSampleLinear1d.cu +++ b/aten/src/ATen/native/cuda/UpSampleLinear1d.cu @@ -21,8 +21,8 @@ __global__ void upsample_linear1d_out_frame( const int n, const accscalar_t rwidth, const bool align_corners, - const PackedTensorAccessor idata, - PackedTensorAccessor odata) { + const PackedTensorAccessor64 idata, + PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -70,8 +70,8 @@ __global__ void upsample_linear1d_out_frame_backward( const int n, const accscalar_t rwidth, const bool align_corners, - PackedTensorAccessor idata, - const PackedTensorAccessor odata) { + PackedTensorAccessor64 idata, + const PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -147,8 +147,8 @@ static void upsample_linear1d_out_cuda_template( input.scalar_type(), "upsample_linear1d_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = input.packed_accessor(); - auto odata = output.packed_accessor(); + auto idata = input.packed_accessor64(); + auto odata = output.packed_accessor64(); const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners); @@ -207,8 +207,8 @@ static void upsample_linear1d_backward_out_cuda_template( grad_output.scalar_type(), "upsample_linear1d_out_frame_backward", [&] { using accscalar_t = at::acc_type; - auto idata = grad_input.packed_accessor(); - auto odata = grad_output.packed_accessor(); + auto idata = grad_input.packed_accessor64(); + auto odata = grad_output.packed_accessor64(); const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners); diff --git a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu index 683860e8a466b7..73799b088a64e2 100644 --- a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu +++ b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu @@ -21,8 +21,8 @@ __global__ void upsample_trilinear3d_out_frame( const accscalar_t rheight, const accscalar_t rwidth, const bool align_corners, - const PackedTensorAccessor idata, - PackedTensorAccessor odata) { + const PackedTensorAccessor64 idata, + PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -105,8 +105,8 @@ __global__ void upsample_trilinear3d_backward_out_frame( const accscalar_t rheight, const accscalar_t rwidth, const bool align_corners, - PackedTensorAccessor idata, - const PackedTensorAccessor odata) { + PackedTensorAccessor64 idata, + const PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -245,8 +245,8 @@ static void upsample_trilinear3d_out_cuda_template( input.scalar_type(), "upsample_trilinear3d_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = input.packed_accessor(); - auto odata = output.packed_accessor(); + auto idata = input.packed_accessor64(); + auto odata = output.packed_accessor64(); const accscalar_t rdepth = area_pixel_compute_scale( input_depth, output_depth, align_corners); @@ -332,8 +332,8 @@ static void upsample_trilinear3d_backward_out_cuda_template( [&] { using accscalar_t = at::acc_type; - auto idata = grad_input.packed_accessor(); - auto odata = grad_output.packed_accessor(); + auto idata = grad_input.packed_accessor64(); + auto odata = grad_output.packed_accessor64(); const accscalar_t rdepth = area_pixel_compute_scale( input_depth, output_depth, align_corners); diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 49c23e695fda72..1fd7c4e16542f6 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -317,19 +318,42 @@ class CAFFE2_API Tensor { template TensorAccessor accessor() && = delete; - // Return a `PackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and + // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and // dimension. You can optionally specify RestrictPtrTraits as a template parameter to // cast the data pointer to a __restrict__ pointer. - // In order to use this, your CUDA kernel has to take a corresponding PackedTensorAccessor + // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor // as an argument. template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - PackedTensorAccessor packed_accessor() const& { + GenericPackedTensorAccessor generic_packed_accessor() const& { static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); TORCH_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim()); - return PackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); + return GenericPackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); } - template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - PackedTensorAccessor packed_accessor() && = delete; + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + GenericPackedTensorAccessor generic_packed_accessor() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() && = delete; + + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() const & { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() && = delete; Tensor operator-() const; Tensor& operator+=(const Tensor & other); diff --git a/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu b/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu index ff0c0c4eb89e5e..12d3b3d9731f1b 100644 --- a/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu +++ b/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu @@ -9,9 +9,9 @@ using namespace at; __global__ void test_tensor_packed_accessor_kernel( - PackedTensorAccessor resa, - PackedTensorAccessor t1a, - PackedTensorAccessor t2a) { + PackedTensorAccessor64 resa, + PackedTensorAccessor64 t1a, + PackedTensorAccessor64 t2a) { for (int64_t i = 0; i < resa.size(0); i++) { float val = 0.0f; for (int64_t j = 0; j < t1a.size(1); j++) { @@ -21,7 +21,7 @@ __global__ void test_tensor_packed_accessor_kernel( } } -// test PackedTensorAccessor and Tensor.packed_accessor +// test GenericPackedTensorAccessor and Tensor.generic_packed_accessor TEST(PackedtensoraccessorTest, PackedtensoraccessorTestCUDA) { if (!at::cuda::is_available()) return; manual_seed(123); @@ -30,9 +30,9 @@ TEST(PackedtensoraccessorTest, PackedtensoraccessorTestCUDA) { Tensor t2 = rand({4}, CUDA(kFloat)); Tensor res = empty({4}, CUDA(kFloat)); - auto t1a = t1.packed_accessor(); - auto t2a = t2.packed_accessor(); - auto resa = res.packed_accessor(); + auto t1a = t1.packed_accessor64(); + auto t2a = t2.packed_accessor64(); + auto resa = res.packed_accessor64(); auto stream = at::cuda::getCurrentCUDAStream(); diff --git a/docs/cpp/source/notes/tensor_basics.rst b/docs/cpp/source/notes/tensor_basics.rst index 5d25efcf68de95..09032546a3a9ae 100644 --- a/docs/cpp/source/notes/tensor_basics.rst +++ b/docs/cpp/source/notes/tensor_basics.rst @@ -76,20 +76,25 @@ CUDA accessors .. code-block:: cpp __global__ void packed_accessor_kernel( - PackedTensorAccessor foo, + PackedTensorAccessor64 foo, float* trace) { int i=threadIdx.x atomicAdd(trace, foo[i][i]) } - + torch::Tensor foo = torch::rand({12, 12}); // assert foo is 2-dimensional and holds floats. - auto foo_a = foo.packed_accessor(); + auto foo_a = foo.packed_accessor64(); float trace = 0; packed_accessor_kernel<<<1, 12>>>(foo_a, &trace); +In addition to ``PackedTensorAccessor64`` and ``packed_accessor64`` there are +also the corresponding ``PackedTensorAccessor32`` and ``packed_accessor32`` +which use 32-bit integers for indexing. This can be quite a bit faster on CUDA +but may lead to overflows in the indexing calculations. + Note that the template can hold other parameters such as the pointer restriction and the integer type for indexing. See documentation for a thorough template description of *accessors* and *packed accessors*.