Skip to content

Commit

Permalink
Rename packed tensor accessor (pytorch#25654)
Browse files Browse the repository at this point in the history
Summary:
Closes pytorch#19268

This does the renaming suggested by ezyang in pytorch#19268 (comment) except that the templated version of `packed_accessor` is also renamed to `generic_packed_accessor`.

Additionally, all of the users I could find in `ATen/native/cuda` are updated without changing their index types.

The corresponding tutorial update is in pytorch/tutorials#644
Pull Request resolved: pytorch#25654

Differential Revision: D17259208

Pulled By: ezyang

fbshipit-source-id: 172a46f623d544ca16f7ed5077b6e4f57a3d1f21
  • Loading branch information
peterbell10 authored and facebook-github-bot committed Sep 10, 2019
1 parent e8cc1fd commit 76ee02f
Show file tree
Hide file tree
Showing 16 changed files with 292 additions and 223 deletions.
50 changes: 33 additions & 17 deletions aten/src/ATen/core/TensorAccessor.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#pragma once

#include <c10/macros/Macros.h>
#include <c10/util/Deprecated.h>
#include <stdint.h>
#include <cstddef>

namespace at {

// The PtrTraits argument to the TensorAccessor/PackedTensorAccessor
// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
// is used to enable the __restrict__ keyword/modifier for the data
// passed to cuda.
template <typename T>
Expand Down Expand Up @@ -62,7 +63,7 @@ class TensorAccessorBase {

// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
// `Tensor.accessor<T, N>()`.
// For CUDA `Tensor`s, `PackedTensorAccessor` is used on the host and only
// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only
// indexing on the device uses `TensorAccessor`s.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
Expand Down Expand Up @@ -103,7 +104,7 @@ class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrT
};


// PackedTensorAccessorBase and PackedTensorAccessor are used on for CUDA `Tensor`s on the host
// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host
// and as
// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)
// in order to transfer them on the device when calling kernels.
Expand All @@ -112,10 +113,10 @@ class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrT
// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
// on the device, so those functions are host only.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class PackedTensorAccessorBase {
class GenericPackedTensorAccessorBase {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST PackedTensorAccessorBase(
C10_HOST GenericPackedTensorAccessorBase(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
Expand All @@ -126,7 +127,7 @@ class PackedTensorAccessorBase {

// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type>
C10_HOST PackedTensorAccessorBase(
C10_HOST GenericPackedTensorAccessorBase(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
Expand Down Expand Up @@ -156,23 +157,23 @@ class PackedTensorAccessorBase {
};

template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class PackedTensorAccessor : public PackedTensorAccessorBase<T,N,PtrTraits,index_t> {
class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;

C10_HOST PackedTensorAccessor(
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: PackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}

// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type>
C10_HOST PackedTensorAccessor(
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
: PackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}

C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
index_t* new_sizes = this->sizes_ + 1;
Expand All @@ -188,22 +189,22 @@ class PackedTensorAccessor : public PackedTensorAccessorBase<T,N,PtrTraits,index
};

template<typename T, template <typename U> class PtrTraits, typename index_t>
class PackedTensorAccessor<T,1,PtrTraits,index_t> : public PackedTensorAccessorBase<T,1,PtrTraits,index_t> {
class GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST PackedTensorAccessor(
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: PackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}

// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type>
C10_HOST PackedTensorAccessor(
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
: PackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}

C10_DEVICE T & operator[](index_t i) {
return this->data_[this->strides_[0] * i];
Expand All @@ -213,4 +214,19 @@ class PackedTensorAccessor<T,1,PtrTraits,index_t> : public PackedTensorAccessorB
}
};

}

// Can't put this directly into the macro function args because of commas
#define AT_X GenericPackedTensorAccessor<T, N, PtrTraits, index_t>

// Old name for `GenericPackedTensorAccessor`
template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
C10_DEFINE_DEPRECATED_USING(PackedTensorAccessor, AT_X)

#undef AT_X

template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
using PackedTensorAccessor32 = GenericPackedTensorAccessor<T, N, PtrTraits, int32_t>;

template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
using PackedTensorAccessor64 = GenericPackedTensorAccessor<T, N, PtrTraits, int64_t>;
} // namespace at
36 changes: 30 additions & 6 deletions aten/src/ATen/core/TensorBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <c10/core/TensorImpl.h>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Optional.h>
#include <c10/util/intrusive_ptr.h>
#include <ATen/core/LegacyTypeDispatch.h>
Expand Down Expand Up @@ -317,19 +318,42 @@ class CAFFE2_API Tensor {
template<typename T, size_t N>
TensorAccessor<T,N> accessor() && = delete;

// Return a `PackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
// Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
// dimension. You can optionally specify RestrictPtrTraits as a template parameter to
// cast the data pointer to a __restrict__ pointer.
// In order to use this, your CUDA kernel has to take a corresponding PackedTensorAccessor
// In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
// as an argument.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
PackedTensorAccessor<T,N,PtrTraits,index_t> packed_accessor() const& {
GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& {
static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
TORCH_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim());
return PackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(data_ptr<T>()),sizes().data(),strides().data());
return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(data_ptr<T>()),sizes().data(),strides().data());
}
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
PackedTensorAccessor<T,N> packed_accessor() && = delete;
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete;

template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& {
return generic_packed_accessor<T,N,PtrTraits,int32_t>();
}
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete;

template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& {
return generic_packed_accessor<T,N,PtrTraits,int64_t>();
}
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete;

template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead")
GenericPackedTensorAccessor<T,N,PtrTraits,index_t> packed_accessor() const & {
return generic_packed_accessor<T,N,PtrTraits,index_t>();
}
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead")
GenericPackedTensorAccessor<T,N,PtrTraits,index_t> packed_accessor() && = delete;

Tensor operator-() const;
Tensor& operator+=(const Tensor & other);
Expand Down
40 changes: 20 additions & 20 deletions aten/src/ATen/native/cuda/AveragePool3d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ __device__ inline int max(int a, int b) {

template <typename scalar_t, typename accscalar_t>
__global__ void avg_pool3d_cuda_update_output(
PackedTensorAccessor<scalar_t, 4> input,
PackedTensorAccessor<scalar_t, 4> output,
PackedTensorAccessor64<scalar_t, 4> input,
PackedTensorAccessor64<scalar_t, 4> output,
int kT, int kH, int kW,
int dT, int dH, int dW,
int padT, int padH, int padW,
Expand Down Expand Up @@ -87,8 +87,8 @@ __global__ void avg_pool3d_cuda_update_output(
//
template<int KERNEL_WIDTH, typename scalar_t, typename accscalar_t>
__global__ void avg_pool3d_cuda_update_output(
PackedTensorAccessor<scalar_t, 4> input,
PackedTensorAccessor<scalar_t, 4> output,
PackedTensorAccessor64<scalar_t, 4> input,
PackedTensorAccessor64<scalar_t, 4> output,
int kT, int kH,
int dT, int dH, int dW,
int padT, int padH, int padW,
Expand Down Expand Up @@ -148,8 +148,8 @@ __global__ void avg_pool3d_cuda_update_output(

template <typename scalar_t, typename accscalar_t>
__global__ void avg_pool3d_single_backward_out_frame_stride1(
PackedTensorAccessor<scalar_t, 4> gradOutput,
PackedTensorAccessor<scalar_t, 4> gradInput,
PackedTensorAccessor64<scalar_t, 4> gradOutput,
PackedTensorAccessor64<scalar_t, 4> gradInput,
int kT, int kH, int kW,
accscalar_t normFactor,
int offsetZ)
Expand Down Expand Up @@ -193,8 +193,8 @@ __global__ void avg_pool3d_single_backward_out_frame_stride1(

template <typename scalar_t, typename accscalar_t>
__global__ void avg_pool3d_cuda_update_grad_input_atomic(
PackedTensorAccessor<scalar_t, 4> gradOutput,
PackedTensorAccessor<scalar_t, 4> gradInput,
PackedTensorAccessor64<scalar_t, 4> gradOutput,
PackedTensorAccessor64<scalar_t, 4> gradInput,
int kT, int kH, int kW,
int dT, int dH, int dW,
int padT, int padH, int padW,
Expand Down Expand Up @@ -251,8 +251,8 @@ __global__ void avg_pool3d_cuda_update_grad_input_atomic(

template <typename scalar_t, typename accscalar_t>
__global__ void avg_pool3d_cuda_update_grad_input(
PackedTensorAccessor<scalar_t, 4> gradOutput,
PackedTensorAccessor<scalar_t, 4> gradInput,
PackedTensorAccessor64<scalar_t, 4> gradOutput,
PackedTensorAccessor64<scalar_t, 4> gradInput,
int kT, int kH, int kW,
int dT, int dH, int dW,
int padT, int padH, int padW,
Expand Down Expand Up @@ -309,8 +309,8 @@ __global__ void avg_pool3d_cuda_update_grad_input(
#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
avg_pool3d_cuda_update_output<KW, scalar_t, accscalar_t> \
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( \
work_input.packed_accessor<scalar_t, 4>(), \
work_output.packed_accessor<scalar_t, 4>(), \
work_input.packed_accessor64<scalar_t, 4>(), \
work_output.packed_accessor64<scalar_t, 4>(), \
kT, kH, \
dT, dH, dW, \
padT, padH, padW, \
Expand Down Expand Up @@ -425,8 +425,8 @@ void avg_pool3d_out_cuda_template(
default:
avg_pool3d_cuda_update_output<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_input.packed_accessor<scalar_t, 4>(),
work_output.packed_accessor<scalar_t, 4>(),
work_input.packed_accessor64<scalar_t, 4>(),
work_output.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
dT, dH, dW,
padT, padH, padW,
Expand Down Expand Up @@ -567,8 +567,8 @@ void avg_pool3d_backward_out_cuda_template(

avg_pool3d_single_backward_out_frame_stride1<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_grad_output.packed_accessor<scalar_t, 4>(),
work_grad_input.packed_accessor<scalar_t, 4>(),
work_grad_output.packed_accessor64<scalar_t, 4>(),
work_grad_input.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
1.0f/divide_factor,
offsetZ);
Expand Down Expand Up @@ -600,8 +600,8 @@ void avg_pool3d_backward_out_cuda_template(
if (kernelsOverlap) {
avg_pool3d_cuda_update_grad_input_atomic<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_grad_output.packed_accessor<scalar_t, 4>(),
work_grad_input.packed_accessor<scalar_t, 4>(),
work_grad_output.packed_accessor64<scalar_t, 4>(),
work_grad_input.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
dT, dH, dW,
padT, padH, padW,
Expand All @@ -611,8 +611,8 @@ void avg_pool3d_backward_out_cuda_template(
else {
avg_pool3d_cuda_update_grad_input<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_grad_output.packed_accessor<scalar_t, 4>(),
work_grad_input.packed_accessor<scalar_t, 4>(),
work_grad_output.packed_accessor64<scalar_t, 4>(),
work_grad_input.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
dT, dH, dW,
padT, padH, padW,
Expand Down
24 changes: 12 additions & 12 deletions aten/src/ATen/native/cuda/DilatedMaxPool3d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ __device__ inline int min(int a, int b) {
template <typename scalar_t>
__global__ static void max_pool3d_with_indices_single_out_frame(
scalar_t* inputData,
PackedTensorAccessor<scalar_t, 4> output,
PackedTensorAccessor<int64_t, 4> indices,
PackedTensorAccessor64<scalar_t, 4> output,
PackedTensorAccessor64<int64_t, 4> indices,
int itime, int iheight, int iwidth,
int kT, int kH, int kW,
int dT, int dH, int dW,
Expand Down Expand Up @@ -81,8 +81,8 @@ __global__ static void max_pool3d_with_indices_single_out_frame(
template <int KERNEL_WIDTH, typename scalar_t>
__global__ static void max_pool3d_with_indices_single_out_frame(
scalar_t* inputData,
PackedTensorAccessor<scalar_t, 4> output,
PackedTensorAccessor<int64_t, 4> indices,
PackedTensorAccessor64<scalar_t, 4> output,
PackedTensorAccessor64<int64_t, 4> indices,
int itime, int iheight, int iwidth,
int kT, int kH,
int dT, int dH, int dW,
Expand Down Expand Up @@ -143,8 +143,8 @@ __global__ static void max_pool3d_with_indices_single_out_frame(
max_pool3d_with_indices_single_out_frame<KW> \
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( \
input_data, \
output.packed_accessor<scalar_t, 4>(), \
indices.packed_accessor<int64_t, 4>(), \
output.packed_accessor64<scalar_t, 4>(), \
indices.packed_accessor64<int64_t, 4>(), \
itime, iheight, iwidth, \
kT, kH, \
dT, dH, dW, \
Expand Down Expand Up @@ -185,8 +185,8 @@ void max_pool3d_with_indices_out_frame(
max_pool3d_with_indices_single_out_frame
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
input_data,
output.packed_accessor<scalar_t, 4>(),
indices.packed_accessor<int64_t, 4>(),
output.packed_accessor64<scalar_t, 4>(),
indices.packed_accessor64<int64_t, 4>(),
itime, iheight, iwidth,
kT, kH, kW,
dT, dH, dW,
Expand All @@ -209,8 +209,8 @@ void max_pool3d_with_indices_out_frame(
template <typename scalar_t>
__global__ static void max_pool3d_with_indices_backward_single_out_frame(
scalar_t *gradInputData,
PackedTensorAccessor<scalar_t, 4> gradOutput,
PackedTensorAccessor<int64_t, 4> indices,
PackedTensorAccessor64<scalar_t, 4> gradOutput,
PackedTensorAccessor64<int64_t, 4> indices,
int itime, int iheight, int iwidth,
int dT, int dH, int dW,
int pT, int pH, int pW,
Expand Down Expand Up @@ -255,8 +255,8 @@ void max_pool3d_with_indices_backward_out_frame(
max_pool3d_with_indices_backward_single_out_frame
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
gradInputData,
gradOutput.packed_accessor<scalar_t, 4>(),
indices.packed_accessor<int64_t, 4>(),
gradOutput.packed_accessor64<scalar_t, 4>(),
indices.packed_accessor64<int64_t, 4>(),
itime, iheight, iwidth,
dT, dH, dW,
pT, pH, pW,
Expand Down
Loading

0 comments on commit 76ee02f

Please sign in to comment.