Skip to content

Commit 82452e0

Browse files
authored
[MRG] Add factored coupling (#358)
* add gfactored ot * pep8 and add doc * add exmaple for factotred OT * final number of PR * correct test on backends * remove useless loss * better tests
1 parent 7671715 commit 82452e0

File tree

8 files changed

+303
-2
lines changed

8 files changed

+303
-2
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,4 +305,6 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020
305305
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
306306
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.
307307

308-
[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
308+
[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
309+
310+
[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR.

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#### New features
77

8+
- Implementation of factored OT with emd and sinkhorn (PR #358).
89
- A brand new logo for POT (PR #357)
910
- Better list of related examples in quick start guide with `minigallery` (PR #334).
1011
- Add optional log-domain Sinkhorn implementation in WDA to support smaller values

docs/source/all.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ API and modules
2929
partial
3030
sliced
3131
weak
32+
factored
3233

3334
.. autosummary::
3435
:toctree: ../modules/generated/
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
==========================================
4+
Optimal transport with factored couplings
5+
==========================================
6+
7+
Illustration of the factored coupling OT between 2D empirical distributions
8+
9+
"""
10+
11+
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
12+
#
13+
# License: MIT License
14+
15+
# sphinx_gallery_thumbnail_number = 2
16+
17+
import numpy as np
18+
import matplotlib.pylab as pl
19+
import ot
20+
import ot.plot
21+
22+
# %%
23+
# Generate data an plot it
24+
# ------------------------
25+
26+
# parameters and data generation
27+
28+
np.random.seed(42)
29+
30+
n = 100 # nb samples
31+
32+
xs = np.random.rand(n, 2) - .5
33+
34+
xs = xs + np.sign(xs)
35+
36+
xt = np.random.rand(n, 2) - .5
37+
38+
a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
39+
40+
#%% plot samples
41+
42+
pl.figure(1)
43+
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
44+
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
45+
pl.legend(loc=0)
46+
pl.title('Source and target distributions')
47+
48+
49+
# %%
50+
# Compute Factore OT and exact OT solutions
51+
# --------------------------------------
52+
53+
#%% EMD
54+
M = ot.dist(xs, xt)
55+
G0 = ot.emd(a, b, M)
56+
57+
#%% factored OT OT
58+
59+
Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4)
60+
61+
62+
# %%
63+
# Plot factored OT and exact OT solutions
64+
# --------------------------------------
65+
66+
pl.figure(2, (14, 4))
67+
68+
pl.subplot(1, 3, 1)
69+
ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1)
70+
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
71+
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
72+
pl.title('Exact OT with samples')
73+
74+
pl.subplot(1, 3, 2)
75+
ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5)
76+
ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5)
77+
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
78+
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
79+
pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples')
80+
pl.title('Factored OT with template samples')
81+
82+
pl.subplot(1, 3, 3)
83+
ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1)
84+
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
85+
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
86+
pl.title('Factored OT low rank OT plan')

ot/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from . import backend
3434
from . import regpath
3535
from . import weak
36+
from . import factored
3637

3738
# OT functions
3839
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -44,6 +45,9 @@
4445
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
4546
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
4647
from .weak import weak_optimal_transport
48+
from .factored import factored_optimal_transport
49+
50+
4751
# utils functions
4852
from .utils import dist, unif, tic, toc, toq
4953

@@ -57,4 +61,5 @@
5761
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
5862
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
5963
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
64+
'factored_optimal_transport',
6065
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']

ot/factored.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
Factored OT solvers (low rank, cost or OT plan)
3+
"""
4+
5+
# Author: Remi Flamary <remi.flamary@polytehnique.edu>
6+
#
7+
# License: MIT License
8+
9+
from .backend import get_backend
10+
from .utils import dist
11+
from .lp import emd
12+
from .bregman import sinkhorn
13+
14+
__all__ = ['factored_optimal_transport']
15+
16+
17+
def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs):
18+
r"""Solves factored OT problem and return OT plans and intermediate distribution
19+
20+
This function solve the following OT problem [40]_
21+
22+
.. math::
23+
\mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b)
24+
25+
where :
26+
27+
- :math:`\mu_a` and :math:`\mu_b` are empirical distributions.
28+
- :math:`\mu` is an empirical distribution with r samples
29+
30+
And returns the two OT plans between
31+
32+
.. note:: This function is backend-compatible and will work on arrays
33+
from all compatible backends. But the algorithm uses the C++ CPU backend
34+
which can lead to copy overhead on GPU arrays.
35+
36+
Uses the conditional gradient algorithm to solve the problem proposed in
37+
:ref:`[39] <references-weak>`.
38+
39+
Parameters
40+
----------
41+
Xa : (ns,d) array-like, float
42+
Source samples
43+
Xb : (nt,d) array-like, float
44+
Target samples
45+
a : (ns,) array-like, float
46+
Source histogram (uniform weight if empty list)
47+
b : (nt,) array-like, float
48+
Target histogram (uniform weight if empty list))
49+
numItermax : int, optional
50+
Max number of iterations
51+
stopThr : float, optional
52+
Stop threshold on the relative variation (>0)
53+
verbose : bool, optional
54+
Print information along iterations
55+
log : bool, optional
56+
record log if True
57+
58+
59+
Returns
60+
-------
61+
Ga: array-like, shape (ns, r)
62+
Optimal transportation matrix between source and the intermediate
63+
distribution
64+
Gb: array-like, shape (r, nt)
65+
Optimal transportation matrix between the intermediate and target
66+
distribution
67+
X: array-like, shape (r, d)
68+
Support of the intermediate distribution
69+
log: dict, optional
70+
If input log is true, a dictionary containing the cost and dual
71+
variables and exit status
72+
73+
74+
.. _references-factored:
75+
References
76+
----------
77+
.. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger,
78+
G., & Weed, J. (2019, April). Statistical optimal transport via factored
79+
couplings. In The 22nd International Conference on Artificial
80+
Intelligence and Statistics (pp. 2454-2465). PMLR.
81+
82+
See Also
83+
--------
84+
ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
85+
regularized OT
86+
"""
87+
88+
nx = get_backend(Xa, Xb)
89+
90+
n_a = Xa.shape[0]
91+
n_b = Xb.shape[0]
92+
d = Xa.shape[1]
93+
94+
if a is None:
95+
a = nx.ones((n_a), type_as=Xa) / n_a
96+
if b is None:
97+
b = nx.ones((n_b), type_as=Xb) / n_b
98+
99+
if X0 is None:
100+
X = nx.randn(r, d, type_as=Xa)
101+
else:
102+
X = X0
103+
104+
w = nx.ones(r, type_as=Xa) / r
105+
106+
def solve_ot(X1, X2, w1, w2):
107+
M = dist(X1, X2)
108+
if reg > 0:
109+
G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs)
110+
log['cost'] = nx.sum(G * M)
111+
return G, log
112+
else:
113+
return emd(w1, w2, M, log=True, **kwargs)
114+
115+
norm_delta = []
116+
117+
# solve the barycenter
118+
for i in range(numItermax):
119+
120+
old_X = X
121+
122+
# solve OT with template
123+
Ga, loga = solve_ot(Xa, X, a, w)
124+
Gb, logb = solve_ot(X, Xb, w, b)
125+
126+
X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r
127+
128+
delta = nx.norm(X - old_X)
129+
if delta < stopThr:
130+
break
131+
if log:
132+
norm_delta.append(delta)
133+
134+
if log:
135+
log_dic = {'delta_iter': norm_delta,
136+
'ua': loga['u'],
137+
'va': loga['v'],
138+
'ub': logb['u'],
139+
'vb': logb['v'],
140+
'costa': loga['cost'],
141+
'costb': logb['cost'],
142+
}
143+
return Ga, Gb, X, log_dic
144+
145+
return Ga, Gb, X

ot/plot.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,13 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
8585
if ('color' not in kwargs) and ('c' not in kwargs):
8686
kwargs['color'] = 'k'
8787
mx = G.max()
88+
if 'alpha' in kwargs:
89+
scale = kwargs['alpha']
90+
del kwargs['alpha']
91+
else:
92+
scale = 1
8893
for i in range(xs.shape[0]):
8994
for j in range(xt.shape[0]):
9095
if G[i, j] / mx > thr:
9196
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
92-
alpha=G[i, j] / mx, **kwargs)
97+
alpha=G[i, j] / mx * scale, **kwargs)

test/test_factored.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Tests for main module ot.weak """
2+
3+
# Author: Remi Flamary <remi.flamary@unice.fr>
4+
#
5+
# License: MIT License
6+
7+
import ot
8+
import numpy as np
9+
10+
11+
def test_factored_ot():
12+
# test weak ot solver and identity stationary point
13+
n = 50
14+
rng = np.random.RandomState(0)
15+
16+
xs = rng.randn(n, 2)
17+
xt = rng.randn(n, 2)
18+
u = ot.utils.unif(n)
19+
20+
Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, r=10, log=True)
21+
22+
# check constraints
23+
np.testing.assert_allclose(u, Ga.sum(1))
24+
np.testing.assert_allclose(u, Gb.sum(0))
25+
26+
Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, reg=1, r=10, log=True)
27+
28+
# check constraints
29+
np.testing.assert_allclose(u, Ga.sum(1))
30+
np.testing.assert_allclose(u, Gb.sum(0))
31+
32+
33+
def test_factored_ot_backends(nx):
34+
# test weak ot solver for different backends
35+
n = 50
36+
rng = np.random.RandomState(0)
37+
38+
xs = rng.randn(n, 2)
39+
xt = rng.randn(n, 2)
40+
u = ot.utils.unif(n)
41+
42+
xs2 = nx.from_numpy(xs)
43+
xt2 = nx.from_numpy(xt)
44+
u2 = nx.from_numpy(u)
45+
46+
Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, u2, u2, r=10)
47+
48+
# check constraints
49+
np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
50+
np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))
51+
52+
Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, reg=1, r=10, X0=X2)
53+
54+
# check constraints
55+
np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
56+
np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))

0 commit comments

Comments
 (0)