Skip to content

Commit

Permalink
Add SlidingConvOptimalTransportDistance
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Jun 26, 2020
1 parent 5f16a7c commit 1b708f7
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ DSP = "0.6"
Distances = "0.8, 0.9"
DocStringExtensions = "0.8"
DoubleFloats = "1.1"
DynamicAxisWarping = "0.1, 0.2"
DynamicAxisWarping = "0.1, 0.2, 0.3"
FillArrays = "0.8"
Hungarian = "0.6"
Lazy = "0.14, 0.15"
Expand Down
1 change: 1 addition & 0 deletions src/SpectralDistances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ DiscretizedRationalDistance,
WelchOptimalTransportDistance,
WelchLPDistance,
ConvOptimalTransportDistance,
SlidingConvOptimalTransportDistance,
RationalOptimalTransportDistance,
OptimalTransportHistogramDistance,
DiscreteGridTransportDistance,
Expand Down
48 changes: 48 additions & 0 deletions src/sinkhorn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,8 @@ It's important to tune the two parameters below, see the docstring for [`sinkhor
- `β = 0.001`
- `dynamic_floor = -10.0`
- `invariant_axis::Int = 0` If this is set to 1 or 2, the distance will be approximately invariant to translations along the invariant axis. As an example, to be invariant to a spectrogram being shifted slightly in time, set `invariant_axis = 2`.
See also [`SlidingConvOptimalTransportDistance`](@ref)
"""
ConvOptimalTransportDistance
Base.@kwdef mutable struct ConvOptimalTransportDistance{T} <: AbstractDistance
Expand Down Expand Up @@ -664,6 +666,52 @@ function evaluate(d::ConvOptimalTransportDistance, A::AbstractMatrix, B::Abstrac
return c
end

"""
SlidingConvOptimalTransportDistance
Similar to [`ConvOptimalTransportDistance`](@ref) but lets the shorter signal slide across the longer signal and returns the minimum distance.
Equivalent to calling
```julia
minimum(distance_profile(d::ConvOptimalTransportDistance, a, b; kwargs...))
```
"""
struct SlidingConvOptimalTransportDistance{D} <: AbstractDistance
d::D
end

function SlidingConvOptimalTransportDistance(;β = 0.001,
dynamic_floor = NaN,
invariant_axis::Int = 0,
workspace = nothing)

SlidingConvOptimalTransportDistance(
ConvOptimalTransportDistance(β, dynamic_floor, invariant_axis, workspace)
)
end

function Distances.evaluate(d::SlidingConvOptimalTransportDistance, a, b; kwargs...)
if size(a,2) > size(b,2)
a,b = b,a
end
minimum(distance_profile(d.d, a, b; kwargs...))
end

function Distances.evaluate(
d::SlidingConvOptimalTransportDistance,
a::DSP.Periodograms.TFR,
b::DSP.Periodograms.TFR;
kwargs...,
)
evaluate(
d,
normalize_spectrogram(a, d.d.dynamic_floor),
normalize_spectrogram(b, d.d.dynamic_floor);
kwargs...,
)
end



"""
normalize_spectrogram(S, dynamic_floor = default_dynamic_floor(S))
"""
Expand Down
4 changes: 4 additions & 0 deletions test/test_convolutional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ end
isinteractive() && plot(D)
@test 40 <= argmin(D) <= 70

dslide = SlidingConvOptimalTransportDistance(d)
@test dslide(A1, A3; tol=1e-6, stride=15) minimum(D)
@test SlidingConvOptimalTransportDistance=0.05, dynamic_floor=-3.0).d.β == SlidingConvOptimalTransportDistance(ConvOptimalTransportDistance=0.05, dynamic_floor=-3.0)).d.β




Expand Down

0 comments on commit 1b708f7

Please sign in to comment.