Skip to content

Commit e1b67c6

Browse files
hichamjanatiHicham Janatirflamaryagramfort
authored
[WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291)
* add debiased sinkhorn barycenter + make loops pythonic * add debiased arg in tests * add 1d and 2d examples of debiased barycenters * fix doctest * fix flake8 * pep8 + make func private + add convergence warnings * remove rel paths + add rng + pylab to pyplot * fix stopping criterion debiased * pass alex * change params with new API * add logdomain barycenters + separate debiased API * test new API * fix jax read-only ? * raise error for jax * test catch jax error * fix pytest catch error * fix relative path * fix flake8 * add warn arg everywhere * fix ref number * catch warnings in tests * add contrib to readme + change ref number * fix convolution example + gallery thumbnails * increase coverage * fix flake Co-authored-by: Hicham Janati <hicham.janati@inria.fr> Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
1 parent 61340d5 commit e1b67c6

11 files changed

+1837
-675
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ POT provides the following generic OT solvers (links to examples):
2222
* [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7].
2323
* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html).
2424
* Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4].
25-
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
25+
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
26+
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
2627
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
2728
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
2829
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12])
@@ -188,7 +189,7 @@ The contributors to this library are
188189
* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers)
189190
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
190191
* [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein -, Fused-Gromov-Wasserstein)
191-
* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT)
192+
* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters)
192193
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
193194
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
194195
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
@@ -293,3 +294,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t
293294
(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
294295
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
295296
Machine Learning (pp. 4104-4113). PMLR.
297+
298+
[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
299+
Conference on Machine Learning, PMLR 119:4692-4701, 2020

examples/barycenters/plot_barycenter_1D.py

Lines changed: 23 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
#
1919
# License: MIT License
2020

21-
# sphinx_gallery_thumbnail_number = 4
21+
# sphinx_gallery_thumbnail_number = 1
2222

2323
import numpy as np
24-
import matplotlib.pylab as pl
24+
import matplotlib.pyplot as plt
2525
import ot
2626
# necessary for 3d plot even if not used
2727
from mpl_toolkits.mplot3d import Axes3D # noqa
@@ -50,18 +50,6 @@
5050
M = ot.utils.dist0(n)
5151
M /= M.max()
5252

53-
##############################################################################
54-
# Plot data
55-
# ---------
56-
57-
#%% plot the distributions
58-
59-
pl.figure(1, figsize=(6.4, 3))
60-
for i in range(n_distributions):
61-
pl.plot(x, A[:, i])
62-
pl.title('Distributions')
63-
pl.tight_layout()
64-
6553
##############################################################################
6654
# Barycenter computation
6755
# ----------------------
@@ -78,24 +66,20 @@
7866
reg = 1e-3
7967
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
8068

81-
pl.figure(2)
82-
pl.clf()
83-
pl.subplot(2, 1, 1)
84-
for i in range(n_distributions):
85-
pl.plot(x, A[:, i])
86-
pl.title('Distributions')
69+
f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1)
70+
ax1.plot(x, A, color="black")
71+
ax1.set_title('Distributions')
8772

88-
pl.subplot(2, 1, 2)
89-
pl.plot(x, bary_l2, 'r', label='l2')
90-
pl.plot(x, bary_wass, 'g', label='Wasserstein')
91-
pl.legend()
92-
pl.title('Barycenters')
93-
pl.tight_layout()
73+
ax2.plot(x, bary_l2, 'r', label='l2')
74+
ax2.plot(x, bary_wass, 'g', label='Wasserstein')
75+
ax2.set_title('Barycenters')
76+
77+
plt.legend()
78+
plt.show()
9479

9580
##############################################################################
9681
# Barycentric interpolation
9782
# -------------------------
98-
9983
#%% barycenter interpolation
10084

10185
n_alpha = 11
@@ -106,24 +90,23 @@
10690

10791
B_wass = np.copy(B_l2)
10892

109-
for i in range(0, n_alpha):
93+
for i in range(n_alpha):
11094
alpha = alpha_list[i]
11195
weights = np.array([1 - alpha, alpha])
11296
B_l2[:, i] = A.dot(weights)
11397
B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
11498

11599
#%% plot interpolation
100+
plt.figure(2)
116101

117-
pl.figure(3)
118-
119-
cmap = pl.cm.get_cmap('viridis')
102+
cmap = plt.cm.get_cmap('viridis')
120103
verts = []
121104
zs = alpha_list
122105
for i, z in enumerate(zs):
123106
ys = B_l2[:, i]
124107
verts.append(list(zip(x, ys)))
125108

126-
ax = pl.gcf().gca(projection='3d')
109+
ax = plt.gcf().gca(projection='3d')
127110

128111
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
129112
poly.set_alpha(0.7)
@@ -134,18 +117,18 @@
134117
ax.set_ylim3d(0, 1)
135118
ax.set_zlabel('')
136119
ax.set_zlim3d(0, B_l2.max() * 1.01)
137-
pl.title('Barycenter interpolation with l2')
138-
pl.tight_layout()
120+
plt.title('Barycenter interpolation with l2')
121+
plt.tight_layout()
139122

140-
pl.figure(4)
141-
cmap = pl.cm.get_cmap('viridis')
123+
plt.figure(3)
124+
cmap = plt.cm.get_cmap('viridis')
142125
verts = []
143126
zs = alpha_list
144127
for i, z in enumerate(zs):
145128
ys = B_wass[:, i]
146129
verts.append(list(zip(x, ys)))
147130

148-
ax = pl.gcf().gca(projection='3d')
131+
ax = plt.gcf().gca(projection='3d')
149132

150133
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
151134
poly.set_alpha(0.7)
@@ -156,7 +139,7 @@
156139
ax.set_ylim3d(0, 1)
157140
ax.set_zlabel('')
158141
ax.set_zlim3d(0, B_l2.max() * 1.01)
159-
pl.title('Barycenter interpolation with Wasserstein')
160-
pl.tight_layout()
142+
plt.title('Barycenter interpolation with Wasserstein')
143+
plt.tight_layout()
161144

162-
pl.show()
145+
plt.show()

examples/barycenters/plot_barycenter_lp_vs_entropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
"""
33
=================================================================================
4-
1D Wasserstein barycenter comparison between exact LP and entropic regularization
4+
1D Wasserstein barycenter: exact LP vs entropic regularization
55
=================================================================================
66
77
This example illustrates the computation of regularized Wasserstein Barycenter

examples/barycenters/plot_convolutional_barycenter.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66
Convolutional Wasserstein Barycenter example
77
============================================
88
9-
This example is designed to illustrate how the Convolutional Wasserstein Barycenter
10-
function of POT works.
9+
This example is designed to illustrate how the Convolutional Wasserstein
10+
Barycenter function of POT works.
1111
"""
1212

1313
# Author: Nicolas Courty <ncourty@irisa.fr>
1414
#
1515
# License: MIT License
16-
16+
import os
17+
from pathlib import Path
1718

1819
import numpy as np
19-
import pylab as pl
20+
import matplotlib.pyplot as plt
2021
import ot
2122

2223
##############################################################################
@@ -25,22 +26,19 @@
2526
#
2627
# The four distributions are constructed from 4 simple images
2728

29+
this_file = os.path.realpath('__file__')
30+
data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
2831

29-
f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2]
30-
f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2]
31-
f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2]
32-
f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2]
32+
f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2]
33+
f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2]
34+
f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
35+
f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
3336

34-
A = []
3537
f1 = f1 / np.sum(f1)
3638
f2 = f2 / np.sum(f2)
3739
f3 = f3 / np.sum(f3)
3840
f4 = f4 / np.sum(f4)
39-
A.append(f1)
40-
A.append(f2)
41-
A.append(f3)
42-
A.append(f4)
43-
A = np.array(A)
41+
A = np.array([f1, f2, f3, f4])
4442

4543
nb_images = 5
4644

@@ -57,14 +55,13 @@
5755
# ----------------------------------------
5856
#
5957

60-
pl.figure(figsize=(10, 10))
61-
pl.title('Convolutional Wasserstein Barycenters in POT')
58+
fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
59+
plt.suptitle('Convolutional Wasserstein Barycenters in POT')
6260
cm = 'Blues'
6361
# regularization parameter
6462
reg = 0.004
6563
for i in range(nb_images):
6664
for j in range(nb_images):
67-
pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
6865
tx = float(i) / (nb_images - 1)
6966
ty = float(j) / (nb_images - 1)
7067

@@ -74,19 +71,19 @@
7471
weights = (1 - ty) * tmp1 + ty * tmp2
7572

7673
if i == 0 and j == 0:
77-
pl.imshow(f1, cmap=cm)
78-
pl.axis('off')
74+
axes[i, j].imshow(f1, cmap=cm)
7975
elif i == 0 and j == (nb_images - 1):
80-
pl.imshow(f3, cmap=cm)
81-
pl.axis('off')
76+
axes[i, j].imshow(f3, cmap=cm)
8277
elif i == (nb_images - 1) and j == 0:
83-
pl.imshow(f2, cmap=cm)
84-
pl.axis('off')
78+
axes[i, j].imshow(f2, cmap=cm)
8579
elif i == (nb_images - 1) and j == (nb_images - 1):
86-
pl.imshow(f4, cmap=cm)
87-
pl.axis('off')
80+
axes[i, j].imshow(f4, cmap=cm)
8881
else:
8982
# call to barycenter computation
90-
pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
91-
pl.axis('off')
92-
pl.show()
83+
axes[i, j].imshow(
84+
ot.bregman.convolutional_barycenter2d(A, reg, weights),
85+
cmap=cm
86+
)
87+
axes[i, j].axis('off')
88+
plt.tight_layout()
89+
plt.show()
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=================================
4+
Debiased Sinkhorn barycenter demo
5+
=================================
6+
7+
This example illustrates the computation of the debiased Sinkhorn barycenter
8+
as proposed in [37]_.
9+
10+
11+
.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
12+
International Conference on Machine Learning, PMLR 119:4692-4701, 2020
13+
"""
14+
15+
# Author: Hicham Janati <hicham.janati100@gmail.com>
16+
#
17+
# License: MIT License
18+
# sphinx_gallery_thumbnail_number = 3
19+
20+
import os
21+
from pathlib import Path
22+
23+
import numpy as np
24+
import matplotlib.pyplot as plt
25+
26+
import ot
27+
from ot.bregman import (barycenter, barycenter_debiased,
28+
convolutional_barycenter2d,
29+
convolutional_barycenter2d_debiased)
30+
31+
##############################################################################
32+
# Debiased barycenter of 1D Gaussians
33+
# ------------------------------------
34+
35+
#%% parameters
36+
37+
n = 100 # nb bins
38+
39+
# bin positions
40+
x = np.arange(n, dtype=np.float64)
41+
42+
# Gaussian distributions
43+
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
44+
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
45+
46+
# creating matrix A containing all distributions
47+
A = np.vstack((a1, a2)).T
48+
n_distributions = A.shape[1]
49+
50+
# loss matrix + normalization
51+
M = ot.utils.dist0(n)
52+
M /= M.max()
53+
54+
#%% barycenter computation
55+
56+
alpha = 0.2 # 0<=alpha<=1
57+
weights = np.array([1 - alpha, alpha])
58+
59+
epsilons = [5e-3, 1e-2, 5e-2]
60+
61+
62+
bars = [barycenter(A, M, reg, weights) for reg in epsilons]
63+
bars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons]
64+
labels = ["Sinkhorn barycenter", "Debiased barycenter"]
65+
colors = ["indianred", "gold"]
66+
67+
f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True,
68+
figsize=(12, 4), num=1)
69+
for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased):
70+
ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3)
71+
ax.plot(A[:, 1], color="k", ls="--", alpha=0.3)
72+
for data, label, color in zip([bar, bar_debiased], labels, colors):
73+
ax.plot(data, color=color, label=label, lw=2)
74+
ax.set_title(r"$\varepsilon = %.3f$" % eps)
75+
plt.legend()
76+
plt.show()
77+
78+
79+
##############################################################################
80+
# Debiased barycenter of 2D images
81+
# ---------------------------------
82+
this_file = os.path.realpath('__file__')
83+
data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
84+
f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
85+
f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
86+
87+
A = np.asarray([f1, f2]) + 1e-2
88+
A /= A.sum(axis=(1, 2))[:, None, None]
89+
90+
##############################################################################
91+
# Display the input images
92+
93+
fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2)
94+
for ax, img in zip(axes, A):
95+
ax.imshow(img, cmap="Greys")
96+
ax.axis("off")
97+
fig.tight_layout()
98+
plt.show()
99+
100+
101+
##############################################################################
102+
# Barycenter computation and visualization
103+
# ----------------------------------------
104+
#
105+
106+
bars_sinkhorn, bars_debiased = [], []
107+
epsilons = [5e-3, 7e-3, 1e-2]
108+
for eps in epsilons:
109+
bar = convolutional_barycenter2d(A, eps)
110+
bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True)
111+
bars_sinkhorn.append(bar)
112+
bars_debiased.append(bar_debiased)
113+
114+
titles = ["Sinkhorn", "Debiased"]
115+
all_bars = [bars_sinkhorn, bars_debiased]
116+
fig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3)
117+
for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)):
118+
for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)):
119+
ax.imshow(img, cmap="Greys")
120+
if jj == 0:
121+
ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13)
122+
ax.set_xticks([])
123+
ax.set_yticks([])
124+
ax.spines['top'].set_visible(False)
125+
ax.spines['right'].set_visible(False)
126+
ax.spines['bottom'].set_visible(False)
127+
ax.spines['left'].set_visible(False)
128+
if ii == 0:
129+
ax.set_ylabel(method, fontsize=15)
130+
fig.tight_layout()
131+
plt.show()

0 commit comments

Comments
 (0)