Skip to content

Commit

Permalink
porting scatter_add to ATen (CPU) (pytorch#31662)
Browse files Browse the repository at this point in the history
Summary:
Fixes [https://github.com/pytorch/pytorch/issues/24758](https://github.com/pytorch/pytorch/issues/24758).
Pull Request resolved: pytorch#31662

Differential Revision: D19440824

Pulled By: ngimel

fbshipit-source-id: b13443cfcc8bcb9ec21f1cddb5c6fbc0ef4bb0f2
  • Loading branch information
nikitaved authored and facebook-github-bot committed Jan 18, 2020
1 parent 5342968 commit 61ee8c9
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 36 deletions.
3 changes: 2 additions & 1 deletion aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@
name: _th_scatter_add_
return: argument 0
cname: scatterAdd
cpu_bool: True
backends:
- CUDA
cuda_bool: True
variants: function
arguments:
Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/native/ScatterGatherShapeChecks.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include <vector>
#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>

namespace at { namespace native {

namespace {

// Used for `scatter` and `scatter_add`
// Tests:
// 1. index.size(d) <= src.size(d) for all d
// 2. index.size(d) <= self.size(d) for all d != dim
void scatter_shape_check(
const Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src
);

} // anonymous namespace

}} // namespace at::native
7 changes: 7 additions & 0 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ DEFINE_DISPATCH(index_put_stub);
DEFINE_DISPATCH(index_put_accum_stub);
REGISTER_NO_CPU_DISPATCH(index_put_accum_stub, index_put_accum_fn);

DEFINE_DISPATCH(scatter_add_stub);

static bool all_strides_match(TensorList tensors) {
TORCH_CHECK(tensors.size() >= 1);
auto strides = tensors[0].strides();
Expand Down Expand Up @@ -494,6 +496,11 @@ Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar so
return self.clone(at::MemoryFormat::Preserve).scatter_(dim, index, source);
}

Tensor & scatter_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) {
scatter_add_stub(self.device().type(), self, dim, index, src);
return self;
}

Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone(at::MemoryFormat::Preserve).scatter_add_(dim, index, source);
}
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/TensorAdvancedIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRe
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
using index_put_accum_fn = void(*)(Tensor &, TensorList , const Tensor &, bool unsafe);

using scatter_add_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);

DECLARE_DISPATCH(index_fn, index_stub);
DECLARE_DISPATCH(index_put_fn, index_put_stub);
DECLARE_DISPATCH(index_put_accum_fn, index_put_accum_stub);

DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);

}} // namespace at::native
171 changes: 171 additions & 0 deletions aten/src/ATen/native/cpu/ScatterGatherKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#include <ATen/native/ScatterGatherShapeChecks.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/Parallel.h>

namespace at { namespace native {

namespace {

static inline int64_t ensure_nonempty_dim(int64_t dim) {
return std::max<int64_t>(dim, 1);
}

static inline int64_t ensure_nonempty_size(const Tensor& t, int64_t dim) {
return t.dim() == 0 ? 1 : t.size(dim);
}

static inline int64_t ensure_nonempty_stride(const Tensor& t, int64_t dim) {
return t.dim() == 0 ? 1 : t.stride(dim);
}

using IdxVec = std::vector<int64_t>;
static inline IdxVec ensure_nonempty_vec(IdxVec vec) {
if (vec.size() == 0) {
vec.push_back(1);
}
return vec;
}

// Used for `scatter` and `scatter_add`
// Tests:
// 1. index.size(d) <= src.size(d) for all d
// 2. index.size(d) <= self.size(d) for all d != dim
void scatter_shape_check(
const Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src
) {
bool is_wrong_shape = false;
int64_t self_dims = ensure_nonempty_dim(self.dim());
for (int64_t d = 0; d < self_dims; ++d) {
int64_t index_d_size = ensure_nonempty_size(index, d);
if (index_d_size > ensure_nonempty_size(src, d)) {
is_wrong_shape = true;
break;
}
if (d != dim) {
if (index_d_size > ensure_nonempty_size(self, d)) {
is_wrong_shape = true;
break;
}
}
}
TORCH_CHECK(!is_wrong_shape,
"Expected index [", index.sizes(), "]",
" to be smaller size than src [", src.sizes(), "]",
" and to be smaller than self [", self.sizes(), "]",
" apart from dimension ", dim
);
}

static Tensor restride_dim(
const Tensor& src, int64_t dim,
IntArrayRef replacement_shape
) {
auto strides = ensure_nonempty_vec(src.strides().vec());
strides[dim] = 0;
return src.as_strided(replacement_shape, strides);
}

template <typename func_t>
void cpu_scatter_gather_base_kernel(
Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src,
const std::string& method_name,
const func_t& f,
bool serial_exec = true
) {
auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
auto index_strides = ensure_nonempty_vec(index.strides().vec());

// `dim` is traversed in a kernel function `f`,
// that is why index.stride(dim) = 0 and index.size(dim) = 1.
// Also, index.size(dim) = 1 makes sure that TensorIterator.DimCounter
// has the following form : (i_1,..., i_{dim-1}, 0, i_{dim+1},...,i_n).
index_sizes[dim] = 1;
index_strides[dim] = 0;

// set self.shape = src.shape = index.shape,
// this defines the number of elements to iterate over,
// and set self.stride(dim) = src.stride(dim) = 0,
// because `dim` is traversed in a kernel function `f`.
auto self_restrided = restride_dim(self, dim, index_sizes);
auto index_restrided = index.as_strided(index_sizes, index_strides);
auto src_restrided = restride_dim(src, dim, index_sizes);

auto iter = TensorIterator();
iter.dont_compute_common_dtype();
iter.dont_resize_outputs();
iter.add_output(self_restrided);
iter.add_input(src_restrided, src.device(), src.scalar_type());
iter.add_input(index_restrided);
iter.build();

auto self_dim_stride = ensure_nonempty_stride(self, dim);
auto index_dim_stride = ensure_nonempty_stride(index, dim);
auto src_dim_stride = ensure_nonempty_stride(src, dim);

AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Bool, ScalarType::Half, iter.dtype(),
method_name, [&] {
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
auto* self_data_bytes = data[0];
const auto* index_data_bytes = data[2];
const auto* src_data_bytes = data[1];

for (int64_t i = 0; i < n; ++i) {
f(
(scalar_t*)self_data_bytes, self_dim_stride,
(int64_t*)index_data_bytes, index_dim_stride,
(scalar_t*)src_data_bytes, src_dim_stride
);

self_data_bytes += strides[0];
index_data_bytes += strides[2];
src_data_bytes += strides[1];
}
};
if (serial_exec) {
iter.serial_for_each(loop, {0, iter.numel()});
} else {
iter.for_each(loop);
}
}
);
}

void scatter_add_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
if (index.numel() == 0) {
return;
}

dim = maybe_wrap_dim(dim, self.dim());

scatter_shape_check(self, dim, index, src);

int64_t index_dim_size = ensure_nonempty_size(index, dim);
int64_t self_dim_size = ensure_nonempty_size(self, dim);

cpu_scatter_gather_base_kernel(
self, dim, index, src,
"scatter_add_", [&] (
auto* self_data, auto self_dim_stride,
const auto* index_data, auto index_dim_stride,
const auto* src_data, auto src_dim_stride
) {
for (int64_t i = 0; i < index_dim_size; ++i) {
int64_t idx_dim = index_data[i * index_dim_stride];
TORCH_CHECK(idx_dim >= 0 && idx_dim < self_dim_size,
"index ", idx_dim,
" is out of bounds for dimension ", dim,
" with size ", self_dim_size);
self_data[idx_dim * self_dim_stride] += src_data[i * src_dim_stride];
}
}, /*serial_exec=*/true
);
}

} // anonymous napespace

REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cpu_kernel);

}} // namespace at::native
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3926,7 +3926,7 @@
- func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
variants: method
dispatch:
CPU: legacy::cpu::_th_scatter_add_
CPU: scatter_add_cpu_
CUDA: legacy::cuda::_th_scatter_add_

- func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
Expand Down
32 changes: 0 additions & 32 deletions aten/src/TH/generic/THTensorEvenMoreMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,38 +566,6 @@ void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor
})
}

void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src)
{
int64_t elems_per_row, i, idx;
int index_ndim_legacy_all = THTensor_nDimensionLegacyAll(index);

THArgCheck(dim < THTensor_(nDimensionLegacyNoScalars)(tensor), 2, "Index dimension is out of bounds");
THArgCheck(index_ndim_legacy_all == 0
|| THLongTensor_nDimensionLegacyNoScalars(index) == THTensor_(nDimensionLegacyNoScalars)(tensor), 3,
"Index tensor must have same dimensions as output tensor");
THArgCheck(THTensor_(nDimensionLegacyNoScalars)(src) == THTensor_(nDimensionLegacyNoScalars)(tensor), 4,
"Input tensor must have same dimensions as output tensor");

// no-op if index is empty
if (index_ndim_legacy_all == 0)
return;

elems_per_row = THTensor_sizeLegacyNoScalars(index, dim);

TH_TENSOR_DIM_APPLY3(int64_t, index, scalar_t, tensor, scalar_t, src, dim,
TH_TENSOR_DIM_APPLY3_SIZE_SCATTER,
for (i = 0; i < elems_per_row; ++i)
{
idx = *(index_data + i*index_stride);
if (idx < 0 || idx >= tensor_size)
{
THFree(TH_TENSOR_DIM_APPLY_counter);
THError("Invalid index in scatterAdd");
}
tensor_data[idx * tensor_stride] += *(src_data + i*src_stride);
})
}

#if !defined(TH_REAL_IS_BOOL)

accreal THTensor_(dot)(THTensor *tensor, THTensor *src)
Expand Down
1 change: 0 additions & 1 deletion aten/src/TH/generic/THTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index,

TH_API void THTensor_(gather)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index);
TH_API void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
TH_API void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
TH_API void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val);

TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension);
Expand Down
3 changes: 2 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2899,10 +2899,11 @@ def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True):
idx = cast(torch.LongTensor().resize_(*idx_size))
_TestTorchMixin._fill_indices(self, idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o)

src_size = [random.randint(1, 5) + s for s in idx_size]
if is_scalar:
src = random.random()
else:
src = cast(torch.Tensor(*idx_size).normal_())
src = cast(torch.Tensor(*src_size).normal_())

base = cast(torch.randn(m, n, o))
actual = getattr(base.clone(), method)(dim, idx, src)
Expand Down

0 comments on commit 61ee8c9

Please sign in to comment.