-
Notifications
You must be signed in to change notification settings - Fork 528
[MRG] MM algorithms for UOT #362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
850d116
bugfix
lchapel 573fbd5
Merge branch 'lchapel-partial-W-and-GW'
lchapel 2b433d1
update refs partial OT
lchapel bf40e97
fixes small typos in plot_partial_wass_and_gromov
lchapel 8bfdd43
fix small bugs in partial.py
lchapel 096b878
Merge branch 'master' of https://github.com/PythonOT/POT
lchapel 5302f6d
update README
lchapel 5a1c998
pep8 bugfix
lchapel d136c8e
modif doctest
lchapel dcbd1b5
fix bugtests
lchapel eedf08d
update on test_partial and test on the numerical precision on ot/partial
lchapel 26a843f
resolve merge pb
lchapel fc7bc8d
Merge branch 'master' into master
rflamary 02b2c72
Delete partial.py
lchapel e4c50d5
Merge branch 'master' of https://github.com/PythonOT/POT into PythonO…
lchapel 96500a4
Merge branch 'PythonOT-master'
lchapel 63601c4
Merge branch 'PythonOT:master' into master
lchapel a33a994
update unbalanced: mm algo+plot
lchapel f37bcf2
update unbalanced: mm algo+plot
lchapel ef0d15c
update unbalanced: mm algo+plot
lchapel 50ebbed
Merge branch 'master' into master
rflamary 44d799d
update unbalanced: mm algo+plot
lchapel ecfbbcc
Merge branch 'master' of https://github.com/lchapel/POT
lchapel 8fa78be
update unbalanced: mm algo+plot
lchapel f11caf2
add test mm algo unbalanced OT
lchapel b944132
add test mm algo unbalanced OT
lchapel 97282df
add test mm algo unbalanced OT
lchapel adf0662
add test mm algo unbalanced OT
lchapel 49a71a0
add test mm algo unbalanced OT
lchapel fc37e17
add test mm algo unbalanced OT
lchapel 915a4c6
add test mm algo unbalanced OT
lchapel c574874
add test mm algo unbalanced OT
lchapel de19311
update unbalanced: mm algo+plot
lchapel 4251357
update unbalanced: mm algo+plot
lchapel 8385b6d
update releases.md with new MM UOT algorithms
lchapel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ API and modules | |
plot | ||
stochastic | ||
unbalanced | ||
regpath | ||
partial | ||
sliced | ||
weak | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
============================================================== | ||
2D examples of exact and entropic unbalanced optimal transport | ||
============================================================== | ||
This example is designed to show how to compute unbalanced and | ||
partial OT in POT. | ||
|
||
UOT aims at solving the following optimization problem: | ||
|
||
.. math:: | ||
W = \min_{\gamma} <\gamma, \mathbf{M}>_F + | ||
\mathrm{reg}\cdot\Omega(\gamma) + | ||
\mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + | ||
\mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) | ||
|
||
s.t. | ||
\gamma \geq 0 | ||
|
||
where :math:`\mathrm{div}` is a divergence. | ||
When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}` | ||
should be the Kullback-Leibler divergence. | ||
When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}` | ||
can be either the Kullback-Leibler or the quadratic divergence. | ||
Using :math:`\ell_1` norm gives the so-called partial OT. | ||
""" | ||
|
||
# Author: Laetitia Chapel <laetitia.chapel@univ-ubs.fr> | ||
# License: MIT License | ||
|
||
import numpy as np | ||
import matplotlib.pylab as pl | ||
import ot | ||
|
||
############################################################################## | ||
# Generate data | ||
# ------------- | ||
|
||
# %% parameters and data generation | ||
|
||
n = 40 # nb samples | ||
|
||
mu_s = np.array([-1, -1]) | ||
cov_s = np.array([[1, 0], [0, 1]]) | ||
|
||
mu_t = np.array([4, 4]) | ||
cov_t = np.array([[1, -.8], [-.8, 1]]) | ||
|
||
np.random.seed(0) | ||
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) | ||
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) | ||
|
||
n_noise = 10 | ||
|
||
xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0) | ||
xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0) | ||
|
||
n = n + n_noise | ||
|
||
a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples | ||
|
||
# loss matrix | ||
M = ot.dist(xs, xt) | ||
M /= M.max() | ||
|
||
|
||
############################################################################## | ||
# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT | ||
# ----------- | ||
|
||
reg = 0.005 | ||
reg_m_kl = 0.05 | ||
reg_m_l2 = 5 | ||
mass = 0.7 | ||
|
||
entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl) | ||
kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl') | ||
l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2') | ||
partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass) | ||
|
||
############################################################################## | ||
# Plot the results | ||
# ---------------- | ||
|
||
pl.figure(2) | ||
transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot] | ||
title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" + | ||
str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl), | ||
"entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)] | ||
|
||
for p in range(4): | ||
pl.subplot(2, 4, p + 1) | ||
P = transp[p] | ||
if P.sum() > 0: | ||
P = P / P.max() | ||
for i in range(n): | ||
for j in range(n): | ||
if P[i, j] > 0: | ||
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', | ||
alpha=P[i, j] * 0.3) | ||
pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) | ||
pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) | ||
pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2) | ||
pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2) | ||
pl.title(title[p]) | ||
pl.yticks(()) | ||
pl.xticks(()) | ||
if p < 1: | ||
pl.ylabel("mappings") | ||
pl.subplot(2, 4, p + 5) | ||
pl.imshow(P, cmap='jet') | ||
pl.yticks(()) | ||
pl.xticks(()) | ||
if p < 1: | ||
pl.ylabel("transport plans") | ||
pl.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is all this? was the implementation false? why is there so much more steps?