diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp index 4a7ba1ff..33ae35ad 100644 --- a/csrc/scatter.cpp +++ b/csrc/scatter.cpp @@ -74,6 +74,38 @@ class ScatterSum : public torch::autograd::Function { } }; +class ScatterMul : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { + dim = dim < 0 ? src.dim() + dim : dim; + ctx->saved_data["dim"] = dim; + ctx->saved_data["src_shape"] = src.sizes(); + index = broadcast(index, src, dim); + auto result = scatter_fw(src, index, dim, optional_out, dim_size, "mul"); + auto out = std::get<0>(result); + ctx->save_for_backward({src, index, out}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto src = saved[0]; + auto index = saved[1]; + auto out = saved[2]; + auto dim = ctx->saved_data["dim"].toInt(); + auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); + auto grad_in = torch::gather(grad_out * out, dim, index, false).div_(src); + grad_in.masked_fill_(grad_in.isnan(), 0); + return {grad_in, Variable(), Variable(), Variable(), Variable()}; + } +}; + class ScatterMean : public torch::autograd::Function { public: static variable_list forward(AutogradContext *ctx, Variable src, @@ -201,6 +233,12 @@ torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0]; } +torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { + return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0]; +} + torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size) { @@ -225,6 +263,7 @@ scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim, static auto registry = torch::RegisterOperators() .op("torch_scatter::scatter_sum", &scatter_sum) + .op("torch_scatter::scatter_mul", &scatter_mul) .op("torch_scatter::scatter_mean", &scatter_mean) .op("torch_scatter::scatter_min", &scatter_min) .op("torch_scatter::scatter_max", &scatter_max); diff --git a/test/test_scatter.py b/test/test_scatter.py index edec96ac..af295fa9 100644 --- a/test/test_scatter.py +++ b/test/test_scatter.py @@ -7,6 +7,8 @@ from .utils import reductions, tensor, dtypes, devices +reductions = reductions + ['mul'] + tests = [ { 'src': [1, 3, 2, 4, 5, 6], @@ -14,6 +16,7 @@ 'dim': 0, 'sum': [3, 12, 0, 6], 'add': [3, 12, 0, 6], + 'mul': [2, 60, 1, 6], 'mean': [1.5, 4, 0, 6], 'min': [1, 3, 0, 6], 'arg_min': [0, 1, 6, 5], @@ -26,6 +29,7 @@ 'dim': 0, 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], 'add': [[4, 6], [21, 24], [0, 0], [11, 12]], + 'mul': [[1 * 3, 2 * 4], [5 * 7 * 9, 6 * 8 * 10], [1, 1], [11, 12]], 'mean': [[2, 3], [7, 8], [0, 0], [11, 12]], 'min': [[1, 2], [5, 6], [0, 0], [11, 12]], 'arg_min': [[0, 0], [1, 1], [6, 6], [5, 5]], @@ -38,6 +42,7 @@ 'dim': 1, 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], 'add': [[4, 21, 0, 11], [12, 18, 12, 0]], + 'mul': [[1 * 3, 5 * 7 * 9, 1, 11], [2 * 4 * 6, 8 * 10, 12, 1]], 'mean': [[2, 7, 0, 11], [4, 9, 12, 0]], 'min': [[1, 5, 0, 11], [2, 8, 12, 0]], 'arg_min': [[0, 1, 6, 5], [0, 2, 5, 6]], @@ -50,6 +55,7 @@ 'dim': 1, 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], 'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], + 'mul': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 11 * 13]]], 'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]], 'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]], 'arg_min': [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]], @@ -62,6 +68,7 @@ 'dim': 1, 'sum': [[4], [6]], 'add': [[4], [6]], + 'mul': [[3], [8]], 'mean': [[2], [3]], 'min': [[1], [2]], 'arg_min': [[0], [0]], @@ -74,6 +81,7 @@ 'dim': 1, 'sum': [[[4, 4]], [[6, 6]]], 'add': [[[4, 4]], [[6, 6]]], + 'mul': [[[3, 3]], [[8, 8]]], 'mean': [[[2, 2]], [[3, 3]]], 'min': [[[1, 1]], [[2, 2]]], 'arg_min': [[[0, 0]], [[0, 0]]], @@ -125,6 +133,8 @@ def test_out(test, reduce, dtype, device): if reduce == 'sum' or reduce == 'add': expected = expected - 2 + elif reduce == 'mul': + expected = out # We can not really test this here. elif reduce == 'mean': expected = out # We can not really test this here. elif reduce == 'min': diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index 149e99f1..2b40fa9d 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -19,6 +19,9 @@ from .placeholder import cuda_version_placeholder torch.ops.torch_scatter.cuda_version = cuda_version_placeholder + from .placeholder import scatter_placeholder + torch.ops.torch_scatter.scatter_mul = scatter_placeholder + from .placeholder import scatter_arg_placeholder torch.ops.torch_scatter.scatter_min = scatter_arg_placeholder torch.ops.torch_scatter.scatter_max = scatter_arg_placeholder @@ -52,7 +55,7 @@ major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3]) t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')] - if t_major != major or t_minor != minor: + if t_major != major: raise RuntimeError( f'Detected that PyTorch and torch_scatter were compiled with ' f'different CUDA versions. PyTorch has CUDA version ' @@ -60,8 +63,8 @@ f'{major}.{minor}. Please reinstall the torch_scatter that ' f'matches your PyTorch install.') -from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min, - scatter_max, scatter) # noqa +from .scatter import (scatter_sum, scatter_add, scatter_mul, scatter_mean, + scatter_min, scatter_max, scatter) # noqa from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr, segment_min_csr, segment_max_csr, segment_csr, gather_csr) # noqa @@ -74,6 +77,7 @@ __all__ = [ 'scatter_sum', 'scatter_add', + 'scatter_mul', 'scatter_mean', 'scatter_min', 'scatter_max', diff --git a/torch_scatter/scatter.py b/torch_scatter/scatter.py index ec393488..5a391a28 100644 --- a/torch_scatter/scatter.py +++ b/torch_scatter/scatter.py @@ -31,6 +31,13 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, return scatter_sum(src, index, dim, out, dim_size) +@torch.jit.script +def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size) + + @torch.jit.script def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, @@ -127,8 +134,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, with size :attr:`dim_size` at dimension :attr:`dim`. If :attr:`dim_size` is not given, a minimal sized output tensor according to :obj:`index.max() + 1` is returned. - :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`, - :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) + :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`, + :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) :rtype: :class:`Tensor` @@ -150,6 +157,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, """ if reduce == 'sum' or reduce == 'add': return scatter_sum(src, index, dim, out, dim_size) + if reduce == 'mul': + return scatter_mul(src, index, dim, out, dim_size) elif reduce == 'mean': return scatter_mean(src, index, dim, out, dim_size) elif reduce == 'min': diff --git a/torch_scatter/segment_csr.py b/torch_scatter/segment_csr.py index ff6a6309..c6638e02 100644 --- a/torch_scatter/segment_csr.py +++ b/torch_scatter/segment_csr.py @@ -22,16 +22,18 @@ def segment_mean_csr(src: torch.Tensor, indptr: torch.Tensor, @torch.jit.script -def segment_min_csr(src: torch.Tensor, indptr: torch.Tensor, - out: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: +def segment_min_csr( + src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.torch_scatter.segment_min_csr(src, indptr, out) @torch.jit.script -def segment_max_csr(src: torch.Tensor, indptr: torch.Tensor, - out: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: +def segment_max_csr( + src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.torch_scatter.segment_max_csr(src, indptr, out) @@ -51,9 +53,9 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor, Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and :math:`m`-dimensional tensors with size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and - :math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an + :math:`(x_0, ..., x_{m-2}, y)`, respectively, then :attr:`out` must be an :math:`n`-dimensional tensor with size - :math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`. + :math:`(x_0, ..., x_{m-2}, y - 1, x_{m}, ..., x_{n-1})`. Moreover, the values of :attr:`indptr` must be between :math:`0` and :math:`x_m` in ascending order. The :attr:`indptr` tensor supports broadcasting in case its dimensions do @@ -64,7 +66,7 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor, .. math:: \mathrm{out}_i = - \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j. + \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+1]-1}~\mathrm{src}_j. Due to the use of index pointers, :meth:`segment_csr` is the fastest method to apply for grouped reductions.