Skip to content

Commit

Permalink
[WIP] migrate scatter_ to ATen CPU (+multithreading, nondeterministic) (
Browse files Browse the repository at this point in the history
pytorch#33139)

Summary:
Fixes pytorch#24757, partially pytorch#33094. Uses fix introduces in pytorch#33108 to avoid regressions for some compilers.
Pull Request resolved: pytorch#33139

Differential Revision: D19882462

Pulled By: ngimel

fbshipit-source-id: 5016f186a4aadc3cc32edcfd9abdea11786f27e9
  • Loading branch information
nikitaved authored and facebook-github-bot committed Feb 20, 2020
1 parent 6cb9e6b commit 602ef0d
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 86 deletions.
3 changes: 2 additions & 1 deletion aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@
[[
name: _th_scatter_
return: argument 0
cpu_bool: True
backends:
- CUDA
cuda_bool: True
variants: function
options:
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/ScatterGatherShapeChecks.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ void gather_shape_check(const Tensor& self, int64_t dim, const Tensor& index);

// Used for `scatter` and `scatter_add`
// Tests:
// 1. index.size(d) <= src.size(d) for all d
// 2. index.size(d) <= self.size(d) for all d != dim
// 1. index.size(d) <= self.size(d) for all d != dim
// 2. index.size(d) <= src.size(d) for all d if src is a Tensor
void scatter_shape_check(
const Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src
const Tensor& self, int64_t dim, const Tensor& index,
const c10::optional<Tensor>& src_opt = c10::nullopt
);

} // anonymous namespace
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ DEFINE_DISPATCH(index_put_accum_stub);
REGISTER_NO_CPU_DISPATCH(index_put_accum_stub, index_put_accum_fn);

DEFINE_DISPATCH(gather_stub);
DEFINE_DISPATCH(scatter_stub);
DEFINE_DISPATCH(scatter_fill_stub);
DEFINE_DISPATCH(scatter_add_stub);

static bool all_strides_match(TensorList tensors) {
Expand Down Expand Up @@ -500,6 +502,16 @@ Tensor gather_cpu(const Tensor & self, int64_t dim, const Tensor & index, bool s
return gather_out_cpu(result, self, dim, index, sparse_grad);
}

Tensor & scatter_cpu_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) {
scatter_stub(self.device().type(), self, dim, index, src);
return self;
}

Tensor & scatter_fill_cpu_(Tensor & self, int64_t dim, const Tensor & index, Scalar src) {
scatter_fill_stub(self.device().type(), self, dim, index, src);
return self;
}

Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone(at::MemoryFormat::Preserve).scatter_(dim, index, source);
}
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/TensorAdvancedIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArr
using index_put_accum_fn = void(*)(Tensor &, TensorList , const Tensor &, bool unsafe);

using gather_fn = void (*)(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
using scatter_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
using scatter_fill_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, Scalar src);
using scatter_add_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);

DECLARE_DISPATCH(index_fn, index_stub);
DECLARE_DISPATCH(index_put_fn, index_put_stub);
DECLARE_DISPATCH(index_put_accum_fn, index_put_accum_stub);

DECLARE_DISPATCH(gather_fn, gather_stub);
DECLARE_DISPATCH(scatter_fn, scatter_stub);
DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub);
DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);

}} // namespace at::native
117 changes: 103 additions & 14 deletions aten/src/ATen/native/cpu/ScatterGatherKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,33 +50,53 @@ void gather_shape_check(const Tensor& self, int64_t dim, const Tensor& index) {

// Used for `scatter` and `scatter_add`
// Tests:
// 1. index.size(d) <= src.size(d) for all d
// 2. index.size(d) <= self.size(d) for all d != dim
// 1. index.size(d) <= self.size(d) for all d != dim
// 2. index.size(d) <= src.size(d) for all d if src is a Tensor
void scatter_shape_check(
const Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src
const Tensor& self, int64_t dim, const Tensor& index,
const c10::optional<Tensor>& src_opt
) {
bool is_wrong_shape = false;
int64_t self_dims = ensure_nonempty_dim(self.dim());

// Check: index.size(d) <= self.size(d) for all d != dim
for (int64_t d = 0; d < self_dims; ++d) {
int64_t index_d_size = ensure_nonempty_size(index, d);
if (index_d_size > ensure_nonempty_size(src, d)) {
if (d == dim) continue;
if (index_d_size > ensure_nonempty_size(self, d)) {
is_wrong_shape = true;
break;
}
if (d != dim) {
if (index_d_size > ensure_nonempty_size(self, d)) {
}

// Check: index.size(d) <= src.size(d) for all d if src is Tensor
if (!is_wrong_shape && src_opt.has_value()) {
auto src = src_opt.value();
for (int64_t d = 0; d < self_dims; ++d) {
int64_t index_d_size = ensure_nonempty_size(index, d);
if (index_d_size > ensure_nonempty_size(src, d)) {
is_wrong_shape = true;
break;
}
}
}
TORCH_CHECK(!is_wrong_shape,
"Expected index [", index.sizes(), "]",
" to be smaller size than src [", src.sizes(), "]",
" and to be smaller than self [", self.sizes(), "]",
" apart from dimension ", dim
);

if (src_opt.has_value()) {
auto src = src_opt.value();
TORCH_CHECK(!is_wrong_shape,
"Expected index ", index.sizes(),
" to be smaller than self ", self.sizes(),
" apart from dimension ", dim,
" and to be smaller size than src ", src.sizes()
);
}
else {
TORCH_CHECK(!is_wrong_shape,
"Expected index ", index.sizes(),
" to be smaller than self ", self.sizes(),
" apart from dimension ", dim
);
}
}

static Tensor restride_dim(
Expand Down Expand Up @@ -187,6 +207,73 @@ void gather_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim, const Te
);
}

void scatter_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
if (index.numel() == 0) {
return;
}

dim = maybe_wrap_dim(dim, self.dim());

scatter_shape_check(self, dim, index, src);

int64_t index_dim_size = ensure_nonempty_size(index, dim);
int64_t self_dim_size = ensure_nonempty_size(self, dim);

cpu_scatter_gather_base_kernel(
self, dim, index, src,
"scatter_cpu_", [&] (
auto* self_data, auto self_dim_stride,
const auto* index_data, auto index_dim_stride,
const auto* src_data, auto src_dim_stride
) {
for (int64_t i = 0; i < index_dim_size; ++i) {
int64_t idx_dim = index_data[i * index_dim_stride];
// we are not putting idx_dim in the error message because it disables
// loop optimization in clang-7
TORCH_CHECK(idx_dim >= 0 && idx_dim < self_dim_size,
"index ", index_data[i * index_dim_stride],
" is out of bounds for dimension ", dim,
" with size ", self_dim_size);
self_data[idx_dim * self_dim_stride] = src_data[i * src_dim_stride];
}
}, /*serial_exec=*/false
);
}

void scatter_fill_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, Scalar src) {
if (index.numel() == 0) {
return;
}

dim = maybe_wrap_dim(dim, self.dim());

scatter_shape_check(self, dim, index);

int64_t index_dim_size = ensure_nonempty_size(index, dim);
int64_t self_dim_size = ensure_nonempty_size(self, dim);

cpu_scatter_gather_base_kernel(
self, dim, index, self,
"scatter_fill_cpu_", [&] (
auto* self_data, auto self_dim_stride,
const auto* index_data, auto index_dim_stride,
const auto* src_data, auto src_dim_stride
) {
for (int64_t i = 0; i < index_dim_size; ++i) {
int64_t idx_dim = index_data[i * index_dim_stride];
// we are not putting idx_dim in the error message because it disables
// loop optimization in clang-7
TORCH_CHECK(idx_dim >= 0 && idx_dim < self_dim_size,
"index ", index_data[i * index_dim_stride],
" is out of bounds for dimension ", dim,
" with size ", self_dim_size);
using scalar_t = typename std::remove_pointer<decltype(self_data)>::type;
self_data[idx_dim * self_dim_stride] = src.to<scalar_t>();
}
}, /*serial_exec=*/false
);
}

void scatter_add_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
if (index.numel() == 0) {
return;
Expand Down Expand Up @@ -219,9 +306,11 @@ void scatter_add_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, cons
/*serial_exec=*/true);
}

} // anonymous napespace
} // anonymous namespace

REGISTER_DISPATCH(gather_stub, &gather_cpu_kernel);
REGISTER_DISPATCH(scatter_stub, &scatter_cpu_kernel);
REGISTER_DISPATCH(scatter_fill_stub, &scatter_fill_cpu_kernel);
REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cpu_kernel);

}} // namespace at::native
4 changes: 2 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3953,7 +3953,7 @@
- func: scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
variants: method
dispatch:
CPU: legacy::cpu::_th_scatter_
CPU: scatter_cpu_
CUDA: legacy::cuda::_th_scatter_

- func: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
Expand All @@ -3963,7 +3963,7 @@
- func: scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)
variants: method
dispatch:
CPU: legacy::cpu::_th_scatter_
CPU: scatter_fill_cpu_
CUDA: legacy::cuda::_th_scatter_

- func: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
Expand Down
62 changes: 0 additions & 62 deletions aten/src/TH/generic/THTensorEvenMoreMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,35 +119,6 @@ void THTensor_(maskedSelectBool)(THTensor *tensor, THTensor *src, THBoolTensor *
});
}

void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val)
{
int64_t elems_per_row, i, idx;
dim = at::maybe_wrap_dim(dim, tensor);
int index_ndim_legacy_all = THLongTensor_nDimensionLegacyAll(index);

THArgCheck(dim < THTensor_(nDimensionLegacyAll)(tensor), 2, "Index dimension is out of bounds");
THArgCheck(index_ndim_legacy_all == 0 || index_ndim_legacy_all == THLongTensor_nDimensionLegacyAll(tensor), 3,
"Index tensor must either be empty or have same dimensions as output tensor");

// no-op if index is empty
if (index_ndim_legacy_all == 0)
return;

elems_per_row = THTensor_sizeLegacyNoScalars(index, dim);

TH_TENSOR_DIM_APPLY2(scalar_t, tensor, int64_t, index, dim,
for (i = 0; i < elems_per_row; ++i)
{
idx = *(index_data + i*index_stride);
if (idx < 0 || idx >= tensor_size)
{
THFree(TH_TENSOR_DIM_APPLY_counter);
THError("Invalid index in scatter");
}
tensor_data[idx * tensor_stride] = val;
})
}

void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value)
{
at::NoNamesGuard guard;
Expand Down Expand Up @@ -510,39 +481,6 @@ void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar
THLongTensor_free(index);
}

void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src)
{
int64_t elems_per_row, i, idx;
dim = at::maybe_wrap_dim(dim, tensor);
int index_ndim_legacy_all = THTensor_nDimensionLegacyAll(index);

THArgCheck(dim < THTensor_(nDimensionLegacyNoScalars)(tensor), 2, "Index dimension is out of bounds");
THArgCheck(index_ndim_legacy_all == 0
|| THLongTensor_nDimensionLegacyNoScalars(index) == THTensor_(nDimensionLegacyNoScalars)(tensor), 3,
"Index tensor must be either empty or have same dimensions as output tensor");
THArgCheck(THTensor_(nDimensionLegacyNoScalars)(src) == THTensor_(nDimensionLegacyNoScalars)(tensor), 4,
"Input tensor must have same dimensions as output tensor");

// no-op if index is empty
if (index_ndim_legacy_all == 0)
return;

elems_per_row = THTensor_sizeLegacyNoScalars(index, dim);

TH_TENSOR_DIM_APPLY3(int64_t, index, scalar_t, tensor, scalar_t, src, dim,
TH_TENSOR_DIM_APPLY3_SIZE_SCATTER,
for (i = 0; i < elems_per_row; ++i)
{
idx = *(index_data + i*index_stride);
if (idx < 0 || idx >= tensor_size)
{
THFree(TH_TENSOR_DIM_APPLY_counter);
THError("Invalid index in scatter");
}
tensor_data[idx * tensor_stride] = *(src_data + i*src_stride);
})
}

#if !defined(TH_REAL_IS_BOOL)

accreal THTensor_(dot)(THTensor *tensor, THTensor *src)
Expand Down
3 changes: 0 additions & 3 deletions aten/src/TH/generic/THTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ TH_API void THTensor_(take)(THTensor *tensor, THTensor *src, THLongTensor *index
TH_API void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int accumulate);
TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val);

TH_API void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
TH_API void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val);

TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension);
TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension);

Expand Down

0 comments on commit 602ef0d

Please sign in to comment.