Skip to content

Commit 0afd84d

Browse files
authored
[WIP] Add backend dual loss and plan computation for stochastic optimization or regularized OT (#360)
* add losses and plan computations and exmaple for dual oiptimization * pep8 * add nice exmaple * update awesome example stochasti dual * add all tests * pep8 + speedup exmaple * add release info
1 parent 82452e0 commit 0afd84d

File tree

5 files changed

+713
-3
lines changed

5 files changed

+713
-3
lines changed

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#### New features
77

8+
- Add stochastic loss and OT plan computation for regularized OT and
9+
backend examples(PR #360).
810
- Implementation of factored OT with emd and sinkhorn (PR #358).
911
- A brand new logo for POT (PR #357)
1012
- Better list of related examples in quick start guide with `minigallery` (PR #334).
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# -*- coding: utf-8 -*-
2+
r"""
3+
======================================================================
4+
Dual OT solvers for entropic and quadratic regularized OT with Pytorch
5+
======================================================================
6+
7+
8+
"""
9+
10+
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
11+
#
12+
# License: MIT License
13+
14+
# sphinx_gallery_thumbnail_number = 3
15+
16+
import numpy as np
17+
import matplotlib.pyplot as pl
18+
import torch
19+
import ot
20+
import ot.plot
21+
22+
# %%
23+
# Data generation
24+
# ---------------
25+
26+
torch.manual_seed(1)
27+
28+
n_source_samples = 100
29+
n_target_samples = 100
30+
theta = 2 * np.pi / 20
31+
noise_level = 0.1
32+
33+
Xs, ys = ot.datasets.make_data_classif(
34+
'gaussrot', n_source_samples, nz=noise_level)
35+
Xt, yt = ot.datasets.make_data_classif(
36+
'gaussrot', n_target_samples, theta=theta, nz=noise_level)
37+
38+
# one of the target mode changes its variance (no linear mapping)
39+
Xt[yt == 2] *= 3
40+
Xt = Xt + 4
41+
42+
43+
# %%
44+
# Plot data
45+
# ---------
46+
47+
pl.figure(1, (10, 5))
48+
pl.clf()
49+
pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples')
50+
pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples')
51+
pl.legend(loc=0)
52+
pl.title('Source and target distributions')
53+
54+
# %%
55+
# Convert data to torch tensors
56+
# -----------------------------
57+
58+
xs = torch.tensor(Xs)
59+
xt = torch.tensor(Xt)
60+
61+
# %%
62+
# Estimating dual variables for entropic OT
63+
# -----------------------------------------
64+
65+
u = torch.randn(n_source_samples, requires_grad=True)
66+
v = torch.randn(n_source_samples, requires_grad=True)
67+
68+
reg = 0.5
69+
70+
optimizer = torch.optim.Adam([u, v], lr=1)
71+
72+
# number of iteration
73+
n_iter = 200
74+
75+
76+
losses = []
77+
78+
for i in range(n_iter):
79+
80+
# generate noise samples
81+
82+
# minus because we maximize te dual loss
83+
loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg)
84+
losses.append(float(loss.detach()))
85+
86+
if i % 10 == 0:
87+
print("Iter: {:3d}, loss={}".format(i, losses[-1]))
88+
89+
loss.backward()
90+
optimizer.step()
91+
optimizer.zero_grad()
92+
93+
94+
pl.figure(2)
95+
pl.plot(losses)
96+
pl.grid()
97+
pl.title('Dual objective (negative)')
98+
pl.xlabel("Iterations")
99+
100+
Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg)
101+
102+
# %%
103+
# Plot teh estimated entropic OT plan
104+
# -----------------------------------
105+
106+
pl.figure(3, (10, 5))
107+
pl.clf()
108+
ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1)
109+
pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
110+
pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
111+
pl.legend(loc=0)
112+
pl.title('Source and target distributions')
113+
114+
115+
# %%
116+
# Estimating dual variables for quadratic OT
117+
# -----------------------------------------
118+
119+
u = torch.randn(n_source_samples, requires_grad=True)
120+
v = torch.randn(n_source_samples, requires_grad=True)
121+
122+
reg = 0.01
123+
124+
optimizer = torch.optim.Adam([u, v], lr=1)
125+
126+
# number of iteration
127+
n_iter = 200
128+
129+
130+
losses = []
131+
132+
133+
for i in range(n_iter):
134+
135+
# generate noise samples
136+
137+
# minus because we maximize te dual loss
138+
loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg)
139+
losses.append(float(loss.detach()))
140+
141+
if i % 10 == 0:
142+
print("Iter: {:3d}, loss={}".format(i, losses[-1]))
143+
144+
loss.backward()
145+
optimizer.step()
146+
optimizer.zero_grad()
147+
148+
149+
pl.figure(4)
150+
pl.plot(losses)
151+
pl.grid()
152+
pl.title('Dual objective (negative)')
153+
pl.xlabel("Iterations")
154+
155+
Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg)
156+
157+
158+
# %%
159+
# Plot the estimated quadratic OT plan
160+
# -----------------------------------
161+
162+
pl.figure(5, (10, 5))
163+
pl.clf()
164+
ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1)
165+
pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
166+
pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
167+
pl.legend(loc=0)
168+
pl.title('OT plan with quadratic regularization')
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# -*- coding: utf-8 -*-
2+
r"""
3+
======================================================================
4+
Continuous OT plan estimation with Pytorch
5+
======================================================================
6+
7+
8+
"""
9+
10+
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
11+
#
12+
# License: MIT License
13+
14+
# sphinx_gallery_thumbnail_number = 3
15+
16+
import numpy as np
17+
import matplotlib.pyplot as pl
18+
import torch
19+
from torch import nn
20+
import ot
21+
import ot.plot
22+
23+
# %%
24+
# Data generation
25+
# ---------------
26+
27+
torch.manual_seed(42)
28+
np.random.seed(42)
29+
30+
n_source_samples = 10000
31+
n_target_samples = 10000
32+
theta = 2 * np.pi / 20
33+
noise_level = 0.1
34+
35+
Xs = np.random.randn(n_source_samples, 2) * 0.5
36+
Xt = np.random.randn(n_target_samples, 2) * 2
37+
38+
# one of the target mode changes its variance (no linear mapping)
39+
Xt = Xt + 4
40+
41+
42+
# %%
43+
# Plot data
44+
# ---------
45+
nvisu = 300
46+
pl.figure(1, (5, 5))
47+
pl.clf()
48+
pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', label='Source samples', alpha=0.5)
49+
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', label='Target samples', alpha=0.5)
50+
pl.legend(loc=0)
51+
ax_bounds = pl.axis()
52+
pl.title('Source and target distributions')
53+
54+
# %%
55+
# Convert data to torch tensors
56+
# -----------------------------
57+
58+
xs = torch.tensor(Xs)
59+
xt = torch.tensor(Xt)
60+
61+
# %%
62+
# Estimating deep dual variables for entropic OT
63+
# ----------------------------------------------
64+
65+
torch.manual_seed(42)
66+
67+
# define the MLP model
68+
69+
70+
class Potential(torch.nn.Module):
71+
def __init__(self):
72+
super(Potential, self).__init__()
73+
self.fc1 = nn.Linear(2, 200)
74+
self.fc2 = nn.Linear(200, 1)
75+
self.relu = torch.nn.ReLU() # instead of Heaviside step fn
76+
77+
def forward(self, x):
78+
output = self.fc1(x)
79+
output = self.relu(output) # instead of Heaviside step fn
80+
output = self.fc2(output)
81+
return output.ravel()
82+
83+
84+
u = Potential().double()
85+
v = Potential().double()
86+
87+
reg = 1
88+
89+
optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005)
90+
91+
# number of iteration
92+
n_iter = 1000
93+
n_batch = 500
94+
95+
96+
losses = []
97+
98+
for i in range(n_iter):
99+
100+
# generate noise samples
101+
102+
iperms = torch.randint(0, n_source_samples, (n_batch,))
103+
ipermt = torch.randint(0, n_target_samples, (n_batch,))
104+
105+
xsi = xs[iperms]
106+
xti = xt[ipermt]
107+
108+
# minus because we maximize te dual loss
109+
loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg)
110+
losses.append(float(loss.detach()))
111+
112+
if i % 10 == 0:
113+
print("Iter: {:3d}, loss={}".format(i, losses[-1]))
114+
115+
loss.backward()
116+
optimizer.step()
117+
optimizer.zero_grad()
118+
119+
120+
pl.figure(2)
121+
pl.plot(losses)
122+
pl.grid()
123+
pl.title('Dual objective (negative)')
124+
pl.xlabel("Iterations")
125+
126+
127+
# %%
128+
# Plot the density on arget for a given source sample
129+
# ---------------------------------------------------
130+
131+
132+
nv = 100
133+
xl = np.linspace(ax_bounds[0], ax_bounds[1], nv)
134+
yl = np.linspace(ax_bounds[2], ax_bounds[3], nv)
135+
136+
XX, YY = np.meshgrid(xl, yl)
137+
138+
xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1)
139+
140+
wxg = np.exp(-((xg[:, 0] - 4)**2 + (xg[:, 1] - 4)**2) / (2 * 2))
141+
wxg = wxg / np.sum(wxg)
142+
143+
xg = torch.tensor(xg)
144+
wxg = torch.tensor(wxg)
145+
146+
147+
pl.figure(4, (12, 4))
148+
pl.clf()
149+
pl.subplot(1, 3, 1)
150+
151+
iv = 2
152+
Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
153+
Gg = Gg.reshape((nv, nv)).detach().numpy()
154+
155+
pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
156+
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
157+
pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
158+
pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
159+
pl.legend(loc=0)
160+
ax_bounds = pl.axis()
161+
pl.title('Density of transported source sample')
162+
163+
pl.subplot(1, 3, 2)
164+
165+
iv = 3
166+
Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
167+
Gg = Gg.reshape((nv, nv)).detach().numpy()
168+
169+
pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
170+
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
171+
pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
172+
pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
173+
pl.legend(loc=0)
174+
ax_bounds = pl.axis()
175+
pl.title('Density of transported source sample')
176+
177+
pl.subplot(1, 3, 3)
178+
179+
iv = 6
180+
Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
181+
Gg = Gg.reshape((nv, nv)).detach().numpy()
182+
183+
pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
184+
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
185+
pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
186+
pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
187+
pl.legend(loc=0)
188+
ax_bounds = pl.axis()
189+
pl.title('Density of transported source sample')

0 commit comments

Comments
 (0)