-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #47 from tlycken/scaling
RFC: Scaling of interpolation objects (fixes #25)
- Loading branch information
Showing
20 changed files
with
324 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,39 @@ | ||
nindexes(N::Int) = N == 1 ? "1 index" : "$N indexes" | ||
|
||
|
||
type FilledInterpolation{T,N,ITP<:AbstractInterpolation,IT,GT,FT} <: AbstractExtrapolation{T,N,ITP,IT,GT} | ||
type FilledExtrapolation{T,N,ITP<:AbstractInterpolation,IT,GT,FT} <: AbstractExtrapolation{T,N,ITP,IT,GT} | ||
itp::ITP | ||
fillvalue::FT | ||
end | ||
@doc """ | ||
`FilledInterpolation(itp, fillvalue)` creates an extrapolation object that returns the `fillvalue` any time the indexes in `itp[x1,x2,...]` are out-of-bounds. | ||
""" | ||
`FilledExtrapolation(itp, fillvalue)` creates an extrapolation object that returns the `fillvalue` any time the indexes in `itp[x1,x2,...]` are out-of-bounds. | ||
By comparison with `extrapolate`, this version lets you control the `fillvalue`'s type directly. It's important for the `fillvalue` to be of the same type as returned by `itp[x1,x2,...]` for in-bounds regions for the index types you are using; otherwise, indexing will be type-unstable (and slow). | ||
""" -> | ||
function FilledInterpolation{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, fillvalue) | ||
FilledInterpolation{T,N,typeof(itp),IT,GT,typeof(fillvalue)}(itp, fillvalue) | ||
""" | ||
function FilledExtrapolation{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, fillvalue) | ||
FilledExtrapolation{T,N,typeof(itp),IT,GT,typeof(fillvalue)}(itp, fillvalue) | ||
end | ||
|
||
@doc """ | ||
""" | ||
`extrapolate(itp, fillvalue)` creates an extrapolation object that returns the `fillvalue` any time the indexes in `itp[x1,x2,...]` are out-of-bounds. | ||
""" -> | ||
extrapolate{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, fillvalue) = FilledInterpolation(itp, convert(eltype(itp), fillvalue)) | ||
""" | ||
extrapolate{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, fillvalue) = FilledExtrapolation(itp, convert(eltype(itp), fillvalue)) | ||
|
||
@generated function getindex{T,N}(fitp::FilledInterpolation{T,N}, args::Number...) | ||
@generated function getindex{T,N}(fitp::FilledExtrapolation{T,N}, args::Number...) | ||
n = length(args) | ||
n == N || return error("Must index $(N)-dimensional interpolation objects with $(nindexes(N))") | ||
meta = Expr(:meta, :inline) | ||
quote | ||
$meta | ||
# Check to see if we're in the extrapolation region, i.e., | ||
# out-of-bounds in an index | ||
@nexprs $N d->((args[d] < 1 || args[d] > size(fitp.itp, d)) && return fitp.fillvalue) | ||
@nexprs $N d->((args[d] < lbound(fitp,d) || args[d] > ubound(fitp, d)) && return fitp.fillvalue) | ||
# In the interpolation region | ||
return getindex(fitp.itp,args...) | ||
end | ||
end | ||
|
||
getindex{T}(fitp::FilledInterpolation{T,1}, x::Number, y::Int) = y == 1 ? fitp[x] : throw(BoundsError()) | ||
getindex{T}(fitp::FilledExtrapolation{T,1}, x::Number, y::Int) = y == 1 ? fitp[x] : throw(BoundsError()) | ||
|
||
lbound(etp::FilledExtrapolation, d) = lbound(etp.itp, d) | ||
ubound(etp::FilledExtrapolation, d) = ubound(etp.itp, d) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,13 @@ | ||
@generated function getindex{T}(exp::AbstractExtrapolation{T,1}, x) | ||
@generated function getindex{T}(etp::AbstractExtrapolation{T,1}, x) | ||
quote | ||
$(extrap_prep(exp, x)) | ||
exp.itp[x] | ||
$(extrap_prep(etp, x)) | ||
etp.itp[x] | ||
end | ||
end | ||
|
||
@generated function getindex{T,N,ITP,GT}(exp::AbstractExtrapolation{T,N,ITP,GT}, xs...) | ||
@generated function getindex{T,N,ITP,GT}(etp::AbstractExtrapolation{T,N,ITP,GT}, xs...) | ||
quote | ||
$(extrap_prep(exp, xs...)) | ||
exp.itp[xs...] | ||
$(extrap_prep(etp, xs...)) | ||
etp.itp[xs...] | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
export ScaledInterpolation | ||
|
||
type ScaledInterpolation{T,N,ITPT,IT,GT,RT} <: AbstractInterpolationWrapper{T,N,ITPT,IT,GT} | ||
itp::ITPT | ||
ranges::RT | ||
end | ||
ScaledInterpolation{T,ITPT,IT,GT,RT}(::Type{T}, N, itp::ITPT, ::Type{IT}, ::Type{GT}, ranges::RT) = | ||
ScaledInterpolation{T,N,ITPT,IT,GT,RT}(itp, ranges) | ||
""" | ||
`scale(itp, xs, ys, ...)` scales an existing interpolation object to allow for indexing using other coordinate axes than unit ranges, by wrapping the interpolation object and transforming the indices from the provided axes onto unit ranges upon indexing. | ||
The parameters `xs` etc must be either ranges or linspaces, and there must be one coordinate range/linspace for each dimension of the interpolation object. | ||
For every `NoInterp` dimension of the interpolation object, the range must be exactly `1:size(itp, d)`. | ||
""" | ||
function scale{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, ranges::Range...) | ||
length(ranges) == N || throw(ArgumentError("Must scale $N-dimensional interpolation object with exactly $N ranges (you used $(length(ranges)))")) | ||
for d in 1:N | ||
if iextract(IT,d) != NoInterp | ||
length(ranges[d]) == size(itp,d) || throw(ArgumentError("The length of the range in dimension $d ($(length(ranges[d]))) did not equal the size of the interpolation object in that direction ($(size(itp,d)))")) | ||
elseif ranges[d] != 1:size(itp,d) | ||
throw(ArgumentError("NoInterp dimension $d must be scaled with unit range 1:$(size(itp,d))")) | ||
end | ||
end | ||
|
||
ScaledInterpolation(T,N,itp,IT,GT,ranges) | ||
end | ||
|
||
@generated function getindex{T,N,ITPT,IT<:DimSpec}(sitp::ScaledInterpolation{T,N,ITPT,IT}, xs::Number...) | ||
length(xs) == N || throw(ArgumentError("Must index into $N-dimensional scaled interpolation object with exactly $N indices (you used $(length(xs)))")) | ||
interp_types = length(IT.parameters) == N ? IT.parameters : tuple([IT.parameters[1] for _ in 1:N]...) | ||
interp_dimens = map(it -> interp_types[it] != NoInterp, 1:N) | ||
interp_indices = map(i -> interp_dimens[i] ? :(coordlookup(sitp.ranges[$i], xs[$i])) : :(xs[$i]), 1:N) | ||
return :(getindex(sitp.itp, $(interp_indices...))) | ||
end | ||
|
||
getindex{T}(sitp::ScaledInterpolation{T,1}, x::Number, y::Int) = y == 1 ? sitp[x] : throw(BoundsError()) | ||
|
||
size(sitp::ScaledInterpolation, d) = size(sitp.itp, d) | ||
lbound{T,N,ITPT,IT}(sitp::ScaledInterpolation{T,N,ITPT,IT,OnGrid}, d) = 1 <= d <= N ? sitp.ranges[d][1] : throw(BoundsError()) | ||
lbound{T,N,ITPT,IT}(sitp::ScaledInterpolation{T,N,ITPT,IT,OnCell}, d) = 1 <= d <= N ? sitp.ranges[d][1] - boundstep(sitp.ranges[d]) : throw(BoundsError()) | ||
ubound{T,N,ITPT,IT}(sitp::ScaledInterpolation{T,N,ITPT,IT,OnGrid}, d) = 1 <= d <= N ? sitp.ranges[d][end] : throw(BoundsError()) | ||
ubound{T,N,ITPT,IT}(sitp::ScaledInterpolation{T,N,ITPT,IT,OnCell}, d) = 1 <= d <= N ? sitp.ranges[d][end] + boundstep(sitp.ranges[d]) : throw(BoundsError()) | ||
|
||
boundstep(r::LinSpace) = ((r.stop - r.start) / r.divisor) / 2 | ||
boundstep(r::FloatRange) = r.step / 2 | ||
boundstep(r::StepRange) = r.step / 2 | ||
boundstep(r::UnitRange) = 1//2 | ||
|
||
""" | ||
Returns *half* the width of one step of the range. | ||
This function is used to calculate the upper and lower bounds of `OnCell` interpolation objects. | ||
""" boundstep | ||
|
||
coordlookup(r::LinSpace, x) = (r.divisor * x + r.stop - r.len * r.start) / (r.stop - r.start) | ||
coordlookup(r::FloatRange, x) = (r.divisor * x - r.start) / r.step + one(eltype(r)) | ||
coordlookup(r::StepRange, x) = (x - r.start) / r.step + one(eltype(r)) | ||
coordlookup(r::UnitRange, x) = x - r.start + one(eltype(r)) | ||
coordlookup(i::Bool, r::Range, x) = i ? coordlookup(r, x) : convert(typeof(coordlookup(r,x)), x) | ||
|
||
gradient{T,N,ITPT,IT<:DimSpec}(sitp::ScaledInterpolation{T,N,ITPT,IT}, xs::Number...) = gradient!(Array(T,count_interp_dims(IT,N)), sitp, xs...) | ||
@generated function gradient!{T,N,ITPT,IT}(g, sitp::ScaledInterpolation{T,N,ITPT,IT}, xs::Number...) | ||
ndims(g) == 1 || throw(DimensionMismatch("g must be a vector (but had $(ndims(g)) dimensions)")) | ||
length(xs) == N || throw(DimensionMismatch("Must index into $N-dimensional scaled interpolation object with exactly $N indices (you used $(length(xs)))")) | ||
|
||
interp_types = length(IT.parameters) == N ? IT.parameters : tuple([IT.parameters[1] for _ in 1:N]...) | ||
interp_dimens = map(it -> interp_types[it] != NoInterp, 1:N) | ||
interp_indices = map(i -> interp_dimens[i] ? :(coordlookup(sitp.ranges[$i], xs[$i])) : :(xs[$i]), 1:N) | ||
|
||
quote | ||
length(g) == $(count_interp_dims(IT, N)) || throw(ArgumentError(string("The length of the provided gradient vector (", length(g), ") did not match the number of interpolating dimensions (", $(count_interp_dims(IT, N)), ")"))) | ||
gradient!(g, sitp.itp, $(interp_indices...)) | ||
for i in eachindex(g) | ||
g[i] = rescale_gradient(sitp.ranges[i], g[i]) | ||
end | ||
g | ||
end | ||
end | ||
|
||
rescale_gradient(r::LinSpace, g) = g * r.divisor / (r.stop - r.start) | ||
rescale_gradient(r::FloatRange, g) = g * r.divisor / r.step | ||
rescale_gradient(r::StepRange, g) = g / r.step | ||
rescale_gradient(r::UnitRange, g) = g | ||
|
||
""" | ||
`rescale_gradient(r::Range)` | ||
Implements the chain rule dy/dx = dy/du * du/dx for use when calculating gradients with scaled interpolation objects. | ||
""" rescale_gradient |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.