Skip to content

Commit c84ef33

Browse files
[MRG] fix doc+example lowrank sinkhorn (#601)
* fix doc+example lowrank sinkhorn * fix autosummary for lowrank doc * update release --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 64c8374 commit c84ef33

File tree

5 files changed

+21
-22
lines changed

5 files changed

+21
-22
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#### Closed issues
66
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
77
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
8+
- Fix doc and example for lowrank sinkhorn (PR #601)
89

910
## 0.9.2
1011
*December 2023*

docs/source/all.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ API and modules
2424
gaussian
2525
gnn
2626
gromov
27+
lowrank
2728
lp
2829
mapping
2930
optim

examples/others/plot_lowrank_sinkhorn.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,40 +88,35 @@
8888
#%%
8989

9090
# Plot sinkhorn vs low rank sinkhorn
91-
pl.figure(1, figsize=(10, 4))
91+
pl.figure(1, figsize=(10, 8))
9292

93-
pl.subplot(1, 3, 1)
93+
pl.subplot(2, 3, 1)
9494
pl.imshow(list_P_Sin[0], interpolation='nearest')
9595
pl.axis('off')
9696
pl.title('Sinkhorn (reg=0.05)')
9797

98-
pl.subplot(1, 3, 2)
98+
pl.subplot(2, 3, 2)
9999
pl.imshow(list_P_Sin[1], interpolation='nearest')
100100
pl.axis('off')
101101
pl.title('Sinkhorn (reg=0.005)')
102102

103-
pl.subplot(1, 3, 3)
103+
pl.subplot(2, 3, 3)
104104
pl.imshow(list_P_Sin[2], interpolation='nearest')
105105
pl.axis('off')
106106
pl.title('Sinkhorn (reg=0.001)')
107107
pl.show()
108108

109-
110-
#%%
111-
112-
pl.figure(2, figsize=(10, 4))
113-
114-
pl.subplot(1, 3, 1)
109+
pl.subplot(2, 3, 4)
115110
pl.imshow(list_P_LR[0], interpolation='nearest')
116111
pl.axis('off')
117112
pl.title('Low rank (rank=3)')
118113

119-
pl.subplot(1, 3, 2)
114+
pl.subplot(2, 3, 5)
120115
pl.imshow(list_P_LR[1], interpolation='nearest')
121116
pl.axis('off')
122117
pl.title('Low rank (rank=10)')
123118

124-
pl.subplot(1, 3, 3)
119+
pl.subplot(2, 3, 6)
125120
pl.imshow(list_P_LR[2], interpolation='nearest')
126121
pl.axis('off')
127122
pl.title('Low rank (rank=50)')

ot/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
:py:mod:`ot.utils`, :py:mod:`ot.datasets`,
66
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
77
:py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath`
8-
, :py:mod:`ot.unbalanced`, :py:mod`ot.mapping`.
8+
, :py:mod:`ot.unbalanced`, :py:mod:`ot.mapping` .
99
The following sub-modules are not imported due to additional dependencies:
1010
- :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
1111
- :any:`ot.plot` : depends on :code:`matplotlib`
@@ -71,4 +71,5 @@
7171
'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample',
7272
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
7373
'binary_search_circle', 'wasserstein_circle',
74-
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn']
74+
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif',
75+
'lowrank_sinkhorn']

ot/lowrank.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -319,17 +319,18 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, re
319319
The function solves the following optimization problem:
320320
321321
.. math::
322-
\mathop{\inf_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle -
323-
\mathrm{reg} \cdot H((Q,R,g))
322+
\mathop{\inf_{(\mathbf{Q},\mathbf{R},\mathbf{g}) \in \mathcal{C}(\mathbf{a},\mathbf{b},r)}} \langle \mathbf{C}, \mathbf{Q}\mathrm{diag}(1/\mathbf{g})\mathbf{R}^\top \rangle -
323+
\mathrm{reg} \cdot H((\mathbf{Q}, \mathbf{R}, \mathbf{g}))
324324
325325
where :
326-
- :math:`C` is the (`dim_a`, `dim_b`) metric cost matrix
327-
- :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term.
328-
- :math: `Q` and `R` are the low-rank matrix decomposition of the OT plan
329-
- :math: `g` is the weight vector for the low-rank decomposition of the OT plan
326+
327+
- :math:`\mathbf{C}` is the (`dim_a`, `dim_b`) metric cost matrix
328+
- :math:`H((\mathbf{Q}, \mathbf{R}, \mathbf{g}))` is the values of the three respective entropies evaluated for each term.
329+
- :math:`\mathbf{Q}` and :math:`\mathbf{R}` are the low-rank matrix decomposition of the OT plan
330+
- :math:`\mathbf{g}` is the weight vector for the low-rank decomposition of the OT plan
330331
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
331-
- :math: `r` is the rank of the OT plan
332-
- :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem
332+
- :math:`r` is the rank of the OT plan
333+
- :math:`\mathcal{C}(\mathbf{a}, \mathbf{b}, r)` are the low-rank couplings of the OT problem
333334
334335
335336
Parameters

0 commit comments

Comments
 (0)