From 7c15c3cdac46e9653ecf29a5eda6af45432fd85e Mon Sep 17 00:00:00 2001 From: Feng Shi Date: Mon, 18 Oct 2021 19:13:14 +0800 Subject: [PATCH] Add function for the addition of two matrices (#177) * Create spadd.py Hi, Maybe it's trivial to have this function, but I still think it'll be helpful and it looks neat when applying matrix addition, i.e., C = A + B. Thanks * update * update * fix jit Co-authored-by: rusty1s --- test/test_add.py | 33 +++++++++++++++++++ torch_sparse/__init__.py | 2 ++ torch_sparse/add.py | 71 ++++++++++++++++++++++++++++++---------- torch_sparse/spadd.py | 18 ++++++++++ torch_sparse/storage.py | 2 +- 5 files changed, 107 insertions(+), 19 deletions(-) create mode 100644 test/test_add.py create mode 100644 torch_sparse/spadd.py diff --git a/test/test_add.py b/test/test_add.py new file mode 100644 index 00000000..e4839220 --- /dev/null +++ b/test/test_add.py @@ -0,0 +1,33 @@ +from itertools import product + +import pytest +import torch +from torch_sparse import SparseTensor, add + +from .utils import dtypes, devices, tensor + + +@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) +def test_add(dtype, device): + rowA = torch.tensor([0, 0, 1, 2, 2], device=device) + colA = torch.tensor([0, 2, 1, 0, 1], device=device) + valueA = tensor([1, 2, 4, 1, 3], dtype, device) + A = SparseTensor(row=rowA, col=colA, value=valueA) + + rowB = torch.tensor([0, 0, 1, 2, 2], device=device) + colB = torch.tensor([1, 2, 2, 1, 2], device=device) + valueB = tensor([2, 3, 1, 2, 4], dtype, device) + B = SparseTensor(row=rowB, col=colB, value=valueB) + + C = A + B + rowC, colC, valueC = C.coo() + + assert rowC.tolist() == [0, 0, 0, 1, 1, 2, 2, 2] + assert colC.tolist() == [0, 1, 2, 1, 2, 0, 1, 2] + assert valueC.tolist() == [1, 2, 5, 4, 1, 1, 5, 4] + + @torch.jit.script + def jit_add(A: SparseTensor, B: SparseTensor) -> SparseTensor: + return add(A, B) + + jit_add(A, B) diff --git a/torch_sparse/__init__.py b/torch_sparse/__init__.py index 3661832d..de60d65a 100644 --- a/torch_sparse/__init__.py +++ b/torch_sparse/__init__.py @@ -65,6 +65,7 @@ from .eye import eye # noqa from .spmm import spmm # noqa from .spspmm import spspmm # noqa +from .spadd import spadd # noqa __all__ = [ 'SparseStorage', @@ -111,5 +112,6 @@ 'eye', 'spmm', 'spspmm', + 'spadd', '__version__', ] diff --git a/torch_sparse/add.py b/torch_sparse/add.py index ea6a6e64..91aa41d7 100644 --- a/torch_sparse/add.py +++ b/torch_sparse/add.py @@ -1,34 +1,69 @@ from typing import Optional import torch +from torch import Tensor from torch_scatter import gather_csr from torch_sparse.tensor import SparseTensor -def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor: - rowptr, col, value = src.csr() - if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... - other = gather_csr(other.squeeze(1), rowptr) - pass - elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise... - other = other.squeeze(0)[col] - else: - raise ValueError( - f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or ' - f'(1, {src.size(1)}, ...), but got size {other.size()}.') - if value is not None: - value = other.to(value.dtype).add_(value) +@torch.jit._overload # noqa: F811 +def add(src, other): # noqa: F811 + # type: (SparseTensor, Tensor) -> SparseTensor + pass + + +@torch.jit._overload # noqa: F811 +def add(src, other): # noqa: F811 + # type: (SparseTensor, SparseTensor) -> SparseTensor + pass + + +def add(src, other): # noqa: F811 + if isinstance(other, Tensor): + rowptr, col, value = src.csr() + if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise. + other = gather_csr(other.squeeze(1), rowptr) + elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise. + other = other.squeeze(0)[col] + else: + raise ValueError( + f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or ' + f'(1, {src.size(1)}, ...), but got size {other.size()}.') + if value is not None: + value = other.to(value.dtype).add_(value) + else: + value = other.add_(1) + return src.set_value(value, layout='coo') + + elif isinstance(other, SparseTensor): + rowA, colA, valueA = src.coo() + rowB, colB, valueB = other.coo() + + row = torch.cat([rowA, rowB], dim=0) + col = torch.cat([colA, colB], dim=0) + + value: Optional[Tensor] = None + if valueA is not None and valueB is not None: + value = torch.cat([valueA, valueB], dim=0) + + M = max(src.size(0), other.size(0)) + N = max(src.size(1), other.size(1)) + sparse_sizes = (M, N) + + out = SparseTensor(row=row, col=col, value=value, + sparse_sizes=sparse_sizes) + out = out.coalesce(reduce='sum') + return out + else: - value = other.add_(1) - return src.set_value(value, layout='coo') + raise NotImplementedError def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: rowptr, col, value = src.csr() - if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... + if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise. other = gather_csr(other.squeeze(1), rowptr) - pass - elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise... + elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise. other = other.squeeze(0)[col] else: raise ValueError( diff --git a/torch_sparse/spadd.py b/torch_sparse/spadd.py new file mode 100644 index 00000000..b459dec4 --- /dev/null +++ b/torch_sparse/spadd.py @@ -0,0 +1,18 @@ +import torch +from torch_sparse import coalesce + + +def spadd(indexA, valueA, indexB, valueB, m, n): + """Matrix addition of two sparse matrices. + + Args: + indexA (:class:`LongTensor`): The index tensor of first sparse matrix. + valueA (:class:`Tensor`): The value tensor of first sparse matrix. + indexB (:class:`LongTensor`): The index tensor of second sparse matrix. + valueB (:class:`Tensor`): The value tensor of second sparse matrix. + m (int): The first dimension of the sparse matrices. + n (int): The second dimension of the sparse matrices. + """ + index = torch.cat([indexA, indexB], dim=-1) + value = torch.cat([valueA, valueB], dim=0) + return coalesce(index=index, value=value, m=m, n=n, op='add') diff --git a/torch_sparse/storage.py b/torch_sparse/storage.py index 55e8c98b..c860238b 100644 --- a/torch_sparse/storage.py +++ b/torch_sparse/storage.py @@ -292,7 +292,7 @@ def sparse_reshape(self, num_rows: int, num_cols: int): idx = self.sparse_size(1) * self.row() + self.col() - row = idx // num_cols + row = torch.div(idx, num_cols, rounding_mode='floor') col = idx % num_cols assert row.dtype == torch.long and col.dtype == torch.long