Skip to content

Commit c1f63b8

Browse files
committed
Add support for kantorovich/wasserstein distances
1 parent db77f47 commit c1f63b8

File tree

3 files changed

+960
-1
lines changed

3 files changed

+960
-1
lines changed

pynndescent/distances.py

+41
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,19 @@
44
import numpy as np
55
import numba
66

7+
from pynndescent.optimal_transport import (
8+
allocate_graph_structures,
9+
initialize_graph_structures,
10+
initialize_supply,
11+
initialize_cost,
12+
network_simplex_core,
13+
total_cost,
14+
ProblemStatus,
15+
)
16+
717
_mock_identity = np.eye(2, dtype=np.float32)
818
_mock_ones = np.ones(2, dtype=np.float32)
19+
_dummy_cost = np.zeros((2, 2), dtype=np.float64)
920

1021

1122
@numba.njit(fastmath=True, cache=True)
@@ -524,6 +535,34 @@ def spearmanr(x, y):
524535
return rs[1, 0]
525536

526537

538+
@numba.njit()
539+
def kantorovich_distance(x, y, cost=_dummy_cost, max_iter=100000):
540+
if cost is _dummy_cost:
541+
raise ValueError("Kantorovich distance requires a cost matrix to be supplied.")
542+
node_arc_data, spanning_tree, graph = allocate_graph_structures(
543+
x.shape[0], y.shape[0], False,
544+
)
545+
initialize_supply(x, -y, graph, node_arc_data.supply)
546+
initialize_cost(cost, graph, node_arc_data.cost)
547+
init_status = initialize_graph_structures(graph, node_arc_data, spanning_tree)
548+
if init_status == False:
549+
raise ValueError(
550+
"Kantorovich distance inputs must be valid probability " "distributions."
551+
)
552+
solve_status = network_simplex_core(node_arc_data, spanning_tree, graph, max_iter,)
553+
if solve_status == ProblemStatus.MAX_ITER_REACHED:
554+
print("WARNING: RESULT MIGHT BE INACURATE\nMax number of iteration reached!")
555+
elif solve_status == ProblemStatus.INFEASIBLE:
556+
raise ValueError(
557+
"Optimal transport problem was INFEASIBLE. Please check " "inputs."
558+
)
559+
elif solve_status == ProblemStatus.UNBOUNDED:
560+
raise ValueError(
561+
"Optimal transport problem was UNBOUNDED. Please check " "inputs."
562+
)
563+
return total_cost(node_arc_data.flow, node_arc_data.cost)
564+
565+
527566
named_distances = {
528567
# general minkowski distances
529568
"euclidean": euclidean,
@@ -550,6 +589,8 @@ def spearmanr(x, y):
550589
"haversine": haversine,
551590
"braycurtis": bray_curtis,
552591
"spearmanr": spearmanr,
592+
"kantorovich": kantorovich_distance,
593+
"wasserstein": kantorovich_distance,
553594
# Binary distances
554595
"hamming": hamming,
555596
"jaccard": jaccard,

0 commit comments

Comments
 (0)