From 602ef0d9d06882bde1a5ba7914599bbbfeb537bb Mon Sep 17 00:00:00 2001 From: Nik Ved Date: Wed, 19 Feb 2020 18:14:19 -0800 Subject: [PATCH] [WIP] migrate scatter_ to ATen CPU (+multithreading, nondeterministic) (#33139) Summary: Fixes https://github.com/pytorch/pytorch/issues/24757, partially https://github.com/pytorch/pytorch/issues/33094. Uses fix introduces in https://github.com/pytorch/pytorch/issues/33108 to avoid regressions for some compilers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33139 Differential Revision: D19882462 Pulled By: ngimel fbshipit-source-id: 5016f186a4aadc3cc32edcfd9abdea11786f27e9 --- aten/src/ATen/Declarations.cwrap | 3 +- .../ATen/native/ScatterGatherShapeChecks.h | 8 +- .../ATen/native/TensorAdvancedIndexing.cpp | 12 ++ aten/src/ATen/native/TensorAdvancedIndexing.h | 4 + .../ATen/native/cpu/ScatterGatherKernel.cpp | 117 +++++++++++++++--- aten/src/ATen/native/native_functions.yaml | 4 +- aten/src/TH/generic/THTensorEvenMoreMath.cpp | 62 ---------- aten/src/TH/generic/THTensorMath.h | 3 - 8 files changed, 127 insertions(+), 86 deletions(-) diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 6568e1f7d9da1..f8517aa4b60ae 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -216,7 +216,8 @@ [[ name: _th_scatter_ return: argument 0 - cpu_bool: True + backends: + - CUDA cuda_bool: True variants: function options: diff --git a/aten/src/ATen/native/ScatterGatherShapeChecks.h b/aten/src/ATen/native/ScatterGatherShapeChecks.h index a501ac41ea294..57e4070e7d33d 100644 --- a/aten/src/ATen/native/ScatterGatherShapeChecks.h +++ b/aten/src/ATen/native/ScatterGatherShapeChecks.h @@ -15,11 +15,11 @@ void gather_shape_check(const Tensor& self, int64_t dim, const Tensor& index); // Used for `scatter` and `scatter_add` // Tests: -// 1. index.size(d) <= src.size(d) for all d -// 2. index.size(d) <= self.size(d) for all d != dim +// 1. index.size(d) <= self.size(d) for all d != dim +// 2. index.size(d) <= src.size(d) for all d if src is a Tensor void scatter_shape_check( - const Tensor& self, int64_t dim, - const Tensor& index, const Tensor& src + const Tensor& self, int64_t dim, const Tensor& index, + const c10::optional& src_opt = c10::nullopt ); } // anonymous namespace diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 82b4853683fbb..58de7e951ca80 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -72,6 +72,8 @@ DEFINE_DISPATCH(index_put_accum_stub); REGISTER_NO_CPU_DISPATCH(index_put_accum_stub, index_put_accum_fn); DEFINE_DISPATCH(gather_stub); +DEFINE_DISPATCH(scatter_stub); +DEFINE_DISPATCH(scatter_fill_stub); DEFINE_DISPATCH(scatter_add_stub); static bool all_strides_match(TensorList tensors) { @@ -500,6 +502,16 @@ Tensor gather_cpu(const Tensor & self, int64_t dim, const Tensor & index, bool s return gather_out_cpu(result, self, dim, index, sparse_grad); } +Tensor & scatter_cpu_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) { + scatter_stub(self.device().type(), self, dim, index, src); + return self; +} + +Tensor & scatter_fill_cpu_(Tensor & self, int64_t dim, const Tensor & index, Scalar src) { + scatter_fill_stub(self.device().type(), self, dim, index, src); + return self; +} + Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { return self.clone(at::MemoryFormat::Preserve).scatter_(dim, index, source); } diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.h b/aten/src/ATen/native/TensorAdvancedIndexing.h index c9e0196742528..b5afc95db7be4 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.h +++ b/aten/src/ATen/native/TensorAdvancedIndexing.h @@ -16,6 +16,8 @@ using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArr using index_put_accum_fn = void(*)(Tensor &, TensorList , const Tensor &, bool unsafe); using gather_fn = void (*)(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index); +using scatter_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); +using scatter_fill_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, Scalar src); using scatter_add_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); DECLARE_DISPATCH(index_fn, index_stub); @@ -23,6 +25,8 @@ DECLARE_DISPATCH(index_put_fn, index_put_stub); DECLARE_DISPATCH(index_put_accum_fn, index_put_accum_stub); DECLARE_DISPATCH(gather_fn, gather_stub); +DECLARE_DISPATCH(scatter_fn, scatter_stub); +DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub); DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub); }} // namespace at::native diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index cfd499898c4bb..0761d89b1aaa3 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -50,33 +50,53 @@ void gather_shape_check(const Tensor& self, int64_t dim, const Tensor& index) { // Used for `scatter` and `scatter_add` // Tests: -// 1. index.size(d) <= src.size(d) for all d -// 2. index.size(d) <= self.size(d) for all d != dim +// 1. index.size(d) <= self.size(d) for all d != dim +// 2. index.size(d) <= src.size(d) for all d if src is a Tensor void scatter_shape_check( - const Tensor& self, int64_t dim, - const Tensor& index, const Tensor& src + const Tensor& self, int64_t dim, const Tensor& index, + const c10::optional& src_opt ) { bool is_wrong_shape = false; int64_t self_dims = ensure_nonempty_dim(self.dim()); + + // Check: index.size(d) <= self.size(d) for all d != dim for (int64_t d = 0; d < self_dims; ++d) { int64_t index_d_size = ensure_nonempty_size(index, d); - if (index_d_size > ensure_nonempty_size(src, d)) { + if (d == dim) continue; + if (index_d_size > ensure_nonempty_size(self, d)) { is_wrong_shape = true; break; } - if (d != dim) { - if (index_d_size > ensure_nonempty_size(self, d)) { + } + + // Check: index.size(d) <= src.size(d) for all d if src is Tensor + if (!is_wrong_shape && src_opt.has_value()) { + auto src = src_opt.value(); + for (int64_t d = 0; d < self_dims; ++d) { + int64_t index_d_size = ensure_nonempty_size(index, d); + if (index_d_size > ensure_nonempty_size(src, d)) { is_wrong_shape = true; break; } } } - TORCH_CHECK(!is_wrong_shape, - "Expected index [", index.sizes(), "]", - " to be smaller size than src [", src.sizes(), "]", - " and to be smaller than self [", self.sizes(), "]", - " apart from dimension ", dim - ); + + if (src_opt.has_value()) { + auto src = src_opt.value(); + TORCH_CHECK(!is_wrong_shape, + "Expected index ", index.sizes(), + " to be smaller than self ", self.sizes(), + " apart from dimension ", dim, + " and to be smaller size than src ", src.sizes() + ); + } + else { + TORCH_CHECK(!is_wrong_shape, + "Expected index ", index.sizes(), + " to be smaller than self ", self.sizes(), + " apart from dimension ", dim + ); + } } static Tensor restride_dim( @@ -187,6 +207,73 @@ void gather_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim, const Te ); } +void scatter_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { + if (index.numel() == 0) { + return; + } + + dim = maybe_wrap_dim(dim, self.dim()); + + scatter_shape_check(self, dim, index, src); + + int64_t index_dim_size = ensure_nonempty_size(index, dim); + int64_t self_dim_size = ensure_nonempty_size(self, dim); + + cpu_scatter_gather_base_kernel( + self, dim, index, src, + "scatter_cpu_", [&] ( + auto* self_data, auto self_dim_stride, + const auto* index_data, auto index_dim_stride, + const auto* src_data, auto src_dim_stride + ) { + for (int64_t i = 0; i < index_dim_size; ++i) { + int64_t idx_dim = index_data[i * index_dim_stride]; + // we are not putting idx_dim in the error message because it disables + // loop optimization in clang-7 + TORCH_CHECK(idx_dim >= 0 && idx_dim < self_dim_size, + "index ", index_data[i * index_dim_stride], + " is out of bounds for dimension ", dim, + " with size ", self_dim_size); + self_data[idx_dim * self_dim_stride] = src_data[i * src_dim_stride]; + } + }, /*serial_exec=*/false + ); +} + +void scatter_fill_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, Scalar src) { + if (index.numel() == 0) { + return; + } + + dim = maybe_wrap_dim(dim, self.dim()); + + scatter_shape_check(self, dim, index); + + int64_t index_dim_size = ensure_nonempty_size(index, dim); + int64_t self_dim_size = ensure_nonempty_size(self, dim); + + cpu_scatter_gather_base_kernel( + self, dim, index, self, + "scatter_fill_cpu_", [&] ( + auto* self_data, auto self_dim_stride, + const auto* index_data, auto index_dim_stride, + const auto* src_data, auto src_dim_stride + ) { + for (int64_t i = 0; i < index_dim_size; ++i) { + int64_t idx_dim = index_data[i * index_dim_stride]; + // we are not putting idx_dim in the error message because it disables + // loop optimization in clang-7 + TORCH_CHECK(idx_dim >= 0 && idx_dim < self_dim_size, + "index ", index_data[i * index_dim_stride], + " is out of bounds for dimension ", dim, + " with size ", self_dim_size); + using scalar_t = typename std::remove_pointer::type; + self_data[idx_dim * self_dim_stride] = src.to(); + } + }, /*serial_exec=*/false + ); +} + void scatter_add_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { if (index.numel() == 0) { return; @@ -219,9 +306,11 @@ void scatter_add_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, cons /*serial_exec=*/true); } -} // anonymous napespace +} // anonymous namespace REGISTER_DISPATCH(gather_stub, &gather_cpu_kernel); +REGISTER_DISPATCH(scatter_stub, &scatter_cpu_kernel); +REGISTER_DISPATCH(scatter_fill_stub, &scatter_fill_cpu_kernel); REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cpu_kernel); }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index c8ff06d57a3ec..5558dc7d4b9df 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3953,7 +3953,7 @@ - func: scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) variants: method dispatch: - CPU: legacy::cpu::_th_scatter_ + CPU: scatter_cpu_ CUDA: legacy::cuda::_th_scatter_ - func: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor @@ -3963,7 +3963,7 @@ - func: scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) variants: method dispatch: - CPU: legacy::cpu::_th_scatter_ + CPU: scatter_fill_cpu_ CUDA: legacy::cuda::_th_scatter_ - func: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index a871cffb793ff..c738f0b3f8486 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -119,35 +119,6 @@ void THTensor_(maskedSelectBool)(THTensor *tensor, THTensor *src, THBoolTensor * }); } -void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val) -{ - int64_t elems_per_row, i, idx; - dim = at::maybe_wrap_dim(dim, tensor); - int index_ndim_legacy_all = THLongTensor_nDimensionLegacyAll(index); - - THArgCheck(dim < THTensor_(nDimensionLegacyAll)(tensor), 2, "Index dimension is out of bounds"); - THArgCheck(index_ndim_legacy_all == 0 || index_ndim_legacy_all == THLongTensor_nDimensionLegacyAll(tensor), 3, - "Index tensor must either be empty or have same dimensions as output tensor"); - - // no-op if index is empty - if (index_ndim_legacy_all == 0) - return; - - elems_per_row = THTensor_sizeLegacyNoScalars(index, dim); - - TH_TENSOR_DIM_APPLY2(scalar_t, tensor, int64_t, index, dim, - for (i = 0; i < elems_per_row; ++i) - { - idx = *(index_data + i*index_stride); - if (idx < 0 || idx >= tensor_size) - { - THFree(TH_TENSOR_DIM_APPLY_counter); - THError("Invalid index in scatter"); - } - tensor_data[idx * tensor_stride] = val; - }) -} - void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value) { at::NoNamesGuard guard; @@ -510,39 +481,6 @@ void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar THLongTensor_free(index); } -void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src) -{ - int64_t elems_per_row, i, idx; - dim = at::maybe_wrap_dim(dim, tensor); - int index_ndim_legacy_all = THTensor_nDimensionLegacyAll(index); - - THArgCheck(dim < THTensor_(nDimensionLegacyNoScalars)(tensor), 2, "Index dimension is out of bounds"); - THArgCheck(index_ndim_legacy_all == 0 - || THLongTensor_nDimensionLegacyNoScalars(index) == THTensor_(nDimensionLegacyNoScalars)(tensor), 3, - "Index tensor must be either empty or have same dimensions as output tensor"); - THArgCheck(THTensor_(nDimensionLegacyNoScalars)(src) == THTensor_(nDimensionLegacyNoScalars)(tensor), 4, - "Input tensor must have same dimensions as output tensor"); - - // no-op if index is empty - if (index_ndim_legacy_all == 0) - return; - - elems_per_row = THTensor_sizeLegacyNoScalars(index, dim); - - TH_TENSOR_DIM_APPLY3(int64_t, index, scalar_t, tensor, scalar_t, src, dim, - TH_TENSOR_DIM_APPLY3_SIZE_SCATTER, - for (i = 0; i < elems_per_row; ++i) - { - idx = *(index_data + i*index_stride); - if (idx < 0 || idx >= tensor_size) - { - THFree(TH_TENSOR_DIM_APPLY_counter); - THError("Invalid index in scatter"); - } - tensor_data[idx * tensor_stride] = *(src_data + i*src_stride); - }) -} - #if !defined(TH_REAL_IS_BOOL) accreal THTensor_(dot)(THTensor *tensor, THTensor *src) diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index 06dd99ff534eb..552bf3904851b 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -91,9 +91,6 @@ TH_API void THTensor_(take)(THTensor *tensor, THTensor *src, THLongTensor *index TH_API void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int accumulate); TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val); -TH_API void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src); -TH_API void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val); - TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension); TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension);