Skip to content

Commit

Permalink
Implement parallel scatter reductions for CPU (pytorch#36447)
Browse files Browse the repository at this point in the history
Summary:
This PR implements pytorchgh-33389.

As a result of this PR, users can now specify various reduction modes for scatter operations. Currently, `add`, `subtract`, `multiply` and `divide` have been implemented, and adding new ones is not hard.

While we now allow dynamic runtime selection of reduction modes, the performance is the same as as was the case for the `scatter_add_` method in the master branch. Proof can be seen in the graph below, which compares `scatter_add_` in the master branch (blue) and `scatter_(reduce="add")` from this PR (orange).
![scatter-regression py csv](https://user-images.githubusercontent.com/2629909/82671491-e5e22380-9c79-11ea-95d6-6344760c8578.png)

The script used for benchmarking is as follows:
``` python
import os
import sys
import torch
import time
import numpy
from IPython import get_ipython

Ms=256
Ns=512
dim = 0
top_power = 2
ipython = get_ipython()

plot_name = os.path.basename(__file__)
branch = sys.argv[1]
fname = open(plot_name + ".csv", "a+")

for pM in range(top_power):
    M = Ms * (2 ** pM)
    for pN in range(top_power):
        N = Ns * (2 ** pN)
        input_one = torch.rand(M, N)
        index = torch.tensor(numpy.random.randint(0, M, (M, N)))
        res = torch.randn(M, N)

        test_case = f"{M}x{N}"
        print(test_case)
        tobj = ipython.magic("timeit -o res.scatter_(dim, index, input_one, reduce=\"add\")")

        fname.write(f"{test_case},{branch},{tobj.average},{tobj.stdev}\n")

fname.close()
```

Additionally, one can see that various reduction modes take almost the same time to execute:
```
op: add
70.6 µs ± 27.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
26.1 µs ± 26.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
op: subtract
71 µs ± 20.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
26.4 µs ± 34.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
op: multiply
70.9 µs ± 31.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
27.4 µs ± 29.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
op: divide
164 µs ± 48.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
52.3 µs ± 132 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
Script:
``` python
import torch
import time
import numpy
from IPython import get_ipython

ipython = get_ipython()

nrows = 3000
ncols = 10000
dims = [nrows, ncols]

res = torch.randint(5, 10, dims)
idx1 = torch.randint(dims[0], (1, dims[1])).long()
src1 = torch.randint(5, 10, (1, dims[1]))
idx2 = torch.randint(dims[1], (dims[0], 1)).long()
src2 = torch.randint(5, 10, (dims[0], 1))

for op in ["add", "subtract", "multiply", "divide"]:
    print(f"op: {op}")
    ipython.magic("timeit res.scatter_(0, idx1, src1, reduce=op)")
    ipython.magic("timeit res.scatter_(1, idx2, src2, reduce=op)")
```
Pull Request resolved: pytorch#36447

Differential Revision: D22272631

Pulled By: ngimel

fbshipit-source-id: 3cdb46510f9bb0e135a5c03d6d4aa5de9402ee90
  • Loading branch information
v0dro authored and facebook-github-bot committed Jun 29, 2020
1 parent 11a74a5 commit 9ca4a46
Show file tree
Hide file tree
Showing 5 changed files with 456 additions and 56 deletions.
49 changes: 49 additions & 0 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ DEFINE_DISPATCH(gather_stub);
DEFINE_DISPATCH(scatter_stub);
DEFINE_DISPATCH(scatter_fill_stub);
DEFINE_DISPATCH(scatter_add_stub);
DEFINE_DISPATCH(scatter_reduce_stub);
DEFINE_DISPATCH(scatter_scalar_reduce_stub);

static bool all_strides_match(TensorList tensors) {
TORCH_CHECK(tensors.size() >= 1);
Expand Down Expand Up @@ -533,15 +535,60 @@ Tensor gather(const Tensor & self, int64_t dim, const Tensor & index, bool spars
}

Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long,
"scatter_(): Expected dtype int64 for index.");
scatter_stub(self.device().type(), self, dim, index, source);
return self;
}

Tensor & scatter_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar source) {
TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long,
"scatter_(): Expected dtype int64 for index.");
scatter_fill_stub(self.device().type(), self, dim, index, source);
return self;
}

SCATTER_GATHER_OP get_operator_enum(const std::string& reduce) {
if (reduce == "add") {
return SCATTER_GATHER_OP::REDUCE_ADD;
}
else if (reduce == "subtract") {
return SCATTER_GATHER_OP::REDUCE_SUBTRACT;
}
else if (reduce == "multiply") {
return SCATTER_GATHER_OP::REDUCE_MULTIPLY;
}
else if (reduce == "divide") {
return SCATTER_GATHER_OP::REDUCE_DIVIDE;
}
else {
TORCH_CHECK(false,
"reduce argument must be either of add, subtract, multiply or divide.");
}
}

Tensor& scatter_cpu_scalar_reduce_(Tensor& self, const int64_t dim, const Tensor& index,
Scalar value, const std::string reduce) {
TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long,
"scatter_(): Expected dtype int64 for index.");
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"scatter_(): Expected floating or complex type for self.");
SCATTER_GATHER_OP op = get_operator_enum(reduce);
scatter_scalar_reduce_stub(self.device().type(), self, dim, index, value, op);
return self;
}

Tensor & scatter_cpu_reduce_(Tensor & self, const int64_t dim, const Tensor & index,
const Tensor & src, const std::string reduce) {
TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long,
"scatter_(): Expected dtype int64 for index");
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"scatter_(): Expected floating or complex type for self.");
SCATTER_GATHER_OP op = get_operator_enum(reduce);
scatter_reduce_stub(self.device().type(), self, dim, index, src, op);
return self;
}

Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone(at::MemoryFormat::Preserve).scatter_(dim, index, source);
}
Expand All @@ -551,6 +598,8 @@ Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar so
}

Tensor & scatter_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) {
TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long,
"scatter_(): Expected dtype int64 for index.");
scatter_add_stub(self.device().type(), self, dim, index, src);
return self;
}
Expand Down
10 changes: 9 additions & 1 deletion aten/src/ATen/native/TensorAdvancedIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace at {

namespace at { namespace native {

enum class SCATTER_GATHER_OP: uint8_t {REDUCE_ADD, REDUCE_SUBTRACT, REDUCE_MULTIPLY, REDUCE_DIVIDE};

using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
using index_put_accum_fn = void(*)(Tensor &, TensorList , const Tensor &, bool unsafe);
Expand All @@ -21,7 +23,11 @@ using gather_fn = void (*)(Tensor & result, const Tensor & self, int64_t dim, co
using scatter_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
using scatter_fill_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, Scalar src);
using scatter_add_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);

using scatter_reduce_fn = void(*)(Tensor& self, const int64_t dim, const Tensor& index,
const Tensor& src, const SCATTER_GATHER_OP& reduce);
using scatter_scalar_reduce_fn = void(*)(Tensor& self, const int64_t dim, const Tensor& index,
Scalar& value, const SCATTER_GATHER_OP& reduce);

DECLARE_DISPATCH(index_fn, index_stub);
DECLARE_DISPATCH(index_put_fn, index_put_stub);
DECLARE_DISPATCH(index_put_accum_fn, index_put_accum_stub);
Expand All @@ -33,6 +39,8 @@ DECLARE_DISPATCH(gather_fn, gather_stub);
DECLARE_DISPATCH(scatter_fn, scatter_stub);
DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub);
DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);

TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices);

Expand Down
Loading

0 comments on commit 9ca4a46

Please sign in to comment.