Skip to content

[MRG] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers #431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
574b003
maj gw/ srgw/ generic cg solver
cedricvincentcuaz Sep 22, 2022
ff1e28d
correct pep8 on current state
cedricvincentcuaz Sep 22, 2022
6d879fe
fix bug previous tests
cedricvincentcuaz Sep 22, 2022
f80671a
fix pep8
cedricvincentcuaz Sep 22, 2022
ee330c0
fix bug srGW constC in loss and gradient
cedricvincentcuaz Sep 23, 2022
9670fc9
Merge branch 'master' into semirelaxed_gromov
rflamary Sep 27, 2022
adaae47
fix doc html
cedricvincentcuaz Sep 27, 2022
8d9827e
Merge branch 'semirelaxed_gromov' of https://github.com/cedricvincent…
cedricvincentcuaz Sep 27, 2022
7c758e3
fix doc html
cedricvincentcuaz Sep 27, 2022
98a880e
Merge branch 'master' into semirelaxed_gromov
rflamary Jan 4, 2023
c384f45
update generic_cg and dependencies
cedricvincentcuaz Feb 2, 2023
e6f0bb1
start updating test_optim.py
cedricvincentcuaz Feb 2, 2023
efbba2e
update tests gromov and optim - plus fix gromov dependencies
cedricvincentcuaz Feb 6, 2023
71be9d0
add symmetry feature to entropic gw
cedricvincentcuaz Feb 6, 2023
90fcd48
add symmetry feature to entropic gw
cedricvincentcuaz Feb 6, 2023
6ab9514
add exemple for sr(F)GW matchings
cedricvincentcuaz Feb 9, 2023
dc53fe1
Merge branch 'master' into semirelaxed_gromov
rflamary Feb 23, 2023
43bc857
Merge branch 'master' into semirelaxed_gromov
rflamary Feb 23, 2023
a9cbe08
factor linesearch dependencies /transpose + srgw to backend unfinished
cedricvincentcuaz Feb 27, 2023
6202290
merge releases.md
cedricvincentcuaz Feb 27, 2023
f7fa3ee
small stuff
cedricvincentcuaz Feb 27, 2023
54f0ba1
remove (reg,M) from line-search/ complete srgw tests with backend
cedricvincentcuaz Feb 28, 2023
a875a12
remove backend repetitions / rename fG to costG/ fix innerlog to True
cedricvincentcuaz Feb 28, 2023
92c69d4
fix pep8
cedricvincentcuaz Feb 28, 2023
be55ea2
Merge branch 'master' into semirelaxed_gromov
rflamary Feb 28, 2023
6069b0a
take comments into account / new nx parameters still to test
cedricvincentcuaz Mar 1, 2023
e50a750
Merge branch 'semirelaxed_gromov' of https://github.com/cedricvincent…
cedricvincentcuaz Mar 1, 2023
746bc1d
factor (f)gw2 + test new backend parameters in ot.gromov + harmonize …
cedricvincentcuaz Mar 3, 2023
b028b36
split gromov.py in ot/gromov/ + update test_gromov with helper_backen…
cedricvincentcuaz Mar 3, 2023
49fdbeb
manual documentaion gromov
rflamary Mar 9, 2023
95f2033
remove circular autosummary
rflamary Mar 9, 2023
5882cd1
Merge branch 'master' into semirelaxed_gromov
rflamary Mar 9, 2023
2409983
trying stuff
rflamary Mar 9, 2023
a1f1172
Merge branch 'semirelaxed_gromov' of https://github.com/cedricvincent…
rflamary Mar 9, 2023
dc1ac92
debug documentation
rflamary Mar 9, 2023
97c1e6b
alphabetic ordering of module
rflamary Mar 9, 2023
2d03ba6
merge into branch
cedricvincentcuaz Mar 9, 2023
679b74f
Merge branch 'semirelaxed_gromov' of https://github.com/cedricvincent…
cedricvincentcuaz Mar 9, 2023
fb86e46
add note in entropic gw solvers
cedricvincentcuaz Mar 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update generic_cg and dependencies
  • Loading branch information
cedricvincentcuaz committed Feb 2, 2023
commit c384f457f81ce75ca222421137c838d05c05343a
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#### New features
- Added feature to (Fused) Gromov-Wasserstein solvers to handle asymmetric matrices (PR #401)
- Added semi-relaxed (Fused) Gromov-Wasserstein solvers + examples (PR #401)
- Added Bures Wasserstein distance in `ot.gaussian` (PR ##428)
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
- Added Free Support Sinkhorn Barycenter + example (PR #387)
- New API for OT solver using function `ot.solve` (PR #388)
Expand Down
67 changes: 13 additions & 54 deletions ot/gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .bregman import sinkhorn
from .utils import dist, UndefinedParameter, list_to_array
from .optim import generic_cg, line_search_armijo, solve_gromov_linesearch, solve_semirelaxed_gromov_linesearch
from .optim import cg, semirelaxed_cg, line_search_armijo, solve_gromov_linesearch, solve_semirelaxed_gromov_linesearch
from .lp import emd_1d, emd
from .utils import check_random_state, unif
from .backend import get_backend
Expand Down Expand Up @@ -442,24 +442,20 @@ def df(G, qG=None):
def df(G, qG=None):
return 0.5 * (gwggrad(constC, hC1, hC2, G) + gwggrad(constCt, hC1t, hC2t, G))

def lp_solver(a, b, Mi, numItermax, log, **kwargs):
return emd(a, b, Mi, numItermax, log)

if armijo:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max)
else:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return solve_gromov_linesearch(G, deltaG, Mi, f_val, C1, C2t, reg, Gc, M, alpha_min, alpha_max)
if log:

res, log = generic_cg(p, q, 0, f, df, 1, None, lp_solver, line_search_solver, G0, log=True, **kwargs)
res, log = cg(p, q, 0., 1., f, df, G0, line_search_solver, log=True, **kwargs)
log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10)
log['u'] = nx.from_numpy(log['u'], type_as=C10)
log['v'] = nx.from_numpy(log['v'], type_as=C10)
return nx.from_numpy(res, type_as=C10), log
else:
return nx.from_numpy(generic_cg(p, q, 0, f, df, 1, None, lp_solver, line_search_solver, G0, log=False, **kwargs), type_as=C10)
return nx.from_numpy(cg(p, q, 0., 1., f, df, G0, line_search_solver, log=False, **kwargs), type_as=C10)


def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, **kwargs):
Expand Down Expand Up @@ -587,17 +583,14 @@ def df(G, qG=None):
def df(G, qG=None):
return 0.5 * (gwggrad(constC, hC1, hC2, G) + gwggrad(constCt, hC1t, hC2t, G))

def lp_solver(a, b, Mi, numItermax, log, **kwargs):
return emd(a, b, Mi, numItermax, log)

if armijo:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max)
else:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return solve_gromov_linesearch(G, deltaG, Mi, f_val, C1, C2t, reg, Gc, M, alpha_min, alpha_max)

T, log_gw = generic_cg(p, q, 0, f, df, 1, None, lp_solver, line_search_solver, G0, log=True, **kwargs)
T, log_gw = cg(p, q, 0., 1., f, df, G0, line_search_solver, log=True, **kwargs)

T0 = nx.from_numpy(T, type_as=C10)

Expand Down Expand Up @@ -744,8 +737,6 @@ def df(G, qG=None):
def df(G, qG=None):
return 0.5 * (gwggrad(constC, hC1, hC2, G) + gwggrad(constCt, hC1t, hC2t, G))

def lp_solver(a, b, Mi, numItermax, log, **kwargs):
return emd(a, b, Mi, numItermax, log)
if armijo:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max)
Expand All @@ -754,14 +745,14 @@ def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_
return solve_gromov_linesearch(G, deltaG, Mi, f_val, C1, C2t, reg, Gc, M, alpha_min, alpha_max)

if log:
res, log = generic_cg(p, q, (1 - alpha) * M, f, df, alpha, None, lp_solver, line_search_solver, G0, log=True, **kwargs)
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search_solver, log=True, **kwargs)
fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
log['fgw_dist'] = fgw_dist
log['u'] = nx.from_numpy(log['u'], type_as=C10)
log['v'] = nx.from_numpy(log['v'], type_as=C10)
return nx.from_numpy(res, type_as=C10), log
else:
return nx.from_numpy(generic_cg(p, q, (1 - alpha) * M, f, df, alpha, None, lp_solver, line_search_solver, G0, log=False, **kwargs), type_as=C10)
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search_solver, log=False, **kwargs), type_as=C10)


def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
Expand Down Expand Up @@ -894,17 +885,14 @@ def df(G, qG=None):
def df(G, qG=None):
return 0.5 * (gwggrad(constC, hC1, hC2, G) + gwggrad(constCt, hC1t, hC2t, G))

def lp_solver(a, b, Mi, numItermax, log, **kwargs):
return emd(a, b, Mi, numItermax, log)

if armijo:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max)
else:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return solve_gromov_linesearch(G, deltaG, Mi, f_val, C1, C2t, reg, Gc, M, alpha_min, alpha_max)

T, log_fgw = generic_cg(p, q, (1 - alpha) * M, f, df, alpha, None, lp_solver, line_search_solver, G0, log=True, **kwargs)
T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search_solver, log=True, **kwargs)

fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10)

Expand Down Expand Up @@ -3198,13 +3186,6 @@ def df(G, qG):
marginal_product_2 = nx.dot(ones_p[:, None], nx.dot(qG[None, :], fC2))
return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G))

def lp_solver(a, b, Mi, numItermax, log, **kwargs):

min_ = Mi.min(axis=1)
Gc = (Mi == min_[:, None]).astype(Mi.dtype)
Gc *= (a / Gc.sum(axis=1))[:, None]
return Gc

if armijo:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max)
Expand All @@ -3213,16 +3194,15 @@ def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_
return solve_semirelaxed_gromov_linesearch(G, deltaG, Mi, f_val, C1, C2t, ones_p, qG, qdeltaG, reg, Gc, M, alpha_min, alpha_max)

if log:

res, log = generic_cg(p, q, 0, f, df, 1, None, lp_solver, line_search_solver, semirelaxed=True, G0=G0, log=True, **kwargs)
res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search_solver, log=True, **kwargs)
q = res.sum(0)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10)
log['u'] = nx.from_numpy(log['u'], type_as=C10)
log['v'] = nx.from_numpy(log['v'], type_as=C10)
return nx.from_numpy(res, type_as=C10), log
else:
return nx.from_numpy(generic_cg(p, q, 0, f, df, 1, None, lp_solver, line_search_solver, semirelaxed=True, G0=G0, log=False, **kwargs), type_as=C10)
return nx.from_numpy(semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search_solver, log=False, **kwargs), type_as=C10)


def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, **kwargs):
Expand Down Expand Up @@ -3343,21 +3323,14 @@ def df(G, qG):
marginal_product_2 = nx.dot(ones_p[:, None], nx.dot(qG[None, :], fC2))
return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G))

def lp_solver(a, b, Mi, numItermax, log, **kwargs):

min_ = Mi.min(axis=1)
Gc = (Mi == min_[:, None]).astype(Mi.dtype)
Gc *= (a / Gc.sum(axis=1))[:, None]
return Gc

if armijo:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max)
else:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return solve_semirelaxed_gromov_linesearch(G, deltaG, Mi, f_val, C1, C2t, ones_p, qG, qdeltaG, reg, Gc, M, alpha_min, alpha_max)

T, log_gw = generic_cg(p, q, 0, f, df, 1, None, lp_solver, line_search_solver, G0, semirelaxed=True, log=True, **kwargs)
T, log_gw = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search_solver, log=True, **kwargs)

T0 = nx.from_numpy(T, type_as=C10)

Expand Down Expand Up @@ -3507,13 +3480,6 @@ def df(G, qG):
marginal_product_2 = nx.dot(ones_p[:, None], nx.dot(qG[None, :], fC2))
return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G))

def lp_solver(a, b, Mi, numItermax, log, **kwargs):

min_ = Mi.min(axis=1)
Gc = (Mi == min_[:, None]).astype(Mi.dtype)
Gc *= (a / Gc.sum(axis=1))[:, None]
return Gc

if armijo:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max)
Expand All @@ -3522,14 +3488,14 @@ def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_
return solve_semirelaxed_gromov_linesearch(G, deltaG, Mi, f_val, C1, C2t, ones_p, qG, qdeltaG, reg, Gc, M, alpha_min, alpha_max)

if log:
res, log = generic_cg(p, q, (1 - alpha) * M, f, df, alpha, None, lp_solver, line_search_solver, G0, semirelaxed=True, log=True, **kwargs)
res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search_solver, log=True, **kwargs)
fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
log['fgw_dist'] = fgw_dist
log['u'] = nx.from_numpy(log['u'], type_as=C10)
log['v'] = nx.from_numpy(log['v'], type_as=C10)
return nx.from_numpy(res, type_as=C10), log
else:
return nx.from_numpy(generic_cg(p, q, (1 - alpha) * M, f, df, alpha, None, lp_solver, line_search_solver, G0, semirelaxed=True, log=False, **kwargs), type_as=C10)
return nx.from_numpy(semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search_solver, log=True, **kwargs), type_as=C10)


def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
Expand Down Expand Up @@ -3659,21 +3625,14 @@ def df(G, qG):
marginal_product_2 = nx.dot(ones_p[:, None], nx.dot(qG[None, :], fC2))
return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G))

def lp_solver(a, b, Mi, numItermax, log, **kwargs):

min_ = Mi.min(axis=1)
Gc = (Mi == min_[:, None]).astype(Mi.dtype)
Gc *= (a / Gc.sum(axis=1))[:, None]
return Gc

if armijo:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max)
else:
def line_search_solver(cost, G, deltaG, Mi, f_val, reg, M, Gc, alpha_min, alpha_max, qG, qdeltaG, **kwargs):
return solve_semirelaxed_gromov_linesearch(G, deltaG, Mi, f_val, C1, C2t, ones_p, qG, qdeltaG, reg, Gc, M, alpha_min, alpha_max)

T, log_fgw = generic_cg(p, q, (1 - alpha) * M, f, df, alpha, None, lp_solver, line_search_solver, G0, semirelaxed=True, log=True, **kwargs)
T, log_fgw = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search_solver, log=True, **kwargs)
q = T.sum(0)

fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10)
Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.