Skip to content

Commit

Permalink
Implement scatter reductions (CUDA), remove divide/subtract (pytorch#…
Browse files Browse the repository at this point in the history
…41977)

Summary:
Fixes pytorch#33394 .

This PR does two things:
1. Implement CUDA scatter reductions with revamped GPU atomic operations.
2. Remove support for divide and subtract for CPU reduction as was discussed with ngimel .

I've also updated the docs to reflect the existence of only multiply and add.

Pull Request resolved: pytorch#41977

Reviewed By: mruberry

Differential Revision: D23748888

Pulled By: ngimel

fbshipit-source-id: ea643c0da03c9058e433de96db02b503514c4e9c
  • Loading branch information
v0dro authored and facebook-github-bot committed Sep 17, 2020
1 parent fdeee74 commit e18a221
Show file tree
Hide file tree
Showing 13 changed files with 503 additions and 176 deletions.
10 changes: 2 additions & 8 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,22 +587,16 @@ SCATTER_GATHER_OP get_operator_enum(const std::string& reduce) {
if (reduce == "add") {
return SCATTER_GATHER_OP::REDUCE_ADD;
}
else if (reduce == "subtract") {
return SCATTER_GATHER_OP::REDUCE_SUBTRACT;
}
else if (reduce == "multiply") {
return SCATTER_GATHER_OP::REDUCE_MULTIPLY;
}
else if (reduce == "divide") {
return SCATTER_GATHER_OP::REDUCE_DIVIDE;
}
else {
TORCH_CHECK(false,
"reduce argument must be either of add, subtract, multiply or divide.");
}
}

Tensor& scatter_cpu_scalar_reduce_(Tensor& self, const int64_t dim, const Tensor& index,
Tensor& scatter_scalar_reduce_(Tensor& self, const int64_t dim, const Tensor& index,
Scalar value, const std::string reduce) {
TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long,
"scatter_(): Expected dtype int64 for index.");
Expand All @@ -613,7 +607,7 @@ Tensor& scatter_cpu_scalar_reduce_(Tensor& self, const int64_t dim, const Tensor
return self;
}

Tensor & scatter_cpu_reduce_(Tensor & self, const int64_t dim, const Tensor & index,
Tensor & scatter_reduce_(Tensor & self, const int64_t dim, const Tensor & index,
const Tensor & src, const std::string reduce) {
TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long,
"scatter_(): Expected dtype int64 for index");
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/TensorAdvancedIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace at {

namespace at { namespace native {

enum class SCATTER_GATHER_OP: uint8_t {REDUCE_ADD, REDUCE_SUBTRACT, REDUCE_MULTIPLY, REDUCE_DIVIDE};
enum class SCATTER_GATHER_OP: uint8_t {REDUCE_ADD, REDUCE_MULTIPLY};

using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
Expand Down
34 changes: 1 addition & 33 deletions aten/src/ATen/native/cpu/ScatterGatherKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace at { namespace native {

namespace {

// Implement as functors since lambdas don't get optimized.
class ReduceMultiply {
public:
Expand All @@ -31,24 +31,6 @@ class ReduceAdd {
};
static ReduceAdd reduce_add;

class ReduceSubtract {
public:
template <typename scalar_t>
constexpr void operator() (scalar_t * self_data, scalar_t * src_data) const {
*self_data -= *src_data;
}
};
static ReduceSubtract reduce_subtract;

class ReduceDivide {
public:
template <typename scalar_t>
constexpr void operator() (scalar_t * self_data, scalar_t * src_data) const {
*self_data /= *src_data;
}
};
static ReduceDivide reduce_divide;

class TensorAssign {
public:
template <typename scalar_t>
Expand Down Expand Up @@ -348,17 +330,10 @@ void scatter_reduce_cpu_kernel(Tensor& self, const int64_t dim, const Tensor& in
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
"scatter_reduce_add_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_SUBTRACT :
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
"scatter_reduce_subtract_", reduce_subtract);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
"scatter_reduce_multiply_", reduce_multiply);
break;
case SCATTER_GATHER_OP::REDUCE_DIVIDE :
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
"scatter_reduce_divide_", reduce_divide);
}
}

Expand All @@ -369,17 +344,10 @@ void scatter_scalar_reduce_cpu_kernel(Tensor& self, const int64_t dim, const Ten
cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
"scatter_scalar_reduce_add_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_SUBTRACT :
cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
"scatter_scalar_reduce_subtract_", reduce_subtract);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
"scatter_scalar_reduce_multiply_", reduce_multiply);
break;
case SCATTER_GATHER_OP::REDUCE_DIVIDE :
cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
"scatter_scalar_reduce_divide_", reduce_divide);
}
}

Expand Down
207 changes: 189 additions & 18 deletions aten/src/ATen/native/cuda/ScatterGatherKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,34 @@

namespace at { namespace native {

// Implement as functors since lambdas don't get optimized.
class ReduceMultiply {
public:
template <typename scalar_t>
constexpr C10_DEVICE void operator() (scalar_t * self_data, const scalar_t * src_data) const {
gpuAtomicMul(self_data, *src_data);
}
};
static ReduceMultiply reduce_multiply;

class ReduceAdd {
public:
template <typename scalar_t>
constexpr C10_DEVICE void operator() (scalar_t * self_data, const scalar_t * src_data) const {
gpuAtomicAdd(self_data, *src_data);
}
};
static ReduceAdd reduce_add;

class TensorAssign {
public:
template <typename scalar_t>
constexpr C10_DEVICE void operator() (scalar_t * self_data, const scalar_t * src_data) const {
*self_data = *src_data;
}
};
static TensorAssign tensor_assign;

// The kernels are implemented on an opaque,
// self-aligned type of the correct size,
// to avoid redundant kernels for different types
Expand Down Expand Up @@ -160,6 +188,7 @@ struct cuda_scatter_gather_base_kernel {
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;


AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
Expand All @@ -173,6 +202,78 @@ struct cuda_scatter_gather_base_kernel {
}
);
}

void operator()(
Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src,
const std::string& method_name,
const ReduceMultiply& f
) {
// no-op if index is empty
if (index.numel() == 0) {
return;
}
at::assert_no_internal_overlap(self);

dim = maybe_wrap_dim(dim, self.dim());

scatter_gather_dtype_check(method_name, self, index, src);
if (is_scatter_like) {
scatter_shape_check(self, dim, index, src);
}
else {
gather_shape_check(self, dim, index, src);
}

auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
auto self_strides = ensure_nonempty_vec(self.strides().vec());
auto src_strides = ensure_nonempty_vec(src.strides().vec());

// restride self and src such that
// self.shape = src.shape = index.shape
//
// restride stride[dim] such that
// if (is_scatter_like) self.stride[dim] = 0
// else src.stride[dim] = 0
auto self_restrided = is_scatter_like ?
restride_dim(self, dim, index_sizes)
: self.as_strided(index_sizes, self_strides);
auto src_restrided = is_scatter_like ?
src.as_strided(index_sizes, src_strides)
: restride_dim(src, dim, index_sizes);

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(self_restrided)
.add_input(src_restrided)
.add_input(index)
.build();

auto self_dim_stride = ensure_nonempty_stride(self, dim);
auto self_dim_size = ensure_nonempty_size(self, dim);

auto src_dim_stride = ensure_nonempty_stride(src, dim);
auto src_dim_size = ensure_nonempty_size(src, dim);

auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;


AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
iter.dtype(),
method_name, [&] {
using dtype = typename std::conditional<cast_to_opaque,
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;

_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
iter, index_size, index_stride, f
);
}
);
}
}; // struct cuda_scatter_gather_base_kernel

template <typename scalar_t>
Expand Down Expand Up @@ -214,7 +315,7 @@ struct _cuda_scatter_fill_internal_kernel {

f(
(scalar_t*)self_data + idx_dim * index_stride,
&src_val
(scalar_t*)&src_val
);

};
Expand Down Expand Up @@ -276,35 +377,77 @@ struct cuda_scatter_fill_base_kernel {
);
}
);
}

void operator()(
Tensor& self, int64_t dim,
const Tensor& index, Scalar src,
const std::string& method_name,
const ReduceMultiply& f
) {
// no-op if index is empty
if (index.numel() == 0) {
return;
}
at::assert_no_internal_overlap(self);

dim = maybe_wrap_dim(dim, self.dim());

scatter_gather_dtype_check(method_name, self, index);
scatter_shape_check(self, dim, index);

auto index_sizes = ensure_nonempty_vec(index.sizes().vec());

// restride self such that
// self.shape = index.shape and
// self.stride[dim] = 0
auto self_restrided = restride_dim(self, dim, index_sizes);

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(self_restrided)
.add_input(index)
.build();

auto index_size = ensure_nonempty_size(self, dim);
auto index_stride = ensure_nonempty_stride(self, dim);

AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
iter.dtype(),
method_name, [&] {
using dtype = typename std::conditional<cast_to_opaque,
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;

auto src_scalar_val = src.to<scalar_t>();
auto src_val = *(dtype*)&src_scalar_val;

_cuda_scatter_fill_internal_kernel<dtype>()(
iter, src_val, index_size, index_stride, f
);
}
);
}
}; // struct cuda_scatter_fill_base_kernel

void gather_cuda_kernel(Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) {
cuda_scatter_gather_base_kernel</*is_scatter_like=*/false>()(
result, dim, index, self,
"gather_out_cuda", []C10_DEVICE(auto* lhs, const auto* rhs) {
*lhs = *rhs;
}
);
"gather_out_cuda", tensor_assign);
}

void scatter_cuda_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
cuda_scatter_gather_base_kernel<>()(
self, dim, index, src,
"scatter_cuda_", []C10_DEVICE(auto* lhs, const auto* rhs) {
*lhs = *rhs;
}
);
"scatter_cuda_", tensor_assign);
}

void scatter_fill_cuda_kernel(Tensor& self, int64_t dim, const Tensor& index, Scalar src) {
cuda_scatter_fill_base_kernel<>()(
self, dim, index, src,
"scatter_fill_cuda_", []C10_DEVICE(auto* lhs, const auto* rhs) {
*lhs = *rhs;
}
);
"scatter_fill_cuda_", tensor_assign);
}

void scatter_add_cuda_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
Expand All @@ -313,15 +456,43 @@ void scatter_add_cuda_kernel(Tensor& self, int64_t dim, const Tensor& index, con
globalContext().alertNotDeterministic("scatter_add_cuda_kernel");
cuda_scatter_gather_base_kernel</*is_scatter_like=*/true, /*cast_to_opaque=*/false>()(
self, dim, index, src,
"scatter_add_cuda_", []C10_DEVICE(auto* lhs, const auto* rhs) {
gpuAtomicAdd(lhs, *rhs);
}
);
"scatter_add_cuda_", reduce_add);
}

void scatter_reduce_cuda_kernel(Tensor& self, const int64_t dim, const Tensor& index,
const Tensor& src, const SCATTER_GATHER_OP& reduce) {
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_ADD :
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
"scatter_reduce_cuda_add_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
"scatter_reduce_cuda_multiply_", reduce_multiply);
break;
}
}

void scatter_scalar_reduce_cuda_kernel(Tensor& self, const int64_t dim, const Tensor& index,
Scalar& value, const SCATTER_GATHER_OP& reduce) {
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_ADD :
cuda_scatter_fill_base_kernel<false>()(self, dim, index, value,
"scatter_fill_cuda_add_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
cuda_scatter_fill_base_kernel<false>()(self, dim, index, value,
"scatter_fill_cuda_multiply_", reduce_multiply);
break;
}
}


REGISTER_DISPATCH(gather_stub, &gather_cuda_kernel);
REGISTER_DISPATCH(scatter_stub, &scatter_cuda_kernel);
REGISTER_DISPATCH(scatter_fill_stub, &scatter_fill_cuda_kernel);
REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cuda_kernel);

REGISTER_DISPATCH(scatter_reduce_stub, &scatter_reduce_cuda_kernel);
REGISTER_DISPATCH(scatter_scalar_reduce_stub, &scatter_scalar_reduce_cuda_kernel);

}} // namespace at::native
4 changes: 2 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4541,13 +4541,13 @@
use_c10_dispatcher: full
variants: method
dispatch:
CPU: scatter_cpu_reduce_
CPU, CUDA: scatter_reduce_

- func: scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
dispatch:
CPU: scatter_cpu_scalar_reduce_
CPU, CUDA: scatter_scalar_reduce_

- func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
use_c10_dispatcher: full
Expand Down
Loading

0 comments on commit e18a221

Please sign in to comment.