4
4
import numpy as np
5
5
import numba
6
6
7
+ from pynndescent .optimal_transport import (
8
+ allocate_graph_structures ,
9
+ initialize_graph_structures ,
10
+ initialize_supply ,
11
+ initialize_cost ,
12
+ network_simplex_core ,
13
+ total_cost ,
14
+ ProblemStatus ,
15
+ )
16
+
7
17
_mock_identity = np .eye (2 , dtype = np .float32 )
8
18
_mock_ones = np .ones (2 , dtype = np .float32 )
19
+ _dummy_cost = np .zeros ((2 , 2 ), dtype = np .float64 )
9
20
10
21
11
22
@numba .njit (fastmath = True , cache = True )
@@ -524,6 +535,34 @@ def spearmanr(x, y):
524
535
return rs [1 , 0 ]
525
536
526
537
538
+ @numba .njit ()
539
+ def kantorovich_distance (x , y , cost = _dummy_cost , max_iter = 100000 ):
540
+ if cost is _dummy_cost :
541
+ raise ValueError ("Kantorovich distance requires a cost matrix to be supplied." )
542
+ node_arc_data , spanning_tree , graph = allocate_graph_structures (
543
+ x .shape [0 ], y .shape [0 ], False ,
544
+ )
545
+ initialize_supply (x , - y , graph , node_arc_data .supply )
546
+ initialize_cost (cost , graph , node_arc_data .cost )
547
+ init_status = initialize_graph_structures (graph , node_arc_data , spanning_tree )
548
+ if init_status == False :
549
+ raise ValueError (
550
+ "Kantorovich distance inputs must be valid probability " "distributions."
551
+ )
552
+ solve_status = network_simplex_core (node_arc_data , spanning_tree , graph , max_iter ,)
553
+ if solve_status == ProblemStatus .MAX_ITER_REACHED :
554
+ print ("WARNING: RESULT MIGHT BE INACURATE\n Max number of iteration reached!" )
555
+ elif solve_status == ProblemStatus .INFEASIBLE :
556
+ raise ValueError (
557
+ "Optimal transport problem was INFEASIBLE. Please check " "inputs."
558
+ )
559
+ elif solve_status == ProblemStatus .UNBOUNDED :
560
+ raise ValueError (
561
+ "Optimal transport problem was UNBOUNDED. Please check " "inputs."
562
+ )
563
+ return total_cost (node_arc_data .flow , node_arc_data .cost )
564
+
565
+
527
566
named_distances = {
528
567
# general minkowski distances
529
568
"euclidean" : euclidean ,
@@ -550,6 +589,8 @@ def spearmanr(x, y):
550
589
"haversine" : haversine ,
551
590
"braycurtis" : bray_curtis ,
552
591
"spearmanr" : spearmanr ,
592
+ "kantorovich" : kantorovich_distance ,
593
+ "wasserstein" : kantorovich_distance ,
553
594
# Binary distances
554
595
"hamming" : hamming ,
555
596
"jaccard" : jaccard ,
0 commit comments