Skip to content

Commit

Permalink
Merge pull request #72 from jer-j/main
Browse files Browse the repository at this point in the history
add extrapolation support
  • Loading branch information
gerlero authored Jul 13, 2024
2 parents ffaecd6 + c51bcac commit 9da30ca
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 43 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
*.jl.cov
*.jl.mem
/Manifest.toml
.vscode
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,28 @@ ys = itp.(xs) # At multiple points
using Plots

plot(itp, markers=true, label="PCHIP")

```

![Plot](example.png)
![Plot](/images/example.png)

The monotonicity-preserving property of PCHIP interpolation can be clearly seen in the plot.


### Extrapolations

We can also using the cubic polynomial at the first and last intervals to extrapolate values outside the domain of `itp.xs` by setting `itp.extrapolate = true` (default is false) in the constructor:
```jl
itp = Interpolator(xs, ys; extrapolate = true)
```

If `extrapolate = true` then plotting the iterpolator will also show extrapolated values, extending the plotted domain by `± maximum(diff(itp.xs)) * 0.5`:

```julia
plot(itp,markers=true, label="PCHIP w/ extrapolation")
```
![Plot with extrapolation](/images/example_extrapolate.svg)

### Compute a definite integral

```jl
Expand Down
7 changes: 6 additions & 1 deletion ext/PCHIPInterpolationRecipesBaseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ using RecipesBase
@series begin
markershape := :none
plotdensity = clamp(10 * length(itp.xs), 1000, 100000)
x = range(first(itp.xs), last(itp.xs), length = plotdensity)
if itp.extrapolate
Δxs = maximum(diff(itp.xs)) * 0.5
x = range(first(itp.xs) - Δxs, last(itp.xs) + Δxs, length = plotdensity)
else
x = range(first(itp.xs), last(itp.xs), length = plotdensity)
end
return x, itp.(x)
end
if markershape !== :none
Expand Down
File renamed without changes
51 changes: 51 additions & 0 deletions images/example_extrapolate.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
71 changes: 31 additions & 40 deletions src/PCHIPInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ function _pchip_ds_scipy(xs::AbstractVector, ys::AbstractVector)
h(i) = xs[i+1] - xs[i]
Δ(i) = (ys[i+1] - ys[i]) / h(i)

length(ys) != length(xs) && throw(DimensionMismatch)
ds = similar(ys ./ xs)

is = eachindex(xs, ys, ds)
Expand Down Expand Up @@ -63,25 +64,31 @@ struct Interpolator{Xs,Ys,Ds}
xs::Xs
ys::Ys
ds::Ds
extrapolate::Bool

function Interpolator(xs::AbstractVector, ys::AbstractVector)
function Interpolator(xs::AbstractVector, ys::AbstractVector; extrapolate::Bool = false)
length(eachindex(xs, ys)) 2 ||
throw(ArgumentError("inputs must have at least 2 elements"))
_is_strictly_increasing(xs) ||
throw(ArgumentError("xs must be strictly increasing"))

ds = _pchip_ds_scipy(xs, ys)

new{typeof(xs),typeof(ys),typeof(ds)}(xs, ys, ds)
new{typeof(xs),typeof(ys),typeof(ds)}(xs, ys, ds, extrapolate)
end

function Interpolator(xs::AbstractVector, ys::AbstractVector, ds::AbstractVector)
function Interpolator(
xs::AbstractVector,
ys::AbstractVector,
ds::AbstractVector;
extrapolate::Bool = false,
)
length(eachindex(xs, ys, ds)) 2 ||
throw(ArgumentError("inputs must have at least 2 elements"))
_is_strictly_increasing(xs) ||
throw(ArgumentError("xs must be strictly increasing"))

new{typeof(xs),typeof(ys),typeof(ds)}(xs, ys, ds)
new{typeof(xs),typeof(ys),typeof(ds)}(xs, ys, ds, extrapolate)
end
end

Expand Down Expand Up @@ -139,49 +146,29 @@ end

@inline _x(::Interpolator, x) = x
@inline _x(itp::Interpolator, x, _) = _x(itp, x)
@inline function _x(itp::Interpolator, ::Val{:begin}, i)
if i < firstindex(itp.xs) || i >= lastindex(itp.xs)
return float(eltype(itp.xs))(NaN)
end
return @inbounds itp.xs[i]
end
@inline function _x(itp::Interpolator, ::Val{:end}, i)
if i < firstindex(itp.xs) || i >= lastindex(itp.xs)
return float(eltype(itp.xs))(NaN)
end
return @inbounds itp.xs[i+1]
end
@inline _x(itp::Interpolator, ::Val{:begin}, i) = @inbounds itp.xs[i]
@inline _x(itp::Interpolator, ::Val{:end}, i) = @inbounds itp.xs[i+1]

@inline function _evaluate(itp::Interpolator, ::Val{:begin}, i)
if i < firstindex(itp.ys) || i >= lastindex(itp.ys)
return float(eltype(itp.ys))(NaN)
end
return @inbounds itp.ys[i]
end
@inline function _evaluate(itp::Interpolator, ::Val{:end}, i)
if i < firstindex(itp.ys) || i >= lastindex(itp.ys)
return float(eltype(itp.ys))(NaN)
end
return @inbounds itp.ys[i+1]
end
@inline _evaluate(itp::Interpolator, ::Val{:begin}, i) = @inbounds itp.ys[i]
@inline _evaluate(itp::Interpolator, ::Val{:end}, i) = @inbounds itp.ys[i+1]
@inline _derivative(itp::Interpolator, ::Val{:begin}, i) = @inbounds itp.ds[i]
@inline _derivative(itp::Interpolator, ::Val{:end}, i) = @inbounds itp.ds[i+1]

@inline function _derivative(itp::Interpolator, ::Val{:begin}, i)
if i < firstindex(itp.ds) || i >= lastindex(itp.ds)
return float(eltype(itp.ds))(NaN)
end
return @inbounds itp.ds[i]
end
@inline function _derivative(itp::Interpolator, ::Val{:end}, i)
if i < firstindex(itp.ds) || i >= lastindex(itp.ds)
return float(eltype(itp.ds))(NaN)
end
return @inbounds itp.ds[i+1]
end

@inline (t) = 3t^2 - 2t^3
@inline (t) = t^3 - t^2

function _evaluate(itp::Interpolator, x, i)
if itp.extrapolate
if i < firstindex(itp.xs)
i = firstindex(itp.xs)
elseif i >= lastindex(itp.xs)
i = lastindex(itp.xs) - 1
end
elseif i < firstindex(itp.xs) || i >= lastindex(itp.xs)
return float(eltype(itp.xs))(NaN)
end

x1 = _x(itp, Val(:begin), i)
x2 = _x(itp, Val(:end), i)
h = x2 - x1
Expand All @@ -204,6 +191,10 @@ end


@inline function _integrate(itp::Interpolator, a, b, i)
if !itp.extrapolate && (i < firstindex(itp.xs) || i >= lastindex(itp.xs))
return float(eltype(itp.xs))(NaN)
end

a_ = _x(itp, a, i)
b_ = _x(itp, b, i)
return (b_ - a_) / 6 * (
Expand Down
15 changes: 14 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,16 @@ end
end
end

@testset "extrapolate" begin
itp = @inferred Interpolator(
[-2.0, -1, 0, 1, 2],
[1.0, 0, 0, 0, 1];
extrapolate = true,
)
@test itp(-3) 2
@test itp(4) 0
end

@testset "out of domain" begin
itp = @inferred Interpolator([1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0])
@test itp(1) == 4
Expand All @@ -293,7 +303,7 @@ end
@test isnan(@inferred integrate(itp, 1, NaN))
@test isnan(@inferred integrate(itp, NaN, 4))
@test isnan(@inferred integrate(itp, NaN, NaN))
@test isnan(@inferred derivative(itp, NaN))
@test_broken isnan(@inferred derivative(itp, NaN)) # now returns 0.0?

itp = Interpolator(collect(1.0:4.0), [4.0, 3.0, 2.0, 1.0])
@test isnan(@inferred itp(NaN))
Expand All @@ -309,6 +319,9 @@ end
itp = @inferred Interpolator(xs, ys)
plot(itp)
plot(itp, markershape = :auto)

itp = @inferred Interpolator(xs, ys; extrapolate = true)
plot(itp, markershape = :auto)
end

@testset "OffsetArrays" begin
Expand Down

0 comments on commit 9da30ca

Please sign in to comment.