@@ -27,37 +27,41 @@ def test_shapes(method):
27
27
assert keras .ops .shape (oy ) == keras .ops .shape (y )
28
28
29
29
30
- def test_transport_cost_improves ():
30
+ @pytest .mark .parametrize ("method" , ["log_sinkhorn" , "sinkhorn" ])
31
+ def test_transport_cost_improves (method ):
31
32
x = keras .random .normal ((128 , 2 ), seed = 0 )
32
33
y = keras .random .normal ((128 , 2 ), seed = 1 )
33
34
34
35
before_cost = keras .ops .sum (keras .ops .norm (x - y , axis = - 1 ))
35
36
36
- x , y = optimal_transport (x , y , regularization = 0.1 , seed = 0 , max_steps = 1000 )
37
+ x , y = optimal_transport (x , y , regularization = 0.1 , seed = 0 , max_steps = 1000 , method = method )
37
38
38
39
after_cost = keras .ops .sum (keras .ops .norm (x - y , axis = - 1 ))
39
40
40
41
assert after_cost < before_cost
41
42
42
43
43
- @pytest .mark .skip (reason = "too unreliable" )
44
- def test_assignment_is_optimal ():
45
- x = keras .random .normal ((16 , 2 ), seed = 0 )
46
- p = keras .random .shuffle (keras .ops .arange (keras .ops .shape (x )[0 ]), seed = 0 )
47
- optimal_assignments = keras .ops .argsort (p )
44
+ @pytest .mark .parametrize ("method" , ["log_sinkhorn" , "sinkhorn" ])
45
+ def test_assignment_is_optimal (method ):
46
+ y = keras .random .normal ((16 , 2 ), seed = 0 )
47
+ p = keras .random .shuffle (keras .ops .arange (keras .ops .shape (y )[0 ]), seed = 0 )
48
48
49
- y = x [ p ]
49
+ x = keras . ops . take ( y , p , axis = 0 )
50
50
51
- x , y , assignments = optimal_transport (x , y , regularization = 0.1 , seed = 0 , max_steps = 10_000 , return_assignments = True )
51
+ _ , _ , assignments = optimal_transport (
52
+ x , y , regularization = 0.1 , seed = 0 , max_steps = 10_000 , method = method , return_assignments = True
53
+ )
52
54
53
- assert_allclose (assignments , optimal_assignments )
55
+ # transport is stochastic, so it is expected that a small fraction of assignments do not match
56
+ assert keras .ops .sum (assignments == p ) > 14
54
57
55
58
56
59
def test_assignment_aligns_with_pot ():
57
60
try :
58
61
from ot .bregman import sinkhorn_log
59
62
except (ImportError , ModuleNotFoundError ):
60
63
pytest .skip ("Need to install POT to run this test." )
64
+ return
61
65
62
66
x = keras .random .normal ((16 , 2 ), seed = 0 )
63
67
p = keras .random .shuffle (keras .ops .arange (keras .ops .shape (x )[0 ]), seed = 0 )
@@ -68,10 +72,112 @@ def test_assignment_aligns_with_pot():
68
72
M = x [:, None ] - y [None , :]
69
73
M = keras .ops .norm (M , axis = - 1 )
70
74
71
- pot_plan = sinkhorn_log (a , b , M , reg = 1e-3 , numItermax = 10_000 , stopThr = 1e-99 )
72
- pot_assignments = keras .random .categorical (pot_plan , num_samples = 1 , seed = 0 )
75
+ pot_plan = sinkhorn_log (a , b , M , numItermax = 10_000 , reg = 1e-3 , stopThr = 1e-7 )
76
+ pot_assignments = keras .random .categorical (keras . ops . log ( pot_plan ) , num_samples = 1 , seed = 0 )
73
77
pot_assignments = keras .ops .squeeze (pot_assignments , axis = - 1 )
74
78
75
79
_ , _ , assignments = optimal_transport (x , y , regularization = 1e-3 , seed = 0 , max_steps = 10_000 , return_assignments = True )
76
80
77
81
assert_allclose (pot_assignments , assignments )
82
+
83
+
84
+ def test_sinkhorn_plan_correct_marginals ():
85
+ from bayesflow .utils .optimal_transport .sinkhorn import sinkhorn_plan
86
+
87
+ x1 = keras .random .normal ((10 , 2 ), seed = 0 )
88
+ x2 = keras .random .normal ((20 , 2 ), seed = 1 )
89
+
90
+ assert keras .ops .all (keras .ops .isclose (keras .ops .sum (sinkhorn_plan (x1 , x2 ), axis = 0 ), 0.05 , atol = 1e-6 ))
91
+ assert keras .ops .all (keras .ops .isclose (keras .ops .sum (sinkhorn_plan (x1 , x2 ), axis = 1 ), 0.1 , atol = 1e-6 ))
92
+
93
+
94
+ def test_sinkhorn_plan_aligns_with_pot ():
95
+ try :
96
+ from ot .bregman import sinkhorn
97
+ except (ImportError , ModuleNotFoundError ):
98
+ pytest .skip ("Need to install POT to run this test." )
99
+
100
+ from bayesflow .utils .optimal_transport .sinkhorn import sinkhorn_plan
101
+ from bayesflow .utils .optimal_transport .euclidean import euclidean
102
+
103
+ x1 = keras .random .normal ((10 , 3 ), seed = 0 )
104
+ x2 = keras .random .normal ((20 , 3 ), seed = 1 )
105
+
106
+ a = keras .ops .ones (10 ) / 10
107
+ b = keras .ops .ones (20 ) / 20
108
+ M = euclidean (x1 , x2 )
109
+
110
+ pot_result = sinkhorn (a , b , M , 0.1 , stopThr = 1e-8 )
111
+ our_result = sinkhorn_plan (x1 , x2 , regularization = 0.1 , rtol = 1e-7 )
112
+
113
+ assert_allclose (pot_result , our_result )
114
+
115
+
116
+ def test_sinkhorn_plan_matches_analytical_result ():
117
+ from bayesflow .utils .optimal_transport .sinkhorn import sinkhorn_plan
118
+
119
+ x1 = keras .ops .ones (16 )
120
+ x2 = keras .ops .ones (64 )
121
+
122
+ marginal_x1 = keras .ops .ones (16 ) / 16
123
+ marginal_x2 = keras .ops .ones (64 ) / 64
124
+
125
+ result = sinkhorn_plan (x1 , x2 , regularization = 0.1 )
126
+
127
+ # If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals
128
+ expected = keras .ops .outer (marginal_x1 , marginal_x2 )
129
+
130
+ assert_allclose (result , expected )
131
+
132
+
133
+ def test_log_sinkhorn_plan_correct_marginals ():
134
+ from bayesflow .utils .optimal_transport .log_sinkhorn import log_sinkhorn_plan
135
+
136
+ x1 = keras .random .normal ((10 , 2 ), seed = 0 )
137
+ x2 = keras .random .normal ((20 , 2 ), seed = 1 )
138
+
139
+ assert keras .ops .all (
140
+ keras .ops .isclose (keras .ops .logsumexp (log_sinkhorn_plan (x1 , x2 ), axis = 0 ), - keras .ops .log (20 ), atol = 1e-3 )
141
+ )
142
+ assert keras .ops .all (
143
+ keras .ops .isclose (keras .ops .logsumexp (log_sinkhorn_plan (x1 , x2 ), axis = 1 ), - keras .ops .log (10 ), atol = 1e-3 )
144
+ )
145
+
146
+
147
+ def test_log_sinkhorn_plan_aligns_with_pot ():
148
+ try :
149
+ from ot .bregman import sinkhorn_log
150
+ except (ImportError , ModuleNotFoundError ):
151
+ pytest .skip ("Need to install POT to run this test." )
152
+
153
+ from bayesflow .utils .optimal_transport .log_sinkhorn import log_sinkhorn_plan
154
+ from bayesflow .utils .optimal_transport .euclidean import euclidean
155
+
156
+ x1 = keras .random .normal ((100 , 3 ), seed = 0 )
157
+ x2 = keras .random .normal ((200 , 3 ), seed = 1 )
158
+
159
+ a = keras .ops .ones (100 ) / 100
160
+ b = keras .ops .ones (200 ) / 200
161
+ M = euclidean (x1 , x2 )
162
+
163
+ pot_result = keras .ops .log (sinkhorn_log (a , b , M , 0.1 , stopThr = 1e-7 )) # sinkhorn_log returns probabilities
164
+ our_result = log_sinkhorn_plan (x1 , x2 , regularization = 0.1 )
165
+
166
+ assert_allclose (pot_result , our_result )
167
+
168
+
169
+ def test_log_sinkhorn_plan_matches_analytical_result ():
170
+ from bayesflow .utils .optimal_transport .log_sinkhorn import log_sinkhorn_plan
171
+
172
+ x1 = keras .ops .ones (16 )
173
+ x2 = keras .ops .ones (64 )
174
+
175
+ marginal_x1 = keras .ops .ones (16 ) / 16
176
+ marginal_x2 = keras .ops .ones (64 ) / 64
177
+
178
+ result = keras .ops .exp (log_sinkhorn_plan (x1 , x2 , regularization = 0.1 ))
179
+
180
+ # If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals
181
+ expected = keras .ops .outer (marginal_x1 , marginal_x2 )
182
+
183
+ assert_allclose (result , expected )
0 commit comments