diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 6963a02107be73..c06c98e85aaf4f 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -78,6 +78,8 @@ DEFINE_DISPATCH(gather_stub); DEFINE_DISPATCH(scatter_stub); DEFINE_DISPATCH(scatter_fill_stub); DEFINE_DISPATCH(scatter_add_stub); +DEFINE_DISPATCH(scatter_reduce_stub); +DEFINE_DISPATCH(scatter_scalar_reduce_stub); static bool all_strides_match(TensorList tensors) { TORCH_CHECK(tensors.size() >= 1); @@ -533,15 +535,60 @@ Tensor gather(const Tensor & self, int64_t dim, const Tensor & index, bool spars } Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { + TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, + "scatter_(): Expected dtype int64 for index."); scatter_stub(self.device().type(), self, dim, index, source); return self; } Tensor & scatter_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar source) { + TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, + "scatter_(): Expected dtype int64 for index."); scatter_fill_stub(self.device().type(), self, dim, index, source); return self; } +SCATTER_GATHER_OP get_operator_enum(const std::string& reduce) { + if (reduce == "add") { + return SCATTER_GATHER_OP::REDUCE_ADD; + } + else if (reduce == "subtract") { + return SCATTER_GATHER_OP::REDUCE_SUBTRACT; + } + else if (reduce == "multiply") { + return SCATTER_GATHER_OP::REDUCE_MULTIPLY; + } + else if (reduce == "divide") { + return SCATTER_GATHER_OP::REDUCE_DIVIDE; + } + else { + TORCH_CHECK(false, + "reduce argument must be either of add, subtract, multiply or divide."); + } +} + +Tensor& scatter_cpu_scalar_reduce_(Tensor& self, const int64_t dim, const Tensor& index, + Scalar value, const std::string reduce) { + TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, + "scatter_(): Expected dtype int64 for index."); + TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), + "scatter_(): Expected floating or complex type for self."); + SCATTER_GATHER_OP op = get_operator_enum(reduce); + scatter_scalar_reduce_stub(self.device().type(), self, dim, index, value, op); + return self; +} + +Tensor & scatter_cpu_reduce_(Tensor & self, const int64_t dim, const Tensor & index, + const Tensor & src, const std::string reduce) { + TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, + "scatter_(): Expected dtype int64 for index"); + TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), + "scatter_(): Expected floating or complex type for self."); + SCATTER_GATHER_OP op = get_operator_enum(reduce); + scatter_reduce_stub(self.device().type(), self, dim, index, src, op); + return self; +} + Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { return self.clone(at::MemoryFormat::Preserve).scatter_(dim, index, source); } @@ -551,6 +598,8 @@ Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar so } Tensor & scatter_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) { + TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, + "scatter_(): Expected dtype int64 for index."); scatter_add_stub(self.device().type(), self, dim, index, src); return self; } diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.h b/aten/src/ATen/native/TensorAdvancedIndexing.h index 69b17a83673d56..870112c0865cdb 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.h +++ b/aten/src/ATen/native/TensorAdvancedIndexing.h @@ -11,6 +11,8 @@ namespace at { namespace at { namespace native { +enum class SCATTER_GATHER_OP: uint8_t {REDUCE_ADD, REDUCE_SUBTRACT, REDUCE_MULTIPLY, REDUCE_DIVIDE}; + using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides); using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate); using index_put_accum_fn = void(*)(Tensor &, TensorList , const Tensor &, bool unsafe); @@ -21,7 +23,11 @@ using gather_fn = void (*)(Tensor & result, const Tensor & self, int64_t dim, co using scatter_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); using scatter_fill_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, Scalar src); using scatter_add_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); - +using scatter_reduce_fn = void(*)(Tensor& self, const int64_t dim, const Tensor& index, + const Tensor& src, const SCATTER_GATHER_OP& reduce); +using scatter_scalar_reduce_fn = void(*)(Tensor& self, const int64_t dim, const Tensor& index, + Scalar& value, const SCATTER_GATHER_OP& reduce); + DECLARE_DISPATCH(index_fn, index_stub); DECLARE_DISPATCH(index_put_fn, index_put_stub); DECLARE_DISPATCH(index_put_accum_fn, index_put_accum_stub); @@ -33,6 +39,8 @@ DECLARE_DISPATCH(gather_fn, gather_stub); DECLARE_DISPATCH(scatter_fn, scatter_stub); DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub); DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub); +DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub); +DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub); TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices); diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index 4a9d811730f4bc..473ba09fbdfabf 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -1,12 +1,64 @@ #include #include #include +#include #include +#include namespace at { namespace native { namespace { +// Implement as functors since lambdas don't get optimized. +class ReduceMultiply { +public: + template + constexpr void operator() (scalar_t * self_data, scalar_t * src_data) const { + *self_data *= *src_data; + }; + + constexpr void operator() (bool * self_data, bool * src_data) const { + *self_data = *self_data && *src_data; + }; +}; +ReduceMultiply reduce_multiply; + +class ReduceAdd { +public: + template + constexpr void operator() (scalar_t * self_data, scalar_t * src_data) const { + *self_data += *src_data; + }; +}; +ReduceAdd reduce_add; + +class ReduceSubtract { +public: + template + constexpr void operator() (scalar_t * self_data, scalar_t * src_data) const { + *self_data -= *src_data; + }; +}; +ReduceSubtract reduce_subtract; + +class ReduceDivide { +public: + template + constexpr void operator() (scalar_t * self_data, scalar_t * src_data) const { + *self_data /= *src_data; + }; +}; +ReduceDivide reduce_divide; + +class TensorAssign { +public: + template + constexpr void operator() (scalar_t * self_data, scalar_t * src_data) const { + *self_data = *src_data; + }; +}; +TensorAssign tensor_assign; + template struct _cpu_scatter_gather_dim_loop { template @@ -16,7 +68,7 @@ struct _cpu_scatter_gather_dim_loop { scalar_t* src_data, int64_t src_dim_stride, int64_t dim, int64_t index_dim_size, int64_t index_upper_bound, - const func_t& f + func_t& f ) { for (int64_t i = 0; i < index_dim_size; ++i) { @@ -35,17 +87,141 @@ struct _cpu_scatter_gather_dim_loop { ); } } + + template + void operator()( + scalar_t* self_data, int64_t self_dim_stride, + int64_t* index_data, int64_t index_dim_stride, + Scalar value, + int64_t dim, int64_t index_dim_size, + int64_t index_upper_bound, + func_t& f + ) { + + for (int64_t i = 0; i < index_dim_size; ++i) { + int64_t idx_dim = index_data[i * index_dim_stride]; + // we are not putting idx_dim in the error message because it disables + // loop optimization in clang-7 + TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound, + "index ", index_data[i * index_dim_stride], + " is out of bounds for dimension ", dim, + " with size ", index_upper_bound + ); + auto temp = value.to(); + f( + self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, &temp + ); + } + } }; + template struct cpu_scatter_gather_base_kernel { template - void operator()( - Tensor& self, int64_t dim, + void operator()(Tensor& self, int64_t dim, + const Tensor& index, Scalar& value, + const std::string& method_name, func_t& kernel_func) { + // no-op if index is empty + if (index.numel() == 0) { + return; + } + + dim = maybe_wrap_dim(dim, self.dim()); + + if (is_scatter_like) { + scatter_shape_check(self, dim, index, self); + } + else { + gather_shape_check(self, dim, index, self); + } + + auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); + auto index_strides = ensure_nonempty_vec(index.strides().vec()); + + // `dim` is traversed in the kernel, + // that is why index.stride(dim) = 0 and index.size(dim) = 1. + // Also, index.size(dim) = 1 makes sure that TensorIterator.DimCounter + // has the following form : (i_1,..., i_{dim-1}, 0, i_{dim+1},...,i_n). + index_sizes[dim] = 1; + index_strides[dim] = 0; + + auto iter = TensorIteratorConfig() + .check_all_same_dtype(false) + .resize_outputs(false) + .declare_static_shape(index.sizes(), /*squash_dim=*/dim) + .add_output(self) + .add_input(index) + .build(); + + auto self_dim_stride = ensure_nonempty_stride(self, dim); + auto self_dim_size = ensure_nonempty_size(self, dim); + + auto index_dim_stride = ensure_nonempty_stride(index, dim); + auto index_dim_size = ensure_nonempty_size(index, dim); + + auto index_upper_bound = self_dim_size; + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + ScalarType::Bool, ScalarType::Half, iter.dtype(), + method_name, [&] { + constexpr auto SELF_ITER_STRIDE_IDX = 0; + constexpr auto INDEX_ITER_STRIDE_IDX = 1; + + auto loop = [&](char** data, const int64_t* strides, int64_t n) { + auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX]; + auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX]; + // we change the order of TensorIterator-dim loop + // vs dim-TensorIterator loop order depending on + // whether dim is the last dimension and/or + // whether `n` is smaller than `index_dim_size` + + if ((dim== self.dim() - 1) || (n < index_dim_size)) { + for (int64_t nelem = 0; nelem < n; ++nelem) { + // dim loop is a separate code block + // for better performance + _cpu_scatter_gather_dim_loop()( + (scalar_t*)self_data_bytes, self_dim_stride, + (int64_t*)index_data_bytes, index_dim_stride, + value, dim, index_dim_size, index_upper_bound, + kernel_func); + + self_data_bytes += strides[SELF_ITER_STRIDE_IDX]; + index_data_bytes += strides[INDEX_ITER_STRIDE_IDX]; + } + } + else { + for (int64_t i = 0; i < index_dim_size; ++i) { + auto* self_data = self_data_bytes; + auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride); + for (int64_t nelem = 0; nelem < n; ++nelem) { + int64_t idx_dim = *(int64_t*)index_data; + // we are not putting idx_dim in the error message because it disables + // loop optimization in clang-7 + TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound, + "index ", *(int64_t*)index_data, + " is out of bounds for dimension ", dim, + " with size ", index_upper_bound); + + auto temp = value.to(); + kernel_func((scalar_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, &temp); + + self_data += strides[SELF_ITER_STRIDE_IDX]; + index_data += strides[INDEX_ITER_STRIDE_IDX]; + } + } + } + }; + iter.for_each(loop); + } + ); + } + + template + void operator()(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, - const std::string& method_name, - const func_t& f - ) { + const std::string& method_name, func_t& kernel_func) { + // no-op if index is empty if (index.numel() == 0) { return; @@ -75,7 +251,7 @@ struct cpu_scatter_gather_base_kernel { auto index_dim_stride = ensure_nonempty_stride(index, dim); auto index_dim_size = ensure_nonempty_size(index, dim); - + auto src_dim_stride = ensure_nonempty_stride(src, dim); auto src_dim_size = ensure_nonempty_size(src, dim); @@ -84,30 +260,28 @@ struct cpu_scatter_gather_base_kernel { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( ScalarType::Bool, ScalarType::Half, iter.dtype(), method_name, [&] { + constexpr auto SELF_ITER_STRIDE_IDX = 0; + constexpr auto INDEX_ITER_STRIDE_IDX = 2; + constexpr auto SRC_ITER_STRIDE_IDX = 1; auto loop = [&](char** data, const int64_t* strides, int64_t n) { - constexpr auto SELF_ITER_STRIDE_IDX = 0; - constexpr auto INDEX_ITER_STRIDE_IDX = 2; - constexpr auto SRC_ITER_STRIDE_IDX = 1; - auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX]; auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX]; auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX]; - // we change the order of TensorIterator-dim loop // vs dim-TensorIterator loop order depending on // whether dim is the last dimension and/or // whether `n` is smaller than `index_dim_size` - if ((dim == self.dim() - 1) || (n < index_dim_size)) { + if ((dim== self.dim() - 1) || (n < index_dim_size)) { for (int64_t nelem = 0; nelem < n; ++nelem) { // dim loop is a separate code block // for better performance _cpu_scatter_gather_dim_loop()( - (scalar_t*)self_data_bytes, self_dim_stride, - (int64_t*)index_data_bytes, index_dim_stride, - (scalar_t*)src_data_bytes, src_dim_stride, - dim, index_dim_size, index_upper_bound, - f - ); + (scalar_t*)self_data_bytes, self_dim_stride, + (int64_t*)index_data_bytes, index_dim_stride, + (scalar_t*)src_data_bytes, src_dim_stride, + dim, index_dim_size, index_upper_bound, + kernel_func + ); self_data_bytes += strides[SELF_ITER_STRIDE_IDX]; index_data_bytes += strides[INDEX_ITER_STRIDE_IDX]; @@ -124,15 +298,13 @@ struct cpu_scatter_gather_base_kernel { // we are not putting idx_dim in the error message because it disables // loop optimization in clang-7 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound, - "index ", *(int64_t*)index_data, - " is out of bounds for dimension ", dim, - " with size ", index_upper_bound - ); + "index ", *(int64_t*)index_data, + " is out of bounds for dimension ", dim, + " with size ", index_upper_bound); - f( + kernel_func( (scalar_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, - (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride - ); + (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride); self_data += strides[SELF_ITER_STRIDE_IDX]; index_data += strides[INDEX_ITER_STRIDE_IDX]; @@ -140,50 +312,76 @@ struct cpu_scatter_gather_base_kernel { } } } - }; - iter.for_each(loop); } ); } -}; // struct cpu_scatter_gather_base_kernel +}; void gather_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) { cpu_scatter_gather_base_kernel()( result, dim, index, self, - "gather_out_cpu", [] (auto* lhs, const auto* rhs) { - *lhs = *rhs; - } - ); + "gather_out_cpu", tensor_assign); } void scatter_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { cpu_scatter_gather_base_kernel<>()( - self, dim, index, src, - "scatter_cpu_", [] (auto* lhs, const auto* rhs) { - *lhs = *rhs; - } - ); + self, dim, index, src, "scatter_cpu_", tensor_assign); } -void scatter_fill_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, Scalar src) { +void scatter_fill_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, Scalar value) { cpu_scatter_gather_base_kernel<>()( - self, dim, index, self, - "scatter_fill_cpu_", [src] (auto* lhs, const auto* rhs) { - using scalar_t = typename std::remove_pointer::type; - *lhs = src.to(); - } - ); + self, dim, index, value, "scatter_fill_cpu_", tensor_assign); } void scatter_add_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { cpu_scatter_gather_base_kernel<>()( self, dim, index, src, - "scatter_add_", [] (auto* lhs, const auto* rhs) { - *lhs += *rhs; - } - ); + "scatter_add_", reduce_add); + +} + +void scatter_reduce_cpu_kernel(Tensor& self, const int64_t dim, const Tensor& index, + const Tensor& src, const SCATTER_GATHER_OP& reduce) { + switch (reduce) { + case SCATTER_GATHER_OP::REDUCE_ADD : + cpu_scatter_gather_base_kernel<>()(self, dim, index, src, + "scatter_reduce_add_", reduce_add); + break; + case SCATTER_GATHER_OP::REDUCE_SUBTRACT : + cpu_scatter_gather_base_kernel<>()(self, dim, index, src, + "scatter_reduce_subtract_", reduce_subtract); + break; + case SCATTER_GATHER_OP::REDUCE_MULTIPLY : + cpu_scatter_gather_base_kernel<>()(self, dim, index, src, + "scatter_reduce_multiply_", reduce_multiply); + break; + case SCATTER_GATHER_OP::REDUCE_DIVIDE : + cpu_scatter_gather_base_kernel<>()(self, dim, index, src, + "scatter_reduce_divide_", reduce_divide); + } +} + +void scatter_scalar_reduce_cpu_kernel(Tensor& self, const int64_t dim, const Tensor& index, + Scalar& value, const SCATTER_GATHER_OP& reduce) { + switch (reduce) { + case SCATTER_GATHER_OP::REDUCE_ADD : + cpu_scatter_gather_base_kernel<>()(self, dim, index, value, + "scatter_scalar_reduce_add_", reduce_add); + break; + case SCATTER_GATHER_OP::REDUCE_SUBTRACT : + cpu_scatter_gather_base_kernel<>()(self, dim, index, value, + "scatter_scalar_reduce_subtract_", reduce_subtract); + break; + case SCATTER_GATHER_OP::REDUCE_MULTIPLY : + cpu_scatter_gather_base_kernel<>()(self, dim, index, value, + "scatter_scalar_reduce_multiply_", reduce_multiply); + break; + case SCATTER_GATHER_OP::REDUCE_DIVIDE : + cpu_scatter_gather_base_kernel<>()(self, dim, index, value, + "scatter_scalar_reduce_divide_", reduce_divide); + } } } // anonymous namespace @@ -192,5 +390,7 @@ REGISTER_DISPATCH(gather_stub, &gather_cpu_kernel); REGISTER_DISPATCH(scatter_stub, &scatter_cpu_kernel); REGISTER_DISPATCH(scatter_fill_stub, &scatter_fill_cpu_kernel); REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cpu_kernel); +REGISTER_DISPATCH(scatter_reduce_stub, &scatter_reduce_cpu_kernel); +REGISTER_DISPATCH(scatter_scalar_reduce_stub, &scatter_scalar_reduce_cpu_kernel); }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 2549d0c5766abd..57e6b1d0a7f961 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3914,6 +3914,16 @@ - func: scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor variants: function, method +- func: scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!) + variants: method + dispatch: + CPU: scatter_cpu_reduce_ + +- func: scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!) + variants: method + dispatch: + CPU: scatter_cpu_scalar_reduce_ + - func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) variants: method dispatch: diff --git a/test/test_torch.py b/test/test_torch.py index 194b8d6dd847d2..949d28557a771f 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2676,7 +2676,7 @@ def test_scatter_add_mult_index(self): self._test_scatter_add_mult_index_base(self, lambda t: t) @staticmethod - def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True, *, test_complex=False): + def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True, reduction=None, *, test_complex=False): if test_complex: dtypes = [torch.complex64, torch.complex128] else: @@ -2699,7 +2699,10 @@ def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True, *, src = cast(torch.randn(src_size, dtype=dtype)) base = cast(torch.randn(m, n, o, dtype=dtype)) - actual = getattr(base.clone(), method)(dim, idx, src) + if reduction: + actual = getattr(base.clone(), method)(dim, idx, src, reduce=reduction) + else: + actual = getattr(base.clone(), method)(dim, idx, src) expected = base.clone() for i in range(idx_size[0]): for j in range(idx_size[1]): @@ -2707,7 +2710,17 @@ def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True, *, ii = [i, j, k] ii[dim] = idx[i, j, k] if method == 'scatter_' and not is_scalar: - expected[tuple(ii)] = src[i, j, k] + if reduction: + if reduction == "add": + expected[tuple(ii)] += src[i, j, k] + elif reduction == "subtract": + expected[tuple(ii)] -= src[i, j, k] + elif reduction == "multiply": + expected[tuple(ii)] *= src[i, j, k] + elif reduction == "divide": + expected[tuple(ii)] /= src[i, j, k] + else: + expected[tuple(ii)] = src[i, j, k] elif method == 'scatter_add_': expected[tuple(ii)] += src[i, j, k] else: @@ -2725,17 +2738,23 @@ def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True, *, getattr(base.clone(), method)(dim, idx, src.type(torch.int)) # should throw an error when index dtype is not long - with self.assertRaisesRegex(RuntimeError, 'Expected dtype int64 for index'): + with self.assertRaisesRegex(IndexError, 'Expected dtype int64 for index'): getattr(base.clone(), method)(dim, idx.type(torch.int), src) if test_bounds: idx[0][0][0] = 34 with self.assertRaises(RuntimeError): - getattr(base.clone(), method)(dim, idx, src) + if reduction: + getattr(base.clone(), method)(dim, idx, src, reduce=reduction) + else: + getattr(base.clone(), method)(dim, idx, src) # test for empty index, should be a no-op idx = cast(torch.LongTensor()) - actual = getattr(base.clone(), method)(dim, idx, src) + if reduction: + actual = getattr(base.clone(), method)(dim, idx, src, reduce=reduction) + else: + actual = getattr(base.clone(), method)(dim, idx, src) self.assertEqual(actual, base, atol=0, rtol=0) def test_scatter(self): @@ -2747,6 +2766,10 @@ def test_scatterAdd(self): def test_scatterFill(self): self._test_scatter_base(self, lambda t: t, 'scatter_', True) + def test_scatterReduce(self): + for method in ["add", "subtract", "multiply", "divide"]: + self._test_scatter_base(self, lambda t: t, 'scatter_', reduction=method) + def test_masked_scatter(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -12540,6 +12563,116 @@ def test_put_empty(self, device): src = torch.randn(indices_shape, device=device) self.assertEqual(dst, dst.put_(indices, src, accumulate=accumulate)) + @onlyCPU + def test_scatter_reduce_operations_to_large_input(self, device): + index = torch.tensor([[1], [2]], device=device, dtype=torch.long) + test_data = [ + (torch.zeros(4, 4, device=device, dtype=torch.float32), + torch.ones(2, 2, device=device, dtype=torch.float32), + torch.tensor([[0, 0, 0, 0], + [1, 0, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 0]], + device=device, dtype=torch.float32), "add"), + (torch.zeros(4, 4, device=device, dtype=torch.float32), + torch.ones(2, 2, device=device, dtype=torch.float32), + torch.tensor([[0, 0, 0, 0], + [-1, 0, 0, 0], + [-1, 0, 0, 0], + [0, 0, 0, 0]], device=device, dtype=torch.float32), "subtract"), + (torch.tensor([2], device=device, dtype=torch.float32).repeat(4, 4), + torch.tensor([2], device=device, dtype=torch.float32).repeat(2, 2), + torch.tensor([[2, 2, 2, 2], + [4, 2, 2, 2], + [4, 2, 2, 2], + [2, 2, 2, 2]], device=device, dtype=torch.float32), "multiply"), + (torch.tensor([2], device=device, dtype=torch.float32).repeat(4, 4), + torch.tensor([2], device=device, dtype=torch.float32).repeat(2, 2), + torch.tensor([[2, 2, 2, 2], + [1, 2, 2, 2], + [1, 2, 2, 2], + [2, 2, 2, 2]], device=device, dtype=torch.float32), "divide") + ] + + for input, src, result, operation in test_data: + input.scatter_(0, index, src, reduce=operation) + self.assertEqual(input, result) + + @onlyCPU + def test_scatter_reduce_scalar(self, device): + index = torch.tensor([[1], [2]], device=device, dtype=torch.long) + test_data = [ + (torch.zeros(4, 4, device=device, dtype=torch.float32), 1, + torch.tensor([[0, 0, 0, 0], + [1, 0, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 0]], + device=device, dtype=torch.float32), "add"), + (torch.zeros(4, 4, device=device, dtype=torch.float32), 1, + torch.tensor([[0, 0, 0, 0], + [-1, 0, 0, 0], + [-1, 0, 0, 0], + [0, 0, 0, 0]], device=device, dtype=torch.float32), "subtract"), + (torch.tensor([2], device=device, dtype=torch.float32).repeat(4, 4), 2, + torch.tensor([[2, 2, 2, 2], + [4, 2, 2, 2], + [4, 2, 2, 2], + [2, 2, 2, 2]], device=device, dtype=torch.float32), "multiply"), + (torch.tensor([2], device=device, dtype=torch.float32).repeat(4, 4), 2, + torch.tensor([[2, 2, 2, 2], + [1, 2, 2, 2], + [1, 2, 2, 2], + [2, 2, 2, 2]], device=device, dtype=torch.float32), "divide") + ] + + for input, src, result, operation in test_data: + input.scatter_(0, index, src, reduce=operation) + self.assertEqual(input, result) + + # TODO: remove this after scatter_add_ is deprecated. + def test_scatter_add_non_unique_index(self, device): + height = 2 + width = 65536 + input = torch.ones(height, width, device=device) + index = torch.zeros(height, width, dtype=torch.long, device=device) + src = torch.ones(height, width, device=device) + input.scatter_add_(0, index, src) + + self.assertEqual(input, + torch.tensor([[3], [1]], device=device, + dtype=torch.float32).repeat(1, width)) + + @onlyCPU + def test_scatter_reduce_non_unique_index(self, device): + height = 2 + width = 2 + index = torch.zeros(height, width, dtype=torch.long, device=device) + test_data = [ + (torch.ones(height, width, device=device, dtype=torch.float32), + torch.ones(height, width, device=device, dtype=torch.float32), + torch.tensor([[3], [1]], device=device, dtype=torch.float32).repeat(1, width), "add"), + + (torch.ones(height, width, device=device, dtype=torch.float32), + torch.ones(height, width, device=device, dtype=torch.float32), + torch.tensor([[-1], [1]], device=device, + dtype=torch.float32).repeat(1, width), "subtract"), + + (torch.tensor([2], device=device, dtype=torch.float32).repeat(height, width), + torch.tensor([2], device=device, dtype=torch.float32).repeat(height, width), + torch.tensor([[8], [2]], device=device, + dtype=torch.float32).repeat(1, width), "multiply"), + + (torch.tensor([2], device=device, dtype=torch.float32).repeat(height, width), + torch.tensor([2], device=device, dtype=torch.float32).repeat(height, width), + torch.tensor([[0.5], [2]], device=device, + dtype=torch.float32).repeat(1, width), "divide"), + ] + + for input, src, result, operation in test_data: + input.scatter_(0, index, src, reduce=operation) + self.assertEqual(input, result) + + def test_scatter_to_large_input(self, device): input = torch.zeros(4, 4, device=device) src = torch.ones(2, 2, device=device)