Skip to content

Commit

Permalink
Add function for the addition of two matrices (rusty1s#177)
Browse files Browse the repository at this point in the history
* Create spadd.py

Hi,
Maybe it's trivial to have this function, but I still think it'll be helpful and it looks neat when applying matrix addition, i.e., C = A + B.
Thanks

* update

* update

* fix jit

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
shi27feng and rusty1s authored Oct 18, 2021
1 parent 28f1295 commit 7c15c3c
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 19 deletions.
33 changes: 33 additions & 0 deletions test/test_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from itertools import product

import pytest
import torch
from torch_sparse import SparseTensor, add

from .utils import dtypes, devices, tensor


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_add(dtype, device):
rowA = torch.tensor([0, 0, 1, 2, 2], device=device)
colA = torch.tensor([0, 2, 1, 0, 1], device=device)
valueA = tensor([1, 2, 4, 1, 3], dtype, device)
A = SparseTensor(row=rowA, col=colA, value=valueA)

rowB = torch.tensor([0, 0, 1, 2, 2], device=device)
colB = torch.tensor([1, 2, 2, 1, 2], device=device)
valueB = tensor([2, 3, 1, 2, 4], dtype, device)
B = SparseTensor(row=rowB, col=colB, value=valueB)

C = A + B
rowC, colC, valueC = C.coo()

assert rowC.tolist() == [0, 0, 0, 1, 1, 2, 2, 2]
assert colC.tolist() == [0, 1, 2, 1, 2, 0, 1, 2]
assert valueC.tolist() == [1, 2, 5, 4, 1, 1, 5, 4]

@torch.jit.script
def jit_add(A: SparseTensor, B: SparseTensor) -> SparseTensor:
return add(A, B)

jit_add(A, B)
2 changes: 2 additions & 0 deletions torch_sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from .eye import eye # noqa
from .spmm import spmm # noqa
from .spspmm import spspmm # noqa
from .spadd import spadd # noqa

__all__ = [
'SparseStorage',
Expand Down Expand Up @@ -111,5 +112,6 @@
'eye',
'spmm',
'spspmm',
'spadd',
'__version__',
]
71 changes: 53 additions & 18 deletions torch_sparse/add.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,69 @@
from typing import Optional

import torch
from torch import Tensor
from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor


def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
if value is not None:
value = other.to(value.dtype).add_(value)
@torch.jit._overload # noqa: F811
def add(src, other): # noqa: F811
# type: (SparseTensor, Tensor) -> SparseTensor
pass


@torch.jit._overload # noqa: F811
def add(src, other): # noqa: F811
# type: (SparseTensor, SparseTensor) -> SparseTensor
pass


def add(src, other): # noqa: F811
if isinstance(other, Tensor):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise.
other = gather_csr(other.squeeze(1), rowptr)
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise.
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
if value is not None:
value = other.to(value.dtype).add_(value)
else:
value = other.add_(1)
return src.set_value(value, layout='coo')

elif isinstance(other, SparseTensor):
rowA, colA, valueA = src.coo()
rowB, colB, valueB = other.coo()

row = torch.cat([rowA, rowB], dim=0)
col = torch.cat([colA, colB], dim=0)

value: Optional[Tensor] = None
if valueA is not None and valueB is not None:
value = torch.cat([valueA, valueB], dim=0)

M = max(src.size(0), other.size(0))
N = max(src.size(1), other.size(1))
sparse_sizes = (M, N)

out = SparseTensor(row=row, col=col, value=value,
sparse_sizes=sparse_sizes)
out = out.coalesce(reduce='sum')
return out

else:
value = other.add_(1)
return src.set_value(value, layout='coo')
raise NotImplementedError


def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise.
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise.
other = other.squeeze(0)[col]
else:
raise ValueError(
Expand Down
18 changes: 18 additions & 0 deletions torch_sparse/spadd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from torch_sparse import coalesce


def spadd(indexA, valueA, indexB, valueB, m, n):
"""Matrix addition of two sparse matrices.
Args:
indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
valueA (:class:`Tensor`): The value tensor of first sparse matrix.
indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
valueB (:class:`Tensor`): The value tensor of second sparse matrix.
m (int): The first dimension of the sparse matrices.
n (int): The second dimension of the sparse matrices.
"""
index = torch.cat([indexA, indexB], dim=-1)
value = torch.cat([valueA, valueB], dim=0)
return coalesce(index=index, value=value, m=m, n=n, op='add')
2 changes: 1 addition & 1 deletion torch_sparse/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def sparse_reshape(self, num_rows: int, num_cols: int):

idx = self.sparse_size(1) * self.row() + self.col()

row = idx // num_cols
row = torch.div(idx, num_cols, rounding_mode='floor')
col = idx % num_cols
assert row.dtype == torch.long and col.dtype == torch.long

Expand Down

0 comments on commit 7c15c3c

Please sign in to comment.