-
Notifications
You must be signed in to change notification settings - Fork 537
[MRG] Linear Circular OT #736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
3a8b176
32df12b
88b908c
9c82f38
dfaaf7b
63447e8
d1f274d
e6b9ebb
dfda841
52871e5
617ba67
589cfec
3b57922
c2bdd05
5ad161c
58e715f
b49ec14
2f44ea8
bef015b
971fa60
8fbc904
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -791,10 +791,12 @@ def binary_search_circle( | |
| -1, 1 | ||
| ) | ||
|
|
||
| mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) | ||
| tc[mask_end > 0] = ( | ||
| (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp) | ||
| )[mask_end > 0] | ||
| with warnings.catch_warnings(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a bit worrying. why are you catching tjose worning? when do they happen. It is probaly OK but we need at last a comment there to explian
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These warnings are raised when some elements of So I thought that it should be better to not raise them. There are maybe other solutions to avoid these warnings (I tried to put the masks insides the operations, but it broke the shapes, so I tried with the catch). I will add a comment. |
||
| warnings.simplefilter("ignore", category=RuntimeWarning) | ||
| mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) | ||
| tc[mask_end > 0] = ( | ||
| (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp) | ||
| )[mask_end > 0] | ||
| done[nx.prod(mask, axis=-1) > 0] = 1 | ||
| elif nx.any(1 - done): | ||
| tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] | ||
|
|
@@ -933,8 +935,8 @@ def wasserstein_circle( | |
| eps=1e-6, | ||
| require_sort=True, | ||
| ): | ||
| r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or | ||
| the binary search algorithm proposed in [44] otherwise. | ||
| r"""Computes the Wasserstein distance on the circle using either :ref:`[45] <references-wasserstein-circle>` for p=1 or | ||
| the binary search algorithm proposed in :ref:`[44] <references-wasserstein-circle>` otherwise. | ||
| Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, | ||
| takes the value modulo 1. | ||
| If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates | ||
|
|
@@ -996,17 +998,19 @@ def wasserstein_circle( | |
| >>> wasserstein_circle(u.T, v.T) | ||
| array([0.1]) | ||
|
|
||
|
|
||
| .. _references-wasserstein-circle: | ||
| References | ||
| ---------- | ||
| .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. | ||
| .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. | ||
| """ | ||
| assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) | ||
|
|
||
| if p == 1: | ||
| return wasserstein1_circle( | ||
| u_values, v_values, u_weights, v_weights, require_sort | ||
| ) | ||
| # if p == 1: | ||
clbonet marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # return wasserstein1_circle( | ||
| # u_values, v_values, u_weights, v_weights, require_sort | ||
| # ) | ||
|
|
||
| return binary_search_circle( | ||
| u_values, | ||
|
|
@@ -1042,7 +1046,7 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): | |
| .. math:: | ||
| u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}, | ||
|
|
||
| using e.g. ot.utils.get_coordinate_circle(x) | ||
| using e.g. ot.utils.get_coordinate_circle(x). | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -1095,3 +1099,150 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): | |
| cpt2 = nx.sum(u_values * u_weights * ns, axis=0) | ||
|
|
||
| return cpt1 - u_mean**2 + cpt2 + 1 / 12 | ||
|
|
||
|
|
||
| def linear_circular_embedding(x, u_values, u_weights=None, require_sort=True): | ||
| r"""Returns the embedding :math:`\hat{\mu}(x)` of Linear Circular OT with reference | ||
| :math:`\eta=\mathrm{Unif}(S^1)` evaluated in :math:`x`. | ||
|
|
||
| For any :math:`x\in [0,1[`, the embedding is given by (see :ref:`[76] <references-lcot>`) | ||
|
|
||
| .. math`` | ||
| \hat{\mu}(x) = F_{\mu}^{-1}\big(x - \int z\mathrm{d}\mu(z) + \frac12) - x. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : ndary, shape (m,) | ||
| Points in [0,1[ where to evaluate the embedding | ||
| u_values : ndarray, shape (n, ...) | ||
| samples in the source domain (coordinates on [0,1[) | ||
| u_weights : ndarray, shape (n, ...), optional | ||
| samples weights in the source domain | ||
|
|
||
| Returns | ||
| ------- | ||
| embedding: ndarray of shape (m, ...) | ||
| Embedding evaluated at :math:`x` | ||
|
|
||
| .. _references-lcot: | ||
| References | ||
| ---------- | ||
| .. [76] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. | ||
| """ | ||
| if u_weights is not None: | ||
| nx = get_backend(u_values, u_weights) | ||
| else: | ||
| nx = get_backend(u_values) | ||
|
|
||
| n = u_values.shape[0] | ||
| u_values = u_values % 1 | ||
|
|
||
| if len(u_values.shape) == 1: | ||
| u_values = nx.reshape(u_values, (n, 1)) | ||
|
|
||
| if u_weights is None: | ||
| u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) | ||
| elif u_weights.ndim != u_values.ndim: | ||
| u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) | ||
|
|
||
| if require_sort: | ||
| u_sorter = nx.argsort(u_values, 0) | ||
| u_values = nx.take_along_axis(u_values, u_sorter, 0) | ||
| u_weights = nx.take_along_axis(u_weights, u_sorter, 0) | ||
|
|
||
| u_cdf = nx.cumsum(u_weights, 0) | ||
| u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) | ||
|
|
||
| q_s = ( | ||
| x[:, None] - nx.sum(u_values * u_weights, axis=0)[None] + 0.5 | ||
| ) # shape (m, ...) | ||
|
|
||
| u_quantiles = quantile_function(q_s % 1, u_cdf, u_values) | ||
|
|
||
| return (u_quantiles - x[:, None]) % 1 | ||
|
|
||
|
|
||
| def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None): | ||
| r"""Computes the Linear Circular Optimal Transport distance from :ref:`[76] <references-lcot>` using :math:`\eta=\mathrm{Unif}(S^1)` | ||
| as reference measure. | ||
| Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, | ||
| takes the value modulo 1. | ||
| If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates | ||
| using e.g. the atan2 function. | ||
|
|
||
| General loss returned: | ||
|
|
||
| .. math:: | ||
| \mathrm{LCOT}_2^2(\mu, \nu) = \int_0^1 d_{S^1}\big(\hat{\mu}(t), \hat{\nu}(t)\big)^2\ \mathrm{d}t | ||
|
|
||
| where :math:`\hat{\mu}(x)=F_{\mu}^{-1}(x-\int z\mathrm{d}\mu(z)+\frac12) - x` for all :math:`x\in [0,1[`, | ||
| and :math:`d_{S^1}(x,y)=\min(|x-y|, 1-|x-y|)` for :math:`x,y\in [0,1[`. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| u_values : ndarray, shape (n, ...) | ||
| samples in the source domain (coordinates on [0,1[) | ||
| v_values : ndarray, shape (n, ...), optional | ||
| samples in the target domain (coordinates on [0,1[), if None, compute distance against uniform distribution | ||
| u_weights : ndarray, shape (n, ...), optional | ||
| samples weights in the source domain | ||
| v_weights : ndarray, shape (n, ...), optional | ||
| samples weights in the target domain | ||
|
|
||
| Returns | ||
| ------- | ||
| loss: float | ||
clbonet marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Cost associated to the linear optimal transportation | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> u = np.array([[0.2,0.5,0.8]])%1 | ||
| >>> v = np.array([[0.4,0.5,0.7]])%1 | ||
| >>> linear_circular_ot(u.T, v.T) | ||
| array([0.0127]) | ||
|
|
||
|
|
||
| .. _references-lcot: | ||
| References | ||
| ---------- | ||
| .. [76] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. | ||
| """ | ||
| if u_weights is not None: | ||
| nx = get_backend(u_values, u_weights) | ||
| else: | ||
| nx = get_backend(u_values) | ||
|
|
||
| n = u_values.shape[0] | ||
| u_values = u_values % 1 | ||
|
|
||
| if len(u_values.shape) == 1: | ||
| u_values = nx.reshape(u_values, (n, 1)) | ||
|
|
||
| if u_weights is None: | ||
| u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) | ||
| elif u_weights.ndim != u_values.ndim: | ||
| u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) | ||
|
|
||
| unif_s1 = nx.linspace(0, 1, 101, type_as=u_values)[:-1] | ||
|
|
||
| emb_u = linear_circular_embedding(unif_s1, u_values, u_weights) | ||
|
|
||
| if v_values is None: | ||
| dist_u = nx.minimum(nx.abs(emb_u), 1 - nx.abs(emb_u)) | ||
| return nx.mean(dist_u**2, axis=0) | ||
| else: | ||
| m = v_values.shape[0] | ||
| if len(v_values.shape) == 1: | ||
| v_values = nx.reshape(v_values, (m, 1)) | ||
|
|
||
| if u_values.shape[1] != v_values.shape[1]: | ||
| raise ValueError( | ||
| "u and v must have the same number of batchs {} and {} respectively given".format( | ||
| u_values.shape[1], v_values.shape[1] | ||
| ) | ||
| ) | ||
|
|
||
| emb_v = linear_circular_embedding(unif_s1, v_values, v_weights) | ||
|
|
||
| dist_uv = nx.minimum(nx.abs(emb_u - emb_v), 1 - nx.abs(emb_u - emb_v)) | ||
| return nx.mean(dist_uv**2, axis=0) | ||
Uh oh!
There was an error while loading. Please reload this page.