Skip to content

Commit 2dbeeda

Browse files
authored
[MRG] Batch/Lazy Log Sinkhorn Knopp on samples (#259)
* Add batch implementation of Sinkhorn * Reformat to pep8 and modify parameter * Fix error in batch size * Code review and add test * Fix accidental typo in test_empirical_sinkhorn * Remove whitespace * Edit config.yml
1 parent 982510e commit 2dbeeda

File tree

3 files changed

+158
-21
lines changed

3 files changed

+158
-21
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ jobs:
7373
command: |
7474
cd docs;
7575
make html;
76+
no_output_timeout: 30m
7677

7778
# Save the outputs
7879
- store_artifacts:

ot/bregman.py

Lines changed: 113 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
1212
# Alexander Tong <alexander.tong@yale.edu>
1313
# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
14+
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
1415
#
1516
# License: MIT License
1617

1718
import warnings
1819

1920
import numpy as np
2021
from scipy.optimize import fmin_l_bfgs_b
22+
from scipy.special import logsumexp
2123

2224
from ot.utils import unif, dist, list_to_array
2325
from .backend import get_backend
@@ -1684,7 +1686,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
16841686

16851687

16861688
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
1687-
numIterMax=10000, stopThr=1e-9, verbose=False,
1689+
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
16881690
log=False, **kwargs):
16891691
r'''
16901692
Solve the entropic regularization optimal transport problem and return the
@@ -1723,6 +1725,12 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
17231725
Max number of iterations
17241726
stopThr : float, optional
17251727
Stop threshol on error (>0)
1728+
isLazy: boolean, optional
1729+
If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
1730+
If False, calculate full cost matrix and return outputs of sinkhorn function.
1731+
batchSize: int or tuple of 2 int, optional
1732+
Size of the batcheses used to compute the sinkhorn update without memory overhead.
1733+
When a tuple is provided it sets the size of the left/right batches.
17261734
verbose : bool, optional
17271735
Print information along iterations
17281736
log : bool, optional
@@ -1758,24 +1766,78 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
17581766
17591767
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
17601768
'''
1761-
1769+
ns, nt = X_s.shape[0], X_t.shape[0]
17621770
if a is None:
1763-
a = unif(np.shape(X_s)[0])
1771+
a = unif(ns)
17641772
if b is None:
1765-
b = unif(np.shape(X_t)[0])
1773+
b = unif(nt)
1774+
1775+
if isLazy:
1776+
if log:
1777+
dict_log = {"err": []}
17661778

1767-
M = dist(X_s, X_t, metric=metric)
1779+
log_a, log_b = np.log(a), np.log(b)
1780+
f, g = np.zeros(ns), np.zeros(nt)
1781+
1782+
if isinstance(batchSize, int):
1783+
bs, bt = batchSize, batchSize
1784+
elif isinstance(batchSize, tuple) and len(batchSize) == 2:
1785+
bs, bt = batchSize[0], batchSize[1]
1786+
else:
1787+
raise ValueError("Batch size must be in integer or a tuple of two integers")
1788+
1789+
range_s, range_t = range(0, ns, bs), range(0, nt, bt)
1790+
1791+
lse_f = np.zeros(ns)
1792+
lse_g = np.zeros(nt)
1793+
1794+
for i_ot in range(numIterMax):
1795+
1796+
for i in range_s:
1797+
M = dist(X_s[i:i + bs, :], X_t, metric=metric)
1798+
lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1)
1799+
f = log_a - lse_f
1800+
1801+
for j in range_t:
1802+
M = dist(X_s, X_t[j:j + bt, :], metric=metric)
1803+
lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0)
1804+
g = log_b - lse_g
1805+
1806+
if (i_ot + 1) % 10 == 0:
1807+
m1 = np.zeros_like(a)
1808+
for i in range_s:
1809+
M = dist(X_s[i:i + bs, :], X_t, metric=metric)
1810+
m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1)
1811+
err = np.abs(m1 - a).sum()
1812+
if log:
1813+
dict_log["err"].append(err)
1814+
1815+
if verbose and (i_ot + 1) % 100 == 0:
1816+
print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
1817+
1818+
if err <= stopThr:
1819+
break
1820+
1821+
if log:
1822+
dict_log["u"] = f
1823+
dict_log["v"] = g
1824+
return (f, g, dict_log)
1825+
else:
1826+
return (f, g)
17681827

1769-
if log:
1770-
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
1771-
return pi, log
17721828
else:
1773-
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
1774-
return pi
1829+
M = dist(X_s, X_t, metric=metric)
1830+
1831+
if log:
1832+
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
1833+
return pi, log
1834+
else:
1835+
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
1836+
return pi
17751837

17761838

17771839
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
1778-
verbose=False, log=False, **kwargs):
1840+
isLazy=False, batchSize=100, verbose=False, log=False, **kwargs):
17791841
r'''
17801842
Solve the entropic regularization optimal transport problem from empirical
17811843
data and return the OT loss
@@ -1814,6 +1876,12 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
18141876
Max number of iterations
18151877
stopThr : float, optional
18161878
Stop threshol on error (>0)
1879+
isLazy: boolean, optional
1880+
If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
1881+
If False, calculate full cost matrix and return outputs of sinkhorn function.
1882+
batchSize: int or tuple of 2 int, optional
1883+
Size of the batcheses used to compute the sinkhorn update without memory overhead.
1884+
When a tuple is provided it sets the size of the left/right batches.
18171885
verbose : bool, optional
18181886
Print information along iterations
18191887
log : bool, optional
@@ -1850,21 +1918,45 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
18501918
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
18511919
'''
18521920

1921+
ns, nt = X_s.shape[0], X_t.shape[0]
18531922
if a is None:
1854-
a = unif(np.shape(X_s)[0])
1923+
a = unif(ns)
18551924
if b is None:
1856-
b = unif(np.shape(X_t)[0])
1925+
b = unif(nt)
18571926

1858-
M = dist(X_s, X_t, metric=metric)
1927+
if isLazy:
1928+
if log:
1929+
f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
1930+
isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)
1931+
else:
1932+
f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
1933+
isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)
1934+
1935+
bs = batchSize if isinstance(batchSize, int) else batchSize[0]
1936+
range_s = range(0, ns, bs)
1937+
1938+
loss = 0
1939+
for i in range_s:
1940+
M_block = dist(X_s[i:i + bs, :], X_t, metric=metric)
1941+
pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
1942+
loss += np.sum(M_block * pi_block)
1943+
1944+
if log:
1945+
return loss, dict_log
1946+
else:
1947+
return loss
18591948

1860-
if log:
1861-
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
1862-
**kwargs)
1863-
return sinkhorn_loss, log
18641949
else:
1865-
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
1866-
**kwargs)
1867-
return sinkhorn_loss
1950+
M = dist(X_s, X_t, metric=metric)
1951+
1952+
if log:
1953+
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
1954+
**kwargs)
1955+
return sinkhorn_loss, log
1956+
else:
1957+
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
1958+
**kwargs)
1959+
return sinkhorn_loss
18681960

18691961

18701962
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,

test/test_bregman.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# Author: Remi Flamary <remi.flamary@unice.fr>
44
# Kilian Fatras <kilian.fatras@irisa.fr>
5+
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
56
#
67
# License: MIT License
78

@@ -329,6 +330,49 @@ def test_empirical_sinkhorn():
329330
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
330331

331332

333+
def test_lazy_empirical_sinkhorn():
334+
# test sinkhorn
335+
n = 100
336+
a = ot.unif(n)
337+
b = ot.unif(n)
338+
numIterMax = 1000
339+
340+
X_s = np.reshape(np.arange(n), (n, 1))
341+
X_t = np.reshape(np.arange(0, n), (n, 1))
342+
M = ot.dist(X_s, X_t)
343+
M_m = ot.dist(X_s, X_t, metric='minkowski')
344+
345+
f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True)
346+
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
347+
sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
348+
349+
f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
350+
G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
351+
sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
352+
353+
f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
354+
G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
355+
sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
356+
357+
loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
358+
loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
359+
360+
# check constratints
361+
np.testing.assert_allclose(
362+
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
363+
np.testing.assert_allclose(
364+
sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
365+
np.testing.assert_allclose(
366+
sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
367+
np.testing.assert_allclose(
368+
sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
369+
np.testing.assert_allclose(
370+
sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
371+
np.testing.assert_allclose(
372+
sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
373+
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
374+
375+
332376
def test_empirical_sinkhorn_divergence():
333377
# Test sinkhorn divergence
334378
n = 10

0 commit comments

Comments
 (0)