Skip to content

Commit a335324

Browse files
ncassereaurflamary
andauthored
[MRG] Backend for gromov (#294)
* bregman: small correction * gromov backend first draft * Removing decorators * Reworked casting method * Bug solve * Removing casting * Bug solve * toarray renamed todense ; expand_dims removed * Warning (jax not supporting sparse matrix) moved * Mistake corrected * test backend * Sparsity test for older versions of pytorch * Trying pytorch/1.10 * Attempt to correct torch sparse bug * Backend version of gromov tests * Random state introduced for remaining gromov functions * review changes * code coverage * Docs (first draft, to be continued) * Gromov docs * Prettified docs * mistake corrected in the docs * little change Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 0cb2b2e commit a335324

File tree

8 files changed

+1289
-766
lines changed

8 files changed

+1289
-766
lines changed

ot/backend.py

Lines changed: 212 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import numpy as np
2828
import scipy.special as scipy
29+
from scipy.sparse import issparse, coo_matrix, csr_matrix
2930

3031
try:
3132
import torch
@@ -539,6 +540,86 @@ def reshape(self, a, shape):
539540
"""
540541
raise NotImplementedError()
541542

543+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
544+
r"""
545+
Creates a sparse tensor in COOrdinate format.
546+
547+
This function follows the api from :any:`scipy.sparse.coo_matrix`
548+
549+
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
550+
"""
551+
raise NotImplementedError()
552+
553+
def issparse(self, a):
554+
r"""
555+
Checks whether or not the input tensor is a sparse tensor.
556+
557+
This function follows the api from :any:`scipy.sparse.issparse`
558+
559+
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html
560+
"""
561+
raise NotImplementedError()
562+
563+
def tocsr(self, a):
564+
r"""
565+
Converts this matrix to Compressed Sparse Row format.
566+
567+
This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr`
568+
569+
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html
570+
"""
571+
raise NotImplementedError()
572+
573+
def eliminate_zeros(self, a, threshold=0.):
574+
r"""
575+
Removes entries smaller than the given threshold from the sparse tensor.
576+
577+
This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros`
578+
579+
See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html
580+
"""
581+
raise NotImplementedError()
582+
583+
def todense(self, a):
584+
r"""
585+
Converts a sparse tensor to a dense tensor.
586+
587+
This function follows the api from :any:`scipy.sparse.csr_matrix.toarray`
588+
589+
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html
590+
"""
591+
raise NotImplementedError()
592+
593+
def where(self, condition, x, y):
594+
r"""
595+
Returns elements chosen from x or y depending on condition.
596+
597+
This function follows the api from :any:`numpy.where`
598+
599+
See: https://numpy.org/doc/stable/reference/generated/numpy.where.html
600+
"""
601+
raise NotImplementedError()
602+
603+
def copy(self, a):
604+
r"""
605+
Returns a copy of the given tensor.
606+
607+
This function follows the api from :any:`numpy.copy`
608+
609+
See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html
610+
"""
611+
raise NotImplementedError()
612+
613+
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
614+
r"""
615+
Returns True if two arrays are element-wise equal within a tolerance.
616+
617+
This function follows the api from :any:`numpy.allclose`
618+
619+
See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html
620+
"""
621+
raise NotImplementedError()
622+
542623

543624
class NumpyBackend(Backend):
544625
"""
@@ -712,6 +793,46 @@ def stack(self, arrays, axis=0):
712793
def reshape(self, a, shape):
713794
return np.reshape(a, shape)
714795

796+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
797+
if type_as is None:
798+
return coo_matrix((data, (rows, cols)), shape=shape)
799+
else:
800+
return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype)
801+
802+
def issparse(self, a):
803+
return issparse(a)
804+
805+
def tocsr(self, a):
806+
if self.issparse(a):
807+
return a.tocsr()
808+
else:
809+
return csr_matrix(a)
810+
811+
def eliminate_zeros(self, a, threshold=0.):
812+
if threshold > 0:
813+
if self.issparse(a):
814+
a.data[self.abs(a.data) <= threshold] = 0
815+
else:
816+
a[self.abs(a) <= threshold] = 0
817+
if self.issparse(a):
818+
a.eliminate_zeros()
819+
return a
820+
821+
def todense(self, a):
822+
if self.issparse(a):
823+
return a.toarray()
824+
else:
825+
return a
826+
827+
def where(self, condition, x, y):
828+
return np.where(condition, x, y)
829+
830+
def copy(self, a):
831+
return a.copy()
832+
833+
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
834+
return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
835+
715836

716837
class JaxBackend(Backend):
717838
"""
@@ -889,6 +1010,48 @@ def stack(self, arrays, axis=0):
8891010
def reshape(self, a, shape):
8901011
return jnp.reshape(a, shape)
8911012

1013+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
1014+
# Currently, JAX does not support sparse matrices
1015+
data = self.to_numpy(data)
1016+
rows = self.to_numpy(rows)
1017+
cols = self.to_numpy(cols)
1018+
nx = NumpyBackend()
1019+
coo_matrix = nx.coo_matrix(data, rows, cols, shape=shape, type_as=type_as)
1020+
matrix = nx.todense(coo_matrix)
1021+
return self.from_numpy(matrix)
1022+
1023+
def issparse(self, a):
1024+
# Currently, JAX does not support sparse matrices
1025+
return False
1026+
1027+
def tocsr(self, a):
1028+
# Currently, JAX does not support sparse matrices
1029+
return a
1030+
1031+
def eliminate_zeros(self, a, threshold=0.):
1032+
# Currently, JAX does not support sparse matrices
1033+
if threshold > 0:
1034+
return self.where(
1035+
self.abs(a) <= threshold,
1036+
self.zeros((1,), type_as=a),
1037+
a
1038+
)
1039+
return a
1040+
1041+
def todense(self, a):
1042+
# Currently, JAX does not support sparse matrices
1043+
return a
1044+
1045+
def where(self, condition, x, y):
1046+
return jnp.where(condition, x, y)
1047+
1048+
def copy(self, a):
1049+
# No need to copy, JAX arrays are immutable
1050+
return a
1051+
1052+
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
1053+
return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
1054+
8921055

8931056
class TorchBackend(Backend):
8941057
"""
@@ -999,7 +1162,7 @@ def maximum(self, a, b):
9991162
a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
10001163
if isinstance(b, int) or isinstance(b, float):
10011164
b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
1002-
if torch.__version__ >= '1.7.0':
1165+
if hasattr(torch, "maximum"):
10031166
return torch.maximum(a, b)
10041167
else:
10051168
return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
@@ -1009,7 +1172,7 @@ def minimum(self, a, b):
10091172
a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
10101173
if isinstance(b, int) or isinstance(b, float):
10111174
b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
1012-
if torch.__version__ >= '1.7.0':
1175+
if hasattr(torch, "minimum"):
10131176
return torch.minimum(a, b)
10141177
else:
10151178
return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
@@ -1129,3 +1292,50 @@ def stack(self, arrays, axis=0):
11291292

11301293
def reshape(self, a, shape):
11311294
return torch.reshape(a, shape)
1295+
1296+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
1297+
if type_as is None:
1298+
return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape)
1299+
else:
1300+
return torch.sparse_coo_tensor(
1301+
torch.stack([rows, cols]), data, size=shape,
1302+
dtype=type_as.dtype, device=type_as.device
1303+
)
1304+
1305+
def issparse(self, a):
1306+
return getattr(a, "is_sparse", False) or getattr(a, "is_sparse_csr", False)
1307+
1308+
def tocsr(self, a):
1309+
# Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support
1310+
return self.todense(a)
1311+
1312+
def eliminate_zeros(self, a, threshold=0.):
1313+
if self.issparse(a):
1314+
if threshold > 0:
1315+
mask = self.abs(a) <= threshold
1316+
mask = ~mask
1317+
mask = mask.nonzero()
1318+
else:
1319+
mask = a._values().nonzero()
1320+
nv = a._values().index_select(0, mask.view(-1))
1321+
ni = a._indices().index_select(1, mask.view(-1))
1322+
return self.coo_matrix(nv, ni[0], ni[1], shape=a.shape, type_as=a)
1323+
else:
1324+
if threshold > 0:
1325+
a[self.abs(a) <= threshold] = 0
1326+
return a
1327+
1328+
def todense(self, a):
1329+
if self.issparse(a):
1330+
return a.to_dense()
1331+
else:
1332+
return a
1333+
1334+
def where(self, condition, x, y):
1335+
return torch.where(condition, x, y)
1336+
1337+
def copy(self, a):
1338+
return torch.clone(a)
1339+
1340+
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
1341+
return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)

0 commit comments

Comments
 (0)