Skip to content

Commit 9076f02

Browse files
[FEAT] Entropic gw/fgw/srgw/srfgw solvers (#455)
* add entropic fgw + fgw bary + srgw + srfgw with tests * add exemples for entropic srgw - srfgw solvers * add PPA solvers for GW/FGW + complete previous commits * update readme * add tests * add examples + tests + warning in entropic solvers + releases * reduce testing runtimes for test_gromov * fix conflicts * optional marginals * improve coverage * gromov doc harmonization * fix pep8 * complete optional marginal for entropic srfgw --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent f0dab2f commit 9076f02

File tree

13 files changed

+2948
-375
lines changed

13 files changed

+2948
-375
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ The contributors to this library are:
3636
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
3737
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
3838
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
39-
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, semi-relaxed FGW)
39+
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW)
4040
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
4141
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
4242
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ POT provides the following generic OT solvers (links to examples):
2727
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
2828
* Weak OT solver between empirical distributions [39]
2929
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) with LP solver (only small scale).
30-
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from Graph Dictionary Learning [38]
31-
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
30+
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12,51]), differentiable using gradients from Graph Dictionary Learning [38]
31+
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) (exact [24] and regularized [12,51]).
3232
* [Stochastic
3333
solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and
3434
[differentiable losses](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html) for
@@ -42,7 +42,7 @@ POT provides the following generic OT solvers (links to examples):
4242
* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
4343
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
4444
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
45-
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) [48].
45+
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
4646
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
4747

4848
POT provides the following Machine Learning related solvers:
@@ -310,3 +310,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
310310
[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.
311311

312312
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
313+
314+
[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019.

RELEASES.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55
#### New features
66
- Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483)
77
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
8-
- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
8+
- Added the sparsity-constrained OT solver to `ot.smooth` and added `projection_sparse_simplex` to `ot.utils` (PR #459)
99
- Add tests on GPU for master branch and approved PR (PR #473)
1010
- Add `median` method to all inherited classes of `backend.Backend` (PR #472)
1111
- Update tests for macOS and Windows, speedup documentation (PR #484)
12+
- Added Proximal Point algorithm to solve GW problems via a new parameter `solver="PPA"` in `ot.gromov.entropic_gromov_wasserstein` + examples (PR #455)
13+
- Added features `warmstart` and `kwargs` in `ot.gromov.entropic_gromov_wasserstein` to respectively perform warmstart on dual potentials and pass parameters to `ot.sinkhorn` (PR #455)
14+
- Added sinkhorn projection based solvers for FGW `ot.gromov.entropic_fused_gromov_wasserstein` and entropic FGW barycenters + examples (PR #455)
15+
- Added features `warmstartT` and `kwargs` to all CG and entropic (F)GW barycenter solvers (PR #455)
16+
- Added entropic semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #455)
17+
- Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455)
1218

1319
#### Closed issues
14-
1520
- Fix circleci-redirector action and codecov (PR #460)
1621
- Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457)
1722
- Major documentation cleanup (PR #462, #467, #475)
@@ -22,6 +27,7 @@
2227
- Fix `utils.cost_normalization` function issue to work with multiple backends (PR #472)
2328

2429
## 0.9.0
30+
*April 2023*
2531

2632
This new release contains so many new features and bug fixes since 0.8.2 that we
2733
decided to make it a new minor release at 0.9.0.
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
==========================
4+
Entropic-regularized semi-relaxed (Fused) Gromov-Wasserstein example
5+
==========================
6+
7+
This example is designed to show how to use the entropic semi-relaxed Gromov-Wasserstein
8+
and the entropic semi-relaxed Fused Gromov-Wasserstein divergences.
9+
10+
Entropic-regularized sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of
11+
G2 at a minimal entropic-regularized (F)GW distance from G1.
12+
13+
First, we generate two graphs following Stochastic Block Models, then show
14+
how to compute their srGW matchings and illustrate them. These graphs are then
15+
endowed with node features and we follow the same process with srFGW.
16+
17+
[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
18+
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
19+
International Conference on Learning Representations (ICLR), 2021.
20+
"""
21+
22+
# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
23+
#
24+
# License: MIT License
25+
26+
# sphinx_gallery_thumbnail_number = 1
27+
28+
import numpy as np
29+
import matplotlib.pylab as pl
30+
from ot.gromov import entropic_semirelaxed_gromov_wasserstein, entropic_semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein
31+
import networkx
32+
from networkx.generators.community import stochastic_block_model as sbm
33+
34+
#############################################################################
35+
#
36+
# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
37+
# ---------------------------------------------
38+
39+
40+
N2 = 20 # 2 communities
41+
N3 = 30 # 3 communities
42+
p2 = [[1., 0.1],
43+
[0.1, 0.9]]
44+
p3 = [[1., 0.1, 0.],
45+
[0.1, 0.95, 0.1],
46+
[0., 0.1, 0.9]]
47+
G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)
48+
G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)
49+
50+
51+
C2 = networkx.to_numpy_array(G2)
52+
C3 = networkx.to_numpy_array(G3)
53+
54+
h2 = np.ones(C2.shape[0]) / C2.shape[0]
55+
h3 = np.ones(C3.shape[0]) / C3.shape[0]
56+
57+
# Add weights on the edges for visualization later on
58+
weight_intra_G2 = 5
59+
weight_inter_G2 = 0.5
60+
weight_intra_G3 = 1.
61+
weight_inter_G3 = 1.5
62+
63+
weightedG2 = networkx.Graph()
64+
part_G2 = [G2.nodes[i]['block'] for i in range(N2)]
65+
66+
for node in G2.nodes():
67+
weightedG2.add_node(node)
68+
for i, j in G2.edges():
69+
if part_G2[i] == part_G2[j]:
70+
weightedG2.add_edge(i, j, weight=weight_intra_G2)
71+
else:
72+
weightedG2.add_edge(i, j, weight=weight_inter_G2)
73+
74+
weightedG3 = networkx.Graph()
75+
part_G3 = [G3.nodes[i]['block'] for i in range(N3)]
76+
77+
for node in G3.nodes():
78+
weightedG3.add_node(node)
79+
for i, j in G3.edges():
80+
if part_G3[i] == part_G3[j]:
81+
weightedG3.add_edge(i, j, weight=weight_intra_G3)
82+
else:
83+
weightedG3.add_edge(i, j, weight=weight_inter_G3)
84+
85+
#############################################################################
86+
#
87+
# Compute their entropic-regularized semi-relaxed Gromov-Wasserstein divergences
88+
# ---------------------------------------------
89+
90+
# 0) GW(C2, h2, C3, h3) for reference
91+
OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)
92+
gw = log['gw_dist']
93+
94+
# 1) srGW_e(C2, h2, C3)
95+
OT_23, log_23 = entropic_semirelaxed_gromov_wasserstein(
96+
C2, C3, h2, symmetric=True, epsilon=1., G0=None, log=True)
97+
srgw_23 = log_23['srgw_dist']
98+
99+
# 2) srGW_e(C3, h3, C2)
100+
101+
OT_32, log_32 = entropic_semirelaxed_gromov_wasserstein(
102+
C3, C2, h3, symmetric=None, epsilon=1., G0=None, log=True)
103+
srgw_32 = log_32['srgw_dist']
104+
105+
print('GW(C2, C3) = ', gw)
106+
print('srGW_e(C2, h2, C3) = ', srgw_23)
107+
print('srGW_e(C3, h3, C2) = ', srgw_32)
108+
109+
110+
#############################################################################
111+
#
112+
# Visualization of the entropic-regularized semi-relaxed Gromov-Wasserstein matchings
113+
# ---------------------------------------------
114+
#
115+
# We color nodes of the graph on the right - then project its node colors
116+
# based on the optimal transport plan from the entropic srGW matching.
117+
# We adjust the intensity of links across domains proportionaly to the mass
118+
# sent, adding a minimal intensity of 0.1 if mass sent is not zero.
119+
120+
121+
def draw_graph(G, C, nodes_color_part, Gweights=None,
122+
pos=None, edge_color='black', node_size=None,
123+
shiftx=0, seed=0):
124+
125+
if (pos is None):
126+
pos = networkx.spring_layout(G, scale=1., seed=seed)
127+
128+
if shiftx != 0:
129+
for k, v in pos.items():
130+
v[0] = v[0] + shiftx
131+
132+
alpha_edge = 0.7
133+
width_edge = 1.8
134+
if Gweights is None:
135+
networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color)
136+
else:
137+
# We make more visible connections between activated nodes
138+
n = len(Gweights)
139+
edgelist_activated = []
140+
edgelist_deactivated = []
141+
for i in range(n):
142+
for j in range(n):
143+
if Gweights[i] * Gweights[j] * C[i, j] > 0:
144+
edgelist_activated.append((i, j))
145+
elif C[i, j] > 0:
146+
edgelist_deactivated.append((i, j))
147+
148+
networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated,
149+
width=width_edge, alpha=alpha_edge,
150+
edge_color=edge_color)
151+
networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated,
152+
width=width_edge, alpha=0.1,
153+
edge_color=edge_color)
154+
155+
if Gweights is None:
156+
for node, node_color in enumerate(nodes_color_part):
157+
networkx.draw_networkx_nodes(G, pos, nodelist=[node],
158+
node_size=node_size, alpha=1,
159+
node_color=node_color)
160+
else:
161+
scaled_Gweights = Gweights / (0.5 * Gweights.max())
162+
nodes_size = node_size * scaled_Gweights
163+
for node, node_color in enumerate(nodes_color_part):
164+
networkx.draw_networkx_nodes(G, pos, nodelist=[node],
165+
node_size=nodes_size[node], alpha=1,
166+
node_color=node_color)
167+
return pos
168+
169+
170+
def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1,
171+
p1, p2, T, pos1=None, pos2=None,
172+
shiftx=4, switchx=False, node_size=70,
173+
seed_G1=0, seed_G2=0):
174+
starting_color = 0
175+
# get graphs partition and their coloring
176+
part1 = part_G1.copy()
177+
unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)]
178+
nodes_color_part1 = []
179+
for cluster in part1:
180+
nodes_color_part1.append(unique_colors[cluster])
181+
182+
nodes_color_part2 = []
183+
# T: getting colors assignment from argmin of columns
184+
for i in range(len(G2.nodes())):
185+
j = np.argmax(T[:, i])
186+
nodes_color_part2.append(nodes_color_part1[j])
187+
pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1,
188+
pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1)
189+
pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2,
190+
node_size=node_size, shiftx=shiftx, seed=seed_G2)
191+
for k1, v1 in pos1.items():
192+
max_Tk1 = np.max(T[k1, :])
193+
for k2, v2 in pos2.items():
194+
if (T[k1, k2] > 0):
195+
pl.plot([pos1[k1][0], pos2[k2][0]],
196+
[pos1[k1][1], pos2[k2][1]],
197+
'-', lw=0.6, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.),
198+
color=nodes_color_part1[k1])
199+
return pos1, pos2
200+
201+
202+
node_size = 40
203+
fontsize = 10
204+
seed_G2 = 0
205+
seed_G3 = 4
206+
207+
pl.figure(1, figsize=(8, 2.5))
208+
pl.clf()
209+
pl.subplot(121)
210+
pl.axis('off')
211+
pl.axis
212+
pl.title(r'$srGW_e(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize)
213+
214+
hbar2 = OT_23.sum(axis=0)
215+
pos1, pos2 = draw_transp_colored_srGW(
216+
weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
217+
shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
218+
pl.subplot(122)
219+
pl.axis('off')
220+
hbar3 = OT_32.sum(axis=0)
221+
pl.title(r'$srGW_e(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize)
222+
pos1, pos2 = draw_transp_colored_srGW(
223+
weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
224+
pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
225+
pl.tight_layout()
226+
227+
pl.show()
228+
229+
#############################################################################
230+
#
231+
# Add node features
232+
# ---------------------------------------------
233+
234+
# We add node features with given mean - by clusters
235+
# and inversely proportional to clusters' intra-connectivity
236+
237+
F2 = np.zeros((N2, 1))
238+
for i, c in enumerate(part_G2):
239+
F2[i, 0] = np.random.normal(loc=c, scale=0.01)
240+
241+
F3 = np.zeros((N3, 1))
242+
for i, c in enumerate(part_G3):
243+
F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01)
244+
245+
#############################################################################
246+
#
247+
# Compute their semi-relaxed Fused Gromov-Wasserstein divergences
248+
# ---------------------------------------------
249+
250+
alpha = 0.5
251+
# Compute pairwise euclidean distance between node features
252+
M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T)
253+
254+
# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference
255+
256+
OT, log = fused_gromov_wasserstein(
257+
M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True)
258+
fgw = log['fgw_dist']
259+
260+
# 1) srFGW_e(C2, F2, h2, C3, F3)
261+
OT_23, log_23 = entropic_semirelaxed_fused_gromov_wasserstein(
262+
M, C2, C3, h2, symmetric=True, epsilon=1., alpha=0.5, log=True, G0=None)
263+
srfgw_23 = log_23['srfgw_dist']
264+
265+
# 2) srFGW(C3, F3, h3, C2, F2)
266+
267+
OT_32, log_32 = entropic_semirelaxed_fused_gromov_wasserstein(
268+
M.T, C3, C2, h3, symmetric=None, epsilon=1., alpha=alpha, log=True, G0=None)
269+
srfgw_32 = log_32['srfgw_dist']
270+
271+
print('FGW(C2, F2, C3, F3) = ', fgw)
272+
print(r'$srGW_e$(C2, F2, h2, C3, F3) = ', srfgw_23)
273+
print(r'$srGW_e$(C3, F3, h3, C2, F2) = ', srfgw_32)
274+
275+
#############################################################################
276+
#
277+
# Visualization of the entropic semi-relaxed Fused Gromov-Wasserstein matchings
278+
# ---------------------------------------------
279+
#
280+
# We color nodes of the graph on the right - then project its node colors
281+
# based on the optimal transport plan from the srFGW matching
282+
# NB: colors refer to clusters - not to node features
283+
284+
pl.figure(2, figsize=(8, 2.5))
285+
pl.clf()
286+
pl.subplot(121)
287+
pl.axis('off')
288+
pl.axis
289+
pl.title(r'$srFGW_e(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize)
290+
291+
hbar2 = OT_23.sum(axis=0)
292+
pos1, pos2 = draw_transp_colored_srGW(
293+
weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
294+
shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
295+
pl.subplot(122)
296+
pl.axis('off')
297+
hbar3 = OT_32.sum(axis=0)
298+
pl.title(r'$srFGW_e(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize)
299+
pos1, pos2 = draw_transp_colored_srGW(
300+
weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
301+
pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
302+
pl.tight_layout()
303+
304+
pl.show()

0 commit comments

Comments
 (0)