Skip to content

Add specialized StaticArrays-based methods #215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
julia 0.6-
StaticArrays 0.5.0
DiffBase 0.0.3
Calculus 0.2.0
NaNMath 0.2.2
Expand Down
1 change: 0 additions & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
environment:
matrix:
- JULIAVERSION: "julianightlies/bin/winnt/x86/julia-latest-win32.exe"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cutting the Gordian knot.

- JULIAVERSION: "julianightlies/bin/winnt/x64/julia-latest-win64.exe"

branches:
Expand Down
24 changes: 24 additions & 0 deletions src/ForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module ForwardDiff

using DiffBase
using DiffBase: DiffResult
using StaticArrays

import Calculus
import NaNMath
Expand Down Expand Up @@ -41,6 +42,29 @@ const REAL_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real)

const DEFAULT_CHUNK_THRESHOLD = 10

struct Chunk{N} end

function Chunk(input_length::Integer, threshold::Integer = DEFAULT_CHUNK_THRESHOLD)
N = pickchunksize(input_length, threshold)
return Chunk{N}()
end

function Chunk(x::AbstractArray, threshold::Integer = DEFAULT_CHUNK_THRESHOLD)
return Chunk(length(x), threshold)
end

# Constrained to `N <= threshold`, minimize (in order of priority):
# 1. the number of chunks that need to be computed
# 2. the number of "left over" perturbations in the final chunk
function pickchunksize(input_length, threshold = DEFAULT_CHUNK_THRESHOLD)
if input_length <= threshold
return input_length
else
nchunks = round(Int, input_length / DEFAULT_CHUNK_THRESHOLD, RoundUp)
return round(Int, input_length / nchunks, RoundUp)
end
end

############
# includes #
############
Expand Down
27 changes: 0 additions & 27 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,6 @@ struct Tag{F,H} end
end
end

#########
# Chunk #
#########

struct Chunk{N} end

function Chunk(input_length::Integer, threshold::Integer = DEFAULT_CHUNK_THRESHOLD)
N = pickchunksize(input_length, threshold)
return Chunk{N}()
end

function Chunk(x::AbstractArray, threshold::Integer = DEFAULT_CHUNK_THRESHOLD)
return Chunk(length(x), threshold)
end

# Constrained to `N <= threshold`, minimize (in order of priority):
# 1. the number of chunks that need to be computed
# 2. the number of "left over" perturbations in the final chunk
function pickchunksize(input_length, threshold = DEFAULT_CHUNK_THRESHOLD)
if input_length <= threshold
return input_length
else
nchunks = round(Int, input_length / DEFAULT_CHUNK_THRESHOLD, RoundUp)
return round(Int, input_length / nchunks, RoundUp)
end
end

##################
# AbstractConfig #
##################
Expand Down
2 changes: 1 addition & 1 deletion src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ end

@inline extract_derivative(y::Dual{T,V,1}) where {T,V} = partials(y, 1)
@inline extract_derivative(y::Real) = zero(y)
@inline extract_derivative(y::AbstractArray) = extract_derivative!(similar(y, valtype(eltype(y))), y)
@inline extract_derivative(y::AbstractArray) = map(extract_derivative, y)

# mutating #
#----------#
Expand Down
2 changes: 1 addition & 1 deletion src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ end
@inline (::Type{Dual{T}})(value::Real, partials::Tuple) where {T} = Dual{T}(value, Partials(partials))
@inline (::Type{Dual{T}})(value::Real, partials::Tuple{}) where {T} = Dual{T}(value, Partials{0,typeof(value)}(partials))
@inline (::Type{Dual{T}})(value::Real, partials::Real...) where {T} = Dual{T}(value, partials)
@inline (::Type{Dual{T}})(value::V, ::Type{Val{N}}, ::Type{Val{i}}) where {T,V<:Real,N,i} = Dual{T}(value, single_seed(Partials{N,V}, Val{i}))
@inline (::Type{Dual{T}})(value::V, ::Chunk{N}, p::Val{i}) where {T,V<:Real,N,i} = Dual{T}(value, single_seed(Partials{N,V}, p))

@inline Dual(args...) = Dual{Void}(args...)

Expand Down
20 changes: 20 additions & 0 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,22 @@ function gradient!(out, f::F, x, cfg::AllowedGradientConfig{F,H} = GradientConfi
return out
end

@inline gradient(f::F, x::SArray) where {F} = vector_mode_gradient(f, x)

@inline gradient!(out, f::F, x::SArray) where {F} = vector_mode_gradient!(out, f, x)

#####################
# result extraction #
#####################

@generated function extract_gradient(y::Real, ::SArray{S,X,D,N}) where {S,X,D,N}
result = Expr(:tuple, [:(partials(y, $i)) for i in 1:N]...)
return quote
$(Expr(:meta, :inline))
return SArray{S}($result)
end
end

function extract_gradient!(out::DiffResult, y::Real)
DiffBase.value!(out, y)
grad = DiffBase.gradient(out)
Expand Down Expand Up @@ -73,6 +85,14 @@ function vector_mode_gradient!(out, f::F, x, cfg) where {F}
return out
end

@inline function vector_mode_gradient(f::F, x::SArray) where F
return extract_gradient(vector_mode_dual_eval(f, x), x)
end

@inline function vector_mode_gradient!(out, f::F, x::SArray) where F
return extract_gradient!(out, vector_mode_dual_eval(f, x))
end

##############
# chunk mode #
##############
Expand Down
6 changes: 6 additions & 0 deletions src/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ function hessian!(out::DiffResult, f::F, x, cfg::AllowedHessianConfig{F,H} = Hes
jacobian!(DiffBase.hessian(out), ∇f!, DiffBase.gradient(out), x, cfg.jacobian_config)
return out
end

hessian(f::F, x::SArray) where {F} = jacobian(y -> gradient(f, y), x)

hessian!(out, f::F, x::SArray) where {F} = jacobian!(out, y -> gradient(f, y), x)

hessian!(out::DiffResult, f::F, x::SArray) where {F} = hessian!(out, f, x, HessianConfig(f, out, x))
29 changes: 29 additions & 0 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,28 @@ function jacobian!(out, f!::F, y, x, cfg::AllowedJacobianConfig{F,H} = JacobianC
return out
end

@inline jacobian(f::F, x::SArray) where {F} = vector_mode_jacobian(f, x)

@inline jacobian!(out, f::F, x::SArray) where {F} = vector_mode_jacobian!(out, f, x)

#####################
# result extraction #
#####################

@generated function extract_jacobian(ydual::SArray{SY,VY,DY,M},
x::SArray{SX,VX,DX,N}) where {SY,VY,DY,M,SX,VX,DX,N}
result = Expr(:tuple, [:(partials(ydual[$i], $j)) for i in 1:M, j in 1:N]...)
return quote
$(Expr(:meta, :inline))
return SArray{Tuple{M,N}}($result)
end
end

function extract_jacobian(ydual::AbstractArray, x::SArray{S,V,D,N}) where {S,V,D,N}
out = similar(ydual, valtype(eltype(ydual)), length(ydual), N)
return extract_jacobian!(out, ydual, N)
end

function extract_jacobian!(out::AbstractArray, ydual::AbstractArray, n)
out_reshaped = reshape(out, length(ydual), n)
for col in 1:size(out_reshaped, 2), row in 1:size(out_reshaped, 1)
Expand Down Expand Up @@ -110,6 +128,17 @@ function vector_mode_jacobian!(out, f!::F, y, x, cfg::JacobianConfig{T,V,N}) whe
return out
end

@inline function vector_mode_jacobian(f::F, x::SArray) where F
return extract_jacobian(vector_mode_dual_eval(f, x), x)
end

@inline function vector_mode_jacobian!(out, f::F, x::SArray{S,V,D,N}) where {F,S,V,D,N}
ydual = vector_mode_dual_eval(f, x)
extract_jacobian!(out, ydual, N)
extract_value!(out, ydual)
return out
end

# chunk mode #
#------------#

Expand Down
2 changes: 1 addition & 1 deletion src/partials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ end
# Utility/Accessor Functions #
##############################

@generated function single_seed(::Type{Partials{N,V}}, ::Type{Val{i}}) where {N,V,i}
@generated function single_seed(::Type{Partials{N,V}}, ::Val{i}) where {N,V,i}
ex = Expr(:tuple, [ifelse(i === j, :(one(V)), :(zero(V))) for j in 1:N]...)
return :(Partials($(ex)))
end
Expand Down
15 changes: 14 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ end
# vector mode function evaluation #
###################################

@generated function dualize(::F, x::SArray{S,V,D,N}) where {F,S,V,D,N}
tag = Tag(F, x)
dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...)
return quote
chunk = Chunk{N}()
T = typeof($tag)
$(Expr(:meta, :inline))
return SArray{S}($(dx))
end
end

@inline vector_mode_dual_eval(f::F, x::SArray) where {F} = f(dualize(f, x))

function vector_mode_dual_eval(f::F, x, cfg::Union{JacobianConfig,GradientConfig}) where F
xdual = cfg.duals
seed!(xdual, x, cfg.seeds)
Expand All @@ -36,7 +49,7 @@ end
##################################

@generated function construct_seeds(::Type{Partials{N,V}}) where {N,V}
return Expr(:tuple, [:(single_seed(Partials{N,V}, Val{$i})) for i in 1:N]...)
return Expr(:tuple, [:(single_seed(Partials{N,V}, Val{$i}())) for i in 1:N]...)
end

function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
Expand Down
25 changes: 25 additions & 0 deletions test/GradientTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Calculus

using Base.Test
using ForwardDiff
using StaticArrays

include(joinpath(dirname(@__FILE__), "utils.jl"))

Expand Down Expand Up @@ -67,4 +68,28 @@ for f in DiffBase.VECTOR_TO_NUMBER_FUNCS
end
end

##########################################
# test specialized StaticArray codepaths #
##########################################

x = rand(3, 3)
sx = StaticArrays.SArray{Tuple{3,3}}(x)
out = similar(x)
actual = ForwardDiff.gradient(prod, x)

@test ForwardDiff.gradient(prod, sx) == actual

ForwardDiff.gradient!(out, prod, sx)

@test out == actual

result = DiffBase.GradientResult(x)
sresult = DiffBase.GradientResult(sx)

ForwardDiff.gradient!(result, prod, x)
ForwardDiff.gradient!(sresult, prod, sx)

@test DiffBase.value(sresult) == DiffBase.value(result)
@test DiffBase.gradient(sresult) == DiffBase.gradient(result)

end # module
26 changes: 26 additions & 0 deletions test/HessianTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Calculus

using Base.Test
using ForwardDiff
using StaticArrays

include(joinpath(dirname(@__FILE__), "utils.jl"))

Expand Down Expand Up @@ -78,4 +79,29 @@ for f in DiffBase.VECTOR_TO_NUMBER_FUNCS
end
end

##########################################
# test specialized StaticArray codepaths #
##########################################

x = rand(3, 3)
sx = StaticArrays.SArray{Tuple{3,3}}(x)
out = similar(x, 9, 9)
actual = ForwardDiff.hessian(prod, x)

@test ForwardDiff.hessian(prod, sx) == actual

ForwardDiff.hessian!(out, prod, sx)

@test out == actual

result = DiffBase.HessianResult(x)
sresult = DiffBase.HessianResult(sx)

ForwardDiff.hessian!(result, prod, x)
ForwardDiff.hessian!(sresult, prod, sx)

@test DiffBase.value(sresult) == DiffBase.value(result)
@test DiffBase.gradient(sresult) == DiffBase.gradient(result)
@test DiffBase.hessian(sresult) == DiffBase.hessian(result)

end # module
25 changes: 25 additions & 0 deletions test/JacobianTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Calculus
using Base.Test
using ForwardDiff
using ForwardDiff: JacobianConfig
using StaticArrays

include(joinpath(dirname(@__FILE__), "utils.jl"))

Expand Down Expand Up @@ -146,4 +147,28 @@ for f! in DiffBase.INPLACE_ARRAY_TO_ARRAY_FUNCS
end
end

##########################################
# test specialized StaticArray codepaths #
##########################################

x = rand(3, 3)
sx = StaticArrays.SArray{Tuple{3,3}}(x)
out = similar(x, 6, 9)
actual = ForwardDiff.jacobian(diff, x)

@test ForwardDiff.jacobian(diff, sx) == actual

ForwardDiff.jacobian!(out, diff, sx)

@test out == actual

result = DiffBase.JacobianResult(similar(x, 6), x)
sresult = DiffBase.JacobianResult(similar(sx, 6), sx)

ForwardDiff.jacobian!(result, diff, x)
ForwardDiff.jacobian!(sresult, diff, sx)

@test DiffBase.value(sresult) == DiffBase.value(result)
@test DiffBase.jacobian(sresult) == DiffBase.jacobian(result)

end # module