Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems.

[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).

[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021).
147 changes: 142 additions & 5 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ class label

# pairwise distance
self.cost_ = dist(Xs, Xt, metric=self.metric)
self.cost_ = cost_normalization(self.cost_, self.norm)
self.cost_, self.norm_cost_ = cost_normalization(self.cost_, self.norm, return_value=True)

if (ys is not None) and (yt is not None):

Expand Down Expand Up @@ -1058,10 +1058,14 @@ class SinkhornTransport(BaseTransport):
can occur with large metric values.
distribution_estimation : callable, optional (defaults to the uniform)
The kind of distribution estimation to employ
out_of_sample_map : string, optional (default="ferradans")
out_of_sample_map : string, optional (default="continuous")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
"ferradans" which uses the method proposed in :ref:`[6] <references-sinkhorntransport>`.
"ferradans" which uses the nearest neighbor method proposed in :ref:`[6]
<references-sinkhorntransport>` while "continuous" use the out of sample
method from :ref:`[66]
<references-sinkhorntransport>` and :ref:`[19]
<references-sinkhorntransport>`.
limit_max: float, optional (default=np.infty)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an cost defined
Expand Down Expand Up @@ -1089,13 +1093,26 @@ class SinkhornTransport(BaseTransport):
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport. SIAM Journal on Imaging
Sciences, 7(3), 1853-1882.

.. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.
& Blondel, M. Large-scale Optimal Transport and Mapping Estimation.
International Conference on Learning Representation (2018)

.. [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. "Entropic
estimation of optimal transport maps." arXiv preprint
arXiv:2109.12004 (2021).

"""

def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000,
def __init__(self, reg_e=1., method="sinkhorn_log", max_iter=1000,
tol=10e-9, verbose=False, log=False,
metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
out_of_sample_map='continuous', limit_max=np.infty):

if out_of_sample_map not in ['ferradans', 'continuous']:
raise ValueError('Unknown out_of_sample_map method')

self.reg_e = reg_e
self.method = method
self.max_iter = max_iter
Expand Down Expand Up @@ -1135,6 +1152,12 @@ class label

super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)

if self.out_of_sample_map == 'continuous':
self.log = True
if not self.method == 'sinkhorn_log':
self.method = 'sinkhorn_log'
warnings.warn("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='continuous'")

# coupling estimation
returned_ = sinkhorn(
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
Expand All @@ -1150,6 +1173,120 @@ class label

return self

def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`

Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
The source input samples.
ys : array-like, shape (n_source_samples,)
The class labels for source samples
Xt : array-like, shape (n_target_samples, n_features)
The target input samples.
yt : array-like, shape (n_target_samples,)
The class labels for target. If some target samples are unlabelled, fill the
:math:`\mathbf{y_t}`'s elements with -1.

Warning: Note that, due to this convention -1 cannot be used as a
class label
batch_size : int, optional (default=128)
The batch size for out of sample inverse transform

Returns
-------
transp_Xs : array-like, shape (n_source_samples, n_features)
The transport source samples.
"""
nx = self.nx

if self.out_of_sample_map == 'ferradans':
return super(SinkhornTransport, self).transform(Xs, ys, Xt, yt, batch_size)

else: # self.out_of_sample_map == 'continuous':

# check the necessary inputs parameters are here
g = self.log_['log_v']

indices = nx.arange(Xs.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size)]

transp_Xs = []
for bi in batch_ind:
# get the nearest neighbor in the source domain
M = dist(Xs[bi], self.xt_, metric=self.metric)

M = cost_normalization(M, self.norm, value=self.norm_cost_)

K = nx.exp(-M / self.reg_e + g[None, :])

transp_Xs_ = nx.dot(K, self.xt_) / nx.sum(K, axis=1)[:, None]

transp_Xs.append(transp_Xs_)

transp_Xs = nx.concatenate(transp_Xs, axis=0)

return transp_Xs

def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`

Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
The source input samples.
ys : array-like, shape (n_source_samples,)
The class labels for source samples
Xt : array-like, shape (n_target_samples, n_features)
The target input samples.
yt : array-like, shape (n_target_samples,)
The class labels for target. If some target samples are unlabelled, fill the
:math:`\mathbf{y_t}`'s elements with -1.

Warning: Note that, due to this convention -1 cannot be used as a
class label
batch_size : int, optional (default=128)
The batch size for out of sample inverse transform

Returns
-------
transp_Xt : array-like, shape (n_source_samples, n_features)
The transport target samples.
"""

nx = self.nx

if self.out_of_sample_map == 'ferradans':
return super(SinkhornTransport, self).inverse_transform(Xs, ys, Xt, yt, batch_size)

else: # self.out_of_sample_map == 'continuous':

f = self.log_['log_u']

indices = nx.arange(Xt.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size
)]

transp_Xt = []
for bi in batch_ind:

M = dist(Xt[bi], self.xs_, metric=self.metric)
M = cost_normalization(M, self.norm, value=self.norm_cost_)

K = nx.exp(-M / self.reg_e + f[None, :])

transp_Xt_ = nx.dot(K, self.xs_) / nx.sum(K, axis=1)[:, None]

transp_Xt.append(transp_Xt_)

transp_Xt = nx.concatenate(transp_Xt, axis=0)

return transp_Xt


class EMDTransport(BaseTransport):

Expand Down
15 changes: 11 additions & 4 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def dist0(n, method='lin_square'):
return res


def cost_normalization(C, norm=None):
def cost_normalization(C, norm=None, return_value=False, value=None):
r""" Apply normalization to the loss matrix

Parameters
Expand All @@ -382,9 +382,13 @@ def cost_normalization(C, norm=None):
if norm is None:
pass
elif norm == "median":
C /= float(nx.median(C))
if value is None:
value = nx.median(C)
C /= value
elif norm == "max":
C /= float(nx.max(C))
if value is None:
value = nx.max(C)
C /= float(value)
elif norm == "log":
C = nx.log(1 + C)
elif norm == "loglog":
Expand All @@ -393,7 +397,10 @@ def cost_normalization(C, norm=None):
raise ValueError('Norm %s is not a valid option.\n'
'Valid options are:\n'
'median, max, log, loglog' % norm)
return C
if return_value:
return C, value
else:
return C


def dots(*args):
Expand Down
20 changes: 20 additions & 0 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,26 @@ def test_sinkhorn_transport_class(nx):
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
assert len(otda.log_.keys()) != 0

# test diffeernt transform and inverse transform
otda = ot.da.SinkhornTransport(out_of_sample_map='ferradans')
transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)

# test diffeernt transform
otda = ot.da.SinkhornTransport(out_of_sample_map='continuous', method='sinkhorn')
transp_Xs2 = otda.fit_transform(Xs=Xs, Xt=Xt)
assert_equal(transp_Xs2.shape, Xs.shape)
transp_Xt2 = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt2.shape, Xt.shape)

np.testing.assert_almost_equal(nx.to_numpy(transp_Xs), nx.to_numpy(transp_Xs2), decimal=5)
np.testing.assert_almost_equal(nx.to_numpy(transp_Xt), nx.to_numpy(transp_Xt2), decimal=5)

with pytest.raises(ValueError):
otda = ot.da.SinkhornTransport(out_of_sample_map='unknown')


@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
Expand Down