Skip to content

Commit 47bf8e8

Browse files
Hotfix: numercial stability of non-log-stabilized sinkhorn plan (#531)
* fix numerical stability issues in sinkhorn plan * improve test suite * fix ultra-strict convergence criterion in log_sinkhorn_plan * update dependencies * add comment about convergence check * update docsting to reflect fixes * sinkhorn_plan now returns a transport plan with uniform marginal distributions * add unit test for sinkhorn_plan * fix sinkhorn function by sampling from the logits of the transpose of the plan, instead of the plan directly * sinkhorn(x1, x2) now samples from log(plan) to receive assignments such that x2[assignments] matches x1 * re-enable test_assignment_is_optimal() for method='sinkhorn' * log_sinkhorn now correctly uses log_plan instead of keras.ops.exp(log_plan), log_sinkhorn_plan returns logits of the transport plan * add unit tests for log_sinkhorn_plan * fix faulty indexing with tensor for tensorflow backend * re-add numItermax for ot pot test --------- Co-authored-by: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com>
1 parent 2a19d32 commit 47bf8e8

File tree

4 files changed

+155
-36
lines changed

4 files changed

+155
-36
lines changed

bayesflow/utils/optimal_transport/log_sinkhorn.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
99
"""
1010
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`.
11-
Significantly slower than the unstabilized version, so use only when you need numerical stability.
11+
About 50% slower than the unstabilized version, so use only when you need numerical stability.
1212
"""
1313
log_plan = log_sinkhorn_plan(x1, x2, **kwargs)
14-
assignments = keras.random.categorical(keras.ops.exp(log_plan), num_samples=1, seed=seed)
14+
assignments = keras.random.categorical(log_plan, num_samples=1, seed=seed)
1515
assignments = keras.ops.squeeze(assignments, axis=1)
1616

1717
return assignments
@@ -20,19 +20,25 @@ def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
2020
def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, max_steps=None):
2121
"""
2222
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`.
23-
Significantly slower than the unstabilized version, so use only when you need numerical stability.
23+
About 50% slower than the unstabilized version, so use primarily when you need numerical stability.
2424
"""
2525
cost = euclidean(x1, x2)
26+
cost_scaled = -cost / regularization
2627

27-
log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
28+
# initialize transport plan from a gaussian kernel
29+
log_plan = cost_scaled - keras.ops.max(cost_scaled)
30+
n, m = keras.ops.shape(log_plan)
31+
32+
log_a = -keras.ops.log(n)
33+
log_b = -keras.ops.log(m)
2834

2935
def contains_nans(plan):
3036
return keras.ops.any(keras.ops.isnan(plan))
3137

3238
def is_converged(plan):
33-
# for convergence, the plan should be doubly stochastic
34-
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=rtol, atol=atol))
35-
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=rtol, atol=atol))
39+
# for convergence, the target marginals must match
40+
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), log_b, rtol=0.0, atol=rtol + atol))
41+
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), log_a, rtol=0.0, atol=rtol + atol))
3642
return conv0 & conv1
3743

3844
def cond(_, plan):
@@ -41,8 +47,8 @@ def cond(_, plan):
4147

4248
def body(steps, plan):
4349
# Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
44-
plan = keras.ops.log_softmax(plan, axis=0)
45-
plan = keras.ops.log_softmax(plan, axis=1)
50+
plan = plan - keras.ops.logsumexp(plan, axis=0, keepdims=True) + log_b
51+
plan = plan - keras.ops.logsumexp(plan, axis=1, keepdims=True) + log_a
4652

4753
return steps + 1, plan
4854

bayesflow/utils/optimal_transport/sinkhorn.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten
1111
"""
1212
Matches elements from x2 onto x1 using the Sinkhorn-Knopp algorithm.
1313
14-
Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a doubly stochastic
14+
Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a
1515
transport plan, containing assignment probabilities.
1616
The permutation is then sampled randomly according to the transport plan.
1717
@@ -27,12 +27,15 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten
2727
:param seed: Random seed to use for sampling indices.
2828
Default: None, which means the seed will be auto-determined for non-compiled contexts.
2929
30-
:return: Tensor of shape (m,)
30+
:return: Tensor of shape (n,)
3131
Assignment indices for x2.
3232
3333
"""
3434
plan = sinkhorn_plan(x1, x2, **kwargs)
35-
assignments = keras.random.categorical(plan, num_samples=1, seed=seed)
35+
36+
# we sample from log(plan) to receive assignments of length n, corresponding to indices of x2
37+
# such that x2[assignments] matches x1
38+
assignments = keras.random.categorical(keras.ops.log(plan), num_samples=1, seed=seed)
3639
assignments = keras.ops.squeeze(assignments, axis=1)
3740

3841
return assignments
@@ -42,7 +45,7 @@ def sinkhorn_plan(
4245
x1: Tensor,
4346
x2: Tensor,
4447
regularization: float = 1.0,
45-
max_steps: int = 10_000,
48+
max_steps: int = None,
4649
rtol: float = 1e-5,
4750
atol: float = 1e-8,
4851
) -> Tensor:
@@ -59,7 +62,7 @@ def sinkhorn_plan(
5962
Controls the standard deviation of the Gaussian kernel.
6063
6164
:param max_steps: Maximum number of iterations, or None to run until convergence.
62-
Default: 10_000
65+
Default: None
6366
6467
:param rtol: Relative tolerance for convergence.
6568
Default: 1e-5.
@@ -71,17 +74,20 @@ def sinkhorn_plan(
7174
The transport probabilities.
7275
"""
7376
cost = euclidean(x1, x2)
77+
cost_scaled = -cost / regularization
7478

75-
# initialize the transport plan from a gaussian kernel
76-
plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16))
79+
# initialize transport plan from a gaussian kernel
80+
# (more numerically stable version of keras.ops.exp(-cost/regularization))
81+
plan = keras.ops.exp(cost_scaled - keras.ops.max(cost_scaled))
82+
n, m = keras.ops.shape(cost)
7783

7884
def contains_nans(plan):
7985
return keras.ops.any(keras.ops.isnan(plan))
8086

8187
def is_converged(plan):
82-
# for convergence, the plan should be doubly stochastic
83-
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 1.0, rtol=rtol, atol=atol))
84-
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 1.0, rtol=rtol, atol=atol))
88+
# for convergence, the target marginals must match
89+
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 1.0 / m, rtol=rtol, atol=atol))
90+
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 1.0 / n, rtol=rtol, atol=atol))
8591
return conv0 & conv1
8692

8793
def cond(_, plan):
@@ -90,8 +96,8 @@ def cond(_, plan):
9096

9197
def body(steps, plan):
9298
# Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
93-
plan = keras.ops.softmax(plan, axis=0)
94-
plan = keras.ops.softmax(plan, axis=1)
99+
plan = plan / keras.ops.sum(plan, axis=0, keepdims=True) * (1.0 / m)
100+
plan = plan / keras.ops.sum(plan, axis=1, keepdims=True) * (1.0 / n)
95101

96102
return steps + 1, plan
97103

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@ dependencies = [
3636
[project.optional-dependencies]
3737
all = [
3838
# dev
39+
"ipython",
40+
"ipykernel",
3941
"jupyter",
4042
"jupyterlab",
43+
"line-profiler",
4144
"nbconvert",
42-
"ipython",
43-
"ipykernel",
4445
"pre-commit",
4546
"ruff",
4647
"tox",
4748
# docs
48-
4949
"myst-nb ~= 1.2",
5050
"numpydoc ~= 1.8",
5151
"pydata-sphinx-theme ~= 0.16",
@@ -63,6 +63,7 @@ all = [
6363
dev = [
6464
"jupyter",
6565
"jupyterlab",
66+
"line-profiler",
6667
"pre-commit",
6768
"ruff",
6869
"tox",

tests/test_utils/test_optimal_transport.py

Lines changed: 118 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,37 +27,41 @@ def test_shapes(method):
2727
assert keras.ops.shape(oy) == keras.ops.shape(y)
2828

2929

30-
def test_transport_cost_improves():
30+
@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"])
31+
def test_transport_cost_improves(method):
3132
x = keras.random.normal((128, 2), seed=0)
3233
y = keras.random.normal((128, 2), seed=1)
3334

3435
before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
3536

36-
x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=1000)
37+
x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=1000, method=method)
3738

3839
after_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
3940

4041
assert after_cost < before_cost
4142

4243

43-
@pytest.mark.skip(reason="too unreliable")
44-
def test_assignment_is_optimal():
45-
x = keras.random.normal((16, 2), seed=0)
46-
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0)
47-
optimal_assignments = keras.ops.argsort(p)
44+
@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"])
45+
def test_assignment_is_optimal(method):
46+
y = keras.random.normal((16, 2), seed=0)
47+
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(y)[0]), seed=0)
4848

49-
y = x[p]
49+
x = keras.ops.take(y, p, axis=0)
5050

51-
x, y, assignments = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=10_000, return_assignments=True)
51+
_, _, assignments = optimal_transport(
52+
x, y, regularization=0.1, seed=0, max_steps=10_000, method=method, return_assignments=True
53+
)
5254

53-
assert_allclose(assignments, optimal_assignments)
55+
# transport is stochastic, so it is expected that a small fraction of assignments do not match
56+
assert keras.ops.sum(assignments == p) > 14
5457

5558

5659
def test_assignment_aligns_with_pot():
5760
try:
5861
from ot.bregman import sinkhorn_log
5962
except (ImportError, ModuleNotFoundError):
6063
pytest.skip("Need to install POT to run this test.")
64+
return
6165

6266
x = keras.random.normal((16, 2), seed=0)
6367
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0)
@@ -68,10 +72,112 @@ def test_assignment_aligns_with_pot():
6872
M = x[:, None] - y[None, :]
6973
M = keras.ops.norm(M, axis=-1)
7074

71-
pot_plan = sinkhorn_log(a, b, M, reg=1e-3, numItermax=10_000, stopThr=1e-99)
72-
pot_assignments = keras.random.categorical(pot_plan, num_samples=1, seed=0)
75+
pot_plan = sinkhorn_log(a, b, M, numItermax=10_000, reg=1e-3, stopThr=1e-7)
76+
pot_assignments = keras.random.categorical(keras.ops.log(pot_plan), num_samples=1, seed=0)
7377
pot_assignments = keras.ops.squeeze(pot_assignments, axis=-1)
7478

7579
_, _, assignments = optimal_transport(x, y, regularization=1e-3, seed=0, max_steps=10_000, return_assignments=True)
7680

7781
assert_allclose(pot_assignments, assignments)
82+
83+
84+
def test_sinkhorn_plan_correct_marginals():
85+
from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan
86+
87+
x1 = keras.random.normal((10, 2), seed=0)
88+
x2 = keras.random.normal((20, 2), seed=1)
89+
90+
assert keras.ops.all(keras.ops.isclose(keras.ops.sum(sinkhorn_plan(x1, x2), axis=0), 0.05, atol=1e-6))
91+
assert keras.ops.all(keras.ops.isclose(keras.ops.sum(sinkhorn_plan(x1, x2), axis=1), 0.1, atol=1e-6))
92+
93+
94+
def test_sinkhorn_plan_aligns_with_pot():
95+
try:
96+
from ot.bregman import sinkhorn
97+
except (ImportError, ModuleNotFoundError):
98+
pytest.skip("Need to install POT to run this test.")
99+
100+
from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan
101+
from bayesflow.utils.optimal_transport.euclidean import euclidean
102+
103+
x1 = keras.random.normal((10, 3), seed=0)
104+
x2 = keras.random.normal((20, 3), seed=1)
105+
106+
a = keras.ops.ones(10) / 10
107+
b = keras.ops.ones(20) / 20
108+
M = euclidean(x1, x2)
109+
110+
pot_result = sinkhorn(a, b, M, 0.1, stopThr=1e-8)
111+
our_result = sinkhorn_plan(x1, x2, regularization=0.1, rtol=1e-7)
112+
113+
assert_allclose(pot_result, our_result)
114+
115+
116+
def test_sinkhorn_plan_matches_analytical_result():
117+
from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan
118+
119+
x1 = keras.ops.ones(16)
120+
x2 = keras.ops.ones(64)
121+
122+
marginal_x1 = keras.ops.ones(16) / 16
123+
marginal_x2 = keras.ops.ones(64) / 64
124+
125+
result = sinkhorn_plan(x1, x2, regularization=0.1)
126+
127+
# If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals
128+
expected = keras.ops.outer(marginal_x1, marginal_x2)
129+
130+
assert_allclose(result, expected)
131+
132+
133+
def test_log_sinkhorn_plan_correct_marginals():
134+
from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan
135+
136+
x1 = keras.random.normal((10, 2), seed=0)
137+
x2 = keras.random.normal((20, 2), seed=1)
138+
139+
assert keras.ops.all(
140+
keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=0), -keras.ops.log(20), atol=1e-3)
141+
)
142+
assert keras.ops.all(
143+
keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=1), -keras.ops.log(10), atol=1e-3)
144+
)
145+
146+
147+
def test_log_sinkhorn_plan_aligns_with_pot():
148+
try:
149+
from ot.bregman import sinkhorn_log
150+
except (ImportError, ModuleNotFoundError):
151+
pytest.skip("Need to install POT to run this test.")
152+
153+
from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan
154+
from bayesflow.utils.optimal_transport.euclidean import euclidean
155+
156+
x1 = keras.random.normal((100, 3), seed=0)
157+
x2 = keras.random.normal((200, 3), seed=1)
158+
159+
a = keras.ops.ones(100) / 100
160+
b = keras.ops.ones(200) / 200
161+
M = euclidean(x1, x2)
162+
163+
pot_result = keras.ops.log(sinkhorn_log(a, b, M, 0.1, stopThr=1e-7)) # sinkhorn_log returns probabilities
164+
our_result = log_sinkhorn_plan(x1, x2, regularization=0.1)
165+
166+
assert_allclose(pot_result, our_result)
167+
168+
169+
def test_log_sinkhorn_plan_matches_analytical_result():
170+
from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan
171+
172+
x1 = keras.ops.ones(16)
173+
x2 = keras.ops.ones(64)
174+
175+
marginal_x1 = keras.ops.ones(16) / 16
176+
marginal_x2 = keras.ops.ones(64) / 64
177+
178+
result = keras.ops.exp(log_sinkhorn_plan(x1, x2, regularization=0.1))
179+
180+
# If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals
181+
expected = keras.ops.outer(marginal_x1, marginal_x2)
182+
183+
assert_allclose(result, expected)

0 commit comments

Comments
 (0)