Skip to content

Add smooth_ot_dual #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using PythonOT
DocMeta.setdocmeta!(PythonOT, :DocTestSetup, :(using PythonOT); recursive=true)

makedocs(;
modules=[PythonOT],
modules=[PythonOT, PythonOT.Smooth],
authors="David Widmann",
repo="https://github.com/devmotion/PythonOT.jl/blob/{commit}{path}#{line}",
sitename="PythonOT.jl",
Expand Down
13 changes: 12 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,25 @@ emd
emd2
```

## Entropically regularised optimal transport
## Regularized optimal transport

```@docs
sinkhorn
sinkhorn2
barycenter
```

The submodule `Smooth` contains a function for solving regularized optimal
transport problems with L2- and entropic regularization using the dual
formulation. You can load the submodule with
```julia
using PythonOT.Smooth
```

```@docs
PythonOT.Smooth.smooth_ot_dual
```

## Unbalanced optimal transport

```@docs
Expand Down
1 change: 1 addition & 0 deletions src/PythonOT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export emd, emd2, sinkhorn, sinkhorn2, barycenter, sinkhorn_unbalanced, sinkhorn
const pot = PyCall.PyNULL()

include("lib.jl")
include("smooth.jl")

function __init__()
return copy!(pot, PyCall.pyimport_conda("ot", "pot", "conda-forge"))
Expand Down
61 changes: 61 additions & 0 deletions src/smooth.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
module Smooth

using ..PythonOT: PythonOT
using ..PyCall: PyCall

export smooth_ot_dual

"""
smooth_ot_dual(μ, ν, C, ε; reg_type="l2", kwargs...)

Compute the optimal transport plan for a regularized optimal transport problem
with source and target marginals `μ` and `ν`, cost matrix `C` of size
`(length(μ), length(ν))`, and regularization parameter `ε`.

The optimal transport map `γ` is of the same size as `C` and solves
```math
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\langle \\gamma, C \\rangle
+ \\varepsilon \\Omega(\\gamma),
```
where ``\\Omega(\\gamma)`` is the L2-regularization term
``\\Omega(\\gamma) = \\|\\gamma\\|_F^2/2`` if `reg_type="l2"` (the default) or
the entropic regularization term
``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` if `reg_type="kl"`.

The function solves the dual formulation[^BSR2018]
```math
\\max_{\\alpha, \\beta} \\mu^{\\mathsf{T}} \\alpha + \\nu^{\\mathsf{T}} \\beta −
\\sum_{j} \\delta_{\\Omega}(\\alpha + \\beta_j - C_j),
```
where ``C_j`` is the ``j``th column of the cost matrix and ``\\delta_{\\Omega}`` is the
conjugate of the regularization term ``\\Omega``.

This function is a wrapper of the function
[`smooth_ot_dual`](https://pythonot.github.io/gen_modules/ot.smooth.html#ot.smooth.smooth_ot_dual)
in the Python Optimal Transport package. Keyword arguments are listed in the documentation
of the Python function.

# Examples

```jldoctest; setup=:(using PythonOT.Smooth)
julia> μ = [0.5, 0.2, 0.3];

julia> ν = [0.0, 1.0];

julia> C = [0.0 1.0;
2.0 0.0;
0.5 1.5];

julia> smooth_ot_dual(μ, ν, C, 0.01)
3×2 Matrix{Float64}:
0.0 0.5
0.0 0.2
0.0 0.300001
```

[^BSR2018]: Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. In *Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS)*.
"""
function smooth_ot_dual(μ, ν, C, ε; kwargs...)
return PythonOT.pot.smooth.smooth_ot_dual(μ, ν, C, ε; kwargs...)
end
end