Skip to content

Commit

Permalink
Merge branch 'master' into build-cpu-and-gpu-versions
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Mar 2, 2021
2 parents 8a07c86 + ab4cd9d commit cc8aa66
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 14 deletions.
39 changes: 39 additions & 0 deletions csrc/scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,38 @@ class ScatterSum : public torch::autograd::Function<ScatterSum> {
}
};

class ScatterMul : public torch::autograd::Function<ScatterMul> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "mul");
auto out = std::get<0>(result);
ctx->save_for_backward({src, index, out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}

static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto src = saved[0];
auto index = saved[1];
auto out = saved[2];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out * out, dim, index, false).div_(src);
grad_in.masked_fill_(grad_in.isnan(), 0);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};

class ScatterMean : public torch::autograd::Function<ScatterMean> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Expand Down Expand Up @@ -201,6 +233,12 @@ torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
}

torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
}

torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
Expand All @@ -225,6 +263,7 @@ scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,

static auto registry = torch::RegisterOperators()
.op("torch_scatter::scatter_sum", &scatter_sum)
.op("torch_scatter::scatter_mul", &scatter_mul)
.op("torch_scatter::scatter_mean", &scatter_mean)
.op("torch_scatter::scatter_min", &scatter_min)
.op("torch_scatter::scatter_max", &scatter_max);
10 changes: 10 additions & 0 deletions test/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

from .utils import reductions, tensor, dtypes, devices

reductions = reductions + ['mul']

tests = [
{
'src': [1, 3, 2, 4, 5, 6],
'index': [0, 1, 0, 1, 1, 3],
'dim': 0,
'sum': [3, 12, 0, 6],
'add': [3, 12, 0, 6],
'mul': [2, 60, 1, 6],
'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6],
'arg_min': [0, 1, 6, 5],
Expand All @@ -26,6 +29,7 @@
'dim': 0,
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
'mul': [[1 * 3, 2 * 4], [5 * 7 * 9, 6 * 8 * 10], [1, 1], [11, 12]],
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
'arg_min': [[0, 0], [1, 1], [6, 6], [5, 5]],
Expand All @@ -38,6 +42,7 @@
'dim': 1,
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
'mul': [[1 * 3, 5 * 7 * 9, 1, 11], [2 * 4 * 6, 8 * 10, 12, 1]],
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
'arg_min': [[0, 1, 6, 5], [0, 2, 5, 6]],
Expand All @@ -50,6 +55,7 @@
'dim': 1,
'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'mul': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 11 * 13]]],
'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
'arg_min': [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]],
Expand All @@ -62,6 +68,7 @@
'dim': 1,
'sum': [[4], [6]],
'add': [[4], [6]],
'mul': [[3], [8]],
'mean': [[2], [3]],
'min': [[1], [2]],
'arg_min': [[0], [0]],
Expand All @@ -74,6 +81,7 @@
'dim': 1,
'sum': [[[4, 4]], [[6, 6]]],
'add': [[[4, 4]], [[6, 6]]],
'mul': [[[3, 3]], [[8, 8]]],
'mean': [[[2, 2]], [[3, 3]]],
'min': [[[1, 1]], [[2, 2]]],
'arg_min': [[[0, 0]], [[0, 0]]],
Expand Down Expand Up @@ -125,6 +133,8 @@ def test_out(test, reduce, dtype, device):

if reduce == 'sum' or reduce == 'add':
expected = expected - 2
elif reduce == 'mul':
expected = out # We can not really test this here.
elif reduce == 'mean':
expected = out # We can not really test this here.
elif reduce == 'min':
Expand Down
10 changes: 7 additions & 3 deletions torch_scatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from .placeholder import cuda_version_placeholder
torch.ops.torch_scatter.cuda_version = cuda_version_placeholder

from .placeholder import scatter_placeholder
torch.ops.torch_scatter.scatter_mul = scatter_placeholder

from .placeholder import scatter_arg_placeholder
torch.ops.torch_scatter.scatter_min = scatter_arg_placeholder
torch.ops.torch_scatter.scatter_max = scatter_arg_placeholder
Expand Down Expand Up @@ -52,16 +55,16 @@
major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]

if t_major != major or t_minor != minor:
if t_major != major:
raise RuntimeError(
f'Detected that PyTorch and torch_scatter were compiled with '
f'different CUDA versions. PyTorch has CUDA version '
f'{t_major}.{t_minor} and torch_scatter has CUDA version '
f'{major}.{minor}. Please reinstall the torch_scatter that '
f'matches your PyTorch install.')

from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min,
scatter_max, scatter) # noqa
from .scatter import (scatter_sum, scatter_add, scatter_mul, scatter_mean,
scatter_min, scatter_max, scatter) # noqa
from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
segment_min_csr, segment_max_csr, segment_csr,
gather_csr) # noqa
Expand All @@ -74,6 +77,7 @@
__all__ = [
'scatter_sum',
'scatter_add',
'scatter_mul',
'scatter_mean',
'scatter_min',
'scatter_max',
Expand Down
13 changes: 11 additions & 2 deletions torch_scatter/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
return scatter_sum(src, index, dim, out, dim_size)


@torch.jit.script
def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)


@torch.jit.script
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -127,8 +134,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`,
:obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor`
Expand All @@ -150,6 +157,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
"""
if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size)
if reduce == 'mul':
return scatter_mul(src, index, dim, out, dim_size)
elif reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size)
elif reduce == 'min':
Expand Down
20 changes: 11 additions & 9 deletions torch_scatter/segment_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ def segment_mean_csr(src: torch.Tensor, indptr: torch.Tensor,


@torch.jit.script
def segment_min_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
def segment_min_csr(
src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_min_csr(src, indptr, out)


@torch.jit.script
def segment_max_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
def segment_max_csr(
src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_max_csr(src, indptr, out)


Expand All @@ -51,9 +53,9 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an
:math:`(x_0, ..., x_{m-2}, y)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`.
:math:`(x_0, ..., x_{m-2}, y - 1, x_{m}, ..., x_{n-1})`.
Moreover, the values of :attr:`indptr` must be between :math:`0` and
:math:`x_m` in ascending order.
The :attr:`indptr` tensor supports broadcasting in case its dimensions do
Expand All @@ -64,7 +66,7 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
.. math::
\mathrm{out}_i =
\sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.
\sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+1]-1}~\mathrm{src}_j.
Due to the use of index pointers, :meth:`segment_csr` is the fastest
method to apply for grouped reductions.
Expand Down

0 comments on commit cc8aa66

Please sign in to comment.