26
26
27
27
import numpy as np
28
28
import scipy .special as scipy
29
+ from scipy .sparse import issparse , coo_matrix , csr_matrix
29
30
30
31
try :
31
32
import torch
@@ -539,6 +540,86 @@ def reshape(self, a, shape):
539
540
"""
540
541
raise NotImplementedError ()
541
542
543
+ def coo_matrix (self , data , rows , cols , shape = None , type_as = None ):
544
+ r"""
545
+ Creates a sparse tensor in COOrdinate format.
546
+
547
+ This function follows the api from :any:`scipy.sparse.coo_matrix`
548
+
549
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
550
+ """
551
+ raise NotImplementedError ()
552
+
553
+ def issparse (self , a ):
554
+ r"""
555
+ Checks whether or not the input tensor is a sparse tensor.
556
+
557
+ This function follows the api from :any:`scipy.sparse.issparse`
558
+
559
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html
560
+ """
561
+ raise NotImplementedError ()
562
+
563
+ def tocsr (self , a ):
564
+ r"""
565
+ Converts this matrix to Compressed Sparse Row format.
566
+
567
+ This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr`
568
+
569
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html
570
+ """
571
+ raise NotImplementedError ()
572
+
573
+ def eliminate_zeros (self , a , threshold = 0. ):
574
+ r"""
575
+ Removes entries smaller than the given threshold from the sparse tensor.
576
+
577
+ This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros`
578
+
579
+ See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html
580
+ """
581
+ raise NotImplementedError ()
582
+
583
+ def todense (self , a ):
584
+ r"""
585
+ Converts a sparse tensor to a dense tensor.
586
+
587
+ This function follows the api from :any:`scipy.sparse.csr_matrix.toarray`
588
+
589
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html
590
+ """
591
+ raise NotImplementedError ()
592
+
593
+ def where (self , condition , x , y ):
594
+ r"""
595
+ Returns elements chosen from x or y depending on condition.
596
+
597
+ This function follows the api from :any:`numpy.where`
598
+
599
+ See: https://numpy.org/doc/stable/reference/generated/numpy.where.html
600
+ """
601
+ raise NotImplementedError ()
602
+
603
+ def copy (self , a ):
604
+ r"""
605
+ Returns a copy of the given tensor.
606
+
607
+ This function follows the api from :any:`numpy.copy`
608
+
609
+ See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html
610
+ """
611
+ raise NotImplementedError ()
612
+
613
+ def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
614
+ r"""
615
+ Returns True if two arrays are element-wise equal within a tolerance.
616
+
617
+ This function follows the api from :any:`numpy.allclose`
618
+
619
+ See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html
620
+ """
621
+ raise NotImplementedError ()
622
+
542
623
543
624
class NumpyBackend (Backend ):
544
625
"""
@@ -712,6 +793,46 @@ def stack(self, arrays, axis=0):
712
793
def reshape (self , a , shape ):
713
794
return np .reshape (a , shape )
714
795
796
+ def coo_matrix (self , data , rows , cols , shape = None , type_as = None ):
797
+ if type_as is None :
798
+ return coo_matrix ((data , (rows , cols )), shape = shape )
799
+ else :
800
+ return coo_matrix ((data , (rows , cols )), shape = shape , dtype = type_as .dtype )
801
+
802
+ def issparse (self , a ):
803
+ return issparse (a )
804
+
805
+ def tocsr (self , a ):
806
+ if self .issparse (a ):
807
+ return a .tocsr ()
808
+ else :
809
+ return csr_matrix (a )
810
+
811
+ def eliminate_zeros (self , a , threshold = 0. ):
812
+ if threshold > 0 :
813
+ if self .issparse (a ):
814
+ a .data [self .abs (a .data ) <= threshold ] = 0
815
+ else :
816
+ a [self .abs (a ) <= threshold ] = 0
817
+ if self .issparse (a ):
818
+ a .eliminate_zeros ()
819
+ return a
820
+
821
+ def todense (self , a ):
822
+ if self .issparse (a ):
823
+ return a .toarray ()
824
+ else :
825
+ return a
826
+
827
+ def where (self , condition , x , y ):
828
+ return np .where (condition , x , y )
829
+
830
+ def copy (self , a ):
831
+ return a .copy ()
832
+
833
+ def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
834
+ return np .allclose (a , b , rtol = rtol , atol = atol , equal_nan = equal_nan )
835
+
715
836
716
837
class JaxBackend (Backend ):
717
838
"""
@@ -889,6 +1010,48 @@ def stack(self, arrays, axis=0):
889
1010
def reshape (self , a , shape ):
890
1011
return jnp .reshape (a , shape )
891
1012
1013
+ def coo_matrix (self , data , rows , cols , shape = None , type_as = None ):
1014
+ # Currently, JAX does not support sparse matrices
1015
+ data = self .to_numpy (data )
1016
+ rows = self .to_numpy (rows )
1017
+ cols = self .to_numpy (cols )
1018
+ nx = NumpyBackend ()
1019
+ coo_matrix = nx .coo_matrix (data , rows , cols , shape = shape , type_as = type_as )
1020
+ matrix = nx .todense (coo_matrix )
1021
+ return self .from_numpy (matrix )
1022
+
1023
+ def issparse (self , a ):
1024
+ # Currently, JAX does not support sparse matrices
1025
+ return False
1026
+
1027
+ def tocsr (self , a ):
1028
+ # Currently, JAX does not support sparse matrices
1029
+ return a
1030
+
1031
+ def eliminate_zeros (self , a , threshold = 0. ):
1032
+ # Currently, JAX does not support sparse matrices
1033
+ if threshold > 0 :
1034
+ return self .where (
1035
+ self .abs (a ) <= threshold ,
1036
+ self .zeros ((1 ,), type_as = a ),
1037
+ a
1038
+ )
1039
+ return a
1040
+
1041
+ def todense (self , a ):
1042
+ # Currently, JAX does not support sparse matrices
1043
+ return a
1044
+
1045
+ def where (self , condition , x , y ):
1046
+ return jnp .where (condition , x , y )
1047
+
1048
+ def copy (self , a ):
1049
+ # No need to copy, JAX arrays are immutable
1050
+ return a
1051
+
1052
+ def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
1053
+ return jnp .allclose (a , b , rtol = rtol , atol = atol , equal_nan = equal_nan )
1054
+
892
1055
893
1056
class TorchBackend (Backend ):
894
1057
"""
@@ -999,7 +1162,7 @@ def maximum(self, a, b):
999
1162
a = torch .tensor ([float (a )], dtype = b .dtype , device = b .device )
1000
1163
if isinstance (b , int ) or isinstance (b , float ):
1001
1164
b = torch .tensor ([float (b )], dtype = a .dtype , device = a .device )
1002
- if torch . __version__ >= '1.7.0' :
1165
+ if hasattr ( torch , "maximum" ) :
1003
1166
return torch .maximum (a , b )
1004
1167
else :
1005
1168
return torch .max (torch .stack (torch .broadcast_tensors (a , b )), axis = 0 )[0 ]
@@ -1009,7 +1172,7 @@ def minimum(self, a, b):
1009
1172
a = torch .tensor ([float (a )], dtype = b .dtype , device = b .device )
1010
1173
if isinstance (b , int ) or isinstance (b , float ):
1011
1174
b = torch .tensor ([float (b )], dtype = a .dtype , device = a .device )
1012
- if torch . __version__ >= '1.7.0' :
1175
+ if hasattr ( torch , "minimum" ) :
1013
1176
return torch .minimum (a , b )
1014
1177
else :
1015
1178
return torch .min (torch .stack (torch .broadcast_tensors (a , b )), axis = 0 )[0 ]
@@ -1129,3 +1292,50 @@ def stack(self, arrays, axis=0):
1129
1292
1130
1293
def reshape (self , a , shape ):
1131
1294
return torch .reshape (a , shape )
1295
+
1296
+ def coo_matrix (self , data , rows , cols , shape = None , type_as = None ):
1297
+ if type_as is None :
1298
+ return torch .sparse_coo_tensor (torch .stack ([rows , cols ]), data , size = shape )
1299
+ else :
1300
+ return torch .sparse_coo_tensor (
1301
+ torch .stack ([rows , cols ]), data , size = shape ,
1302
+ dtype = type_as .dtype , device = type_as .device
1303
+ )
1304
+
1305
+ def issparse (self , a ):
1306
+ return getattr (a , "is_sparse" , False ) or getattr (a , "is_sparse_csr" , False )
1307
+
1308
+ def tocsr (self , a ):
1309
+ # Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support
1310
+ return self .todense (a )
1311
+
1312
+ def eliminate_zeros (self , a , threshold = 0. ):
1313
+ if self .issparse (a ):
1314
+ if threshold > 0 :
1315
+ mask = self .abs (a ) <= threshold
1316
+ mask = ~ mask
1317
+ mask = mask .nonzero ()
1318
+ else :
1319
+ mask = a ._values ().nonzero ()
1320
+ nv = a ._values ().index_select (0 , mask .view (- 1 ))
1321
+ ni = a ._indices ().index_select (1 , mask .view (- 1 ))
1322
+ return self .coo_matrix (nv , ni [0 ], ni [1 ], shape = a .shape , type_as = a )
1323
+ else :
1324
+ if threshold > 0 :
1325
+ a [self .abs (a ) <= threshold ] = 0
1326
+ return a
1327
+
1328
+ def todense (self , a ):
1329
+ if self .issparse (a ):
1330
+ return a .to_dense ()
1331
+ else :
1332
+ return a
1333
+
1334
+ def where (self , condition , x , y ):
1335
+ return torch .where (condition , x , y )
1336
+
1337
+ def copy (self , a ):
1338
+ return torch .clone (a )
1339
+
1340
+ def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
1341
+ return torch .allclose (a , b , rtol = rtol , atol = atol , equal_nan = equal_nan )
0 commit comments