Skip to content

Commit

Permalink
Support specializing on functions (#615)
Browse files Browse the repository at this point in the history
* specialize on function in jacobian

* specilize on function parameters for derivative, gradient, hessian
  • Loading branch information
j-fu authored Dec 14, 2022
1 parent 76335e6 commit 6a19554
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 33 deletions.
8 changes: 4 additions & 4 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ stored in `y`.
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
"""
@inline function derivative(f!, y::AbstractArray, x::Real,
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T, CHK}
@inline function derivative(f!::F, y::AbstractArray, x::Real,
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F, T, CHK}
require_one_based_indexing(y)
CHK && checktag(T, f!, x)
ydual = cfg.duals
Expand Down Expand Up @@ -60,8 +60,8 @@ called as `f!(y, x)` where the result is stored in `y`.
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
"""
@inline function derivative!(result::Union{AbstractArray,DiffResult},
f!, y::AbstractArray, x::Real,
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T, CHK}
f!::F, y::AbstractArray, x::Real,
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F, T, CHK}
result isa DiffResult ? require_one_based_indexing(y) : require_one_based_indexing(result, y)
CHK && checktag(T, f!, x)
ydual = cfg.duals
Expand Down
14 changes: 7 additions & 7 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ This method assumes that `isa(f(x), Real)`.
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
"""
function gradient(f, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK}
function gradient(f::F, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {F, T, CHK}
require_one_based_indexing(x)
CHK && checktag(T, f, x)
if chunksize(cfg) == length(x)
Expand Down Expand Up @@ -43,13 +43,13 @@ function gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArr
return result
end

@inline gradient(f, x::StaticArray) = vector_mode_gradient(f, x)
@inline gradient(f, x::StaticArray, cfg::GradientConfig) = gradient(f, x)
@inline gradient(f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient(f, x)
@inline gradient(f::F, x::StaticArray) where F = vector_mode_gradient(f, x)
@inline gradient(f::F, x::StaticArray, cfg::GradientConfig) where F = gradient(f, x)
@inline gradient(f::F, x::StaticArray, cfg::GradientConfig, ::Val) where F = gradient(f, x)

@inline gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_gradient!(result, f, x)
@inline gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig) = gradient!(result, f, x)
@inline gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient!(result, f, x)
@inline gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where F = vector_mode_gradient!(result, f, x)
@inline gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig) where F = gradient!(result, f, x)
@inline gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig, ::Val) where F = gradient!(result, f, x)

gradient(f, x::Real) = throw(DimensionMismatch("gradient(f, x) expects that x is an array. Perhaps you meant derivative(f, x)?"))

Expand Down
23 changes: 11 additions & 12 deletions src/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This method assumes that `isa(f(x), Real)`.
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
"""
function hessian(f, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {T,CHK}
function hessian(f::F, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {F, T,CHK}
require_one_based_indexing(x)
CHK && checktag(T, f, x)
∇f = y -> gradient(f, y, cfg.gradient_config, Val{false}())
Expand All @@ -28,7 +28,7 @@ This method assumes that `isa(f(x), Real)`.
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
"""
function hessian!(result::AbstractArray, f, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {T,CHK}
function hessian!(result::AbstractArray, f::F, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
require_one_based_indexing(result, x)
CHK && checktag(T, f, x)
∇f = y -> gradient(f, y, cfg.gradient_config, Val{false}())
Expand Down Expand Up @@ -63,26 +63,25 @@ because `isa(result, DiffResult)`, `cfg` is constructed as `HessianConfig(f, res
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
"""
function hessian!(result::DiffResult, f, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, result, x), ::Val{CHK}=Val{true}()) where {T,CHK}
require_one_based_indexing(x)
function hessian!(result::DiffResult, f::F, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, result, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
CHK && checktag(T, f, x)
∇f! = InnerGradientForHess(result, cfg, f)
jacobian!(DiffResults.hessian(result), ∇f!, DiffResults.gradient(result), x, cfg.jacobian_config, Val{false}())
return ∇f!.result
end

hessian(f, x::StaticArray) = jacobian(y -> gradient(f, y), x)
hessian(f, x::StaticArray, cfg::HessianConfig) = hessian(f, x)
hessian(f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian(f, x)
hessian(f::F, x::StaticArray) where F = jacobian(y -> gradient(f, y), x)
hessian(f::F, x::StaticArray, cfg::HessianConfig) where F = hessian(f, x)
hessian(f::F, x::StaticArray, cfg::HessianConfig, ::Val) where F = hessian(f, x)

hessian!(result::AbstractArray, f, x::StaticArray) = jacobian!(result, y -> gradient(f, y), x)
hessian!(result::AbstractArray, f::F, x::StaticArray) where F = jacobian!(result, y -> gradient(f, y), x)

hessian!(result::MutableDiffResult, f, x::StaticArray) = hessian!(result, f, x, HessianConfig(f, result, x))
hessian!(result::MutableDiffResult, f::F, x::StaticArray) where F = hessian!(result, f, x, HessianConfig(f, result, x))

hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig) = hessian!(result, f, x)
hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian!(result, f, x)
hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig) where F = hessian!(result, f, x)
hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig, ::Val) where F = hessian!(result, f, x)

function hessian!(result::ImmutableDiffResult, f, x::StaticArray)
function hessian!(result::ImmutableDiffResult, f::F, x::StaticArray) where F
T = typeof(Tag(f, eltype(x)))
d1 = dualize(T, x)
d2 = dualize(T, d1)
Expand Down
20 changes: 10 additions & 10 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ This method assumes that `isa(f(x), AbstractArray)`.
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
"""
function jacobian(f, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {T,CHK}
function jacobian(f::F, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
require_one_based_indexing(x)
CHK && checktag(T, f, x)
if chunksize(cfg) == length(x)
Expand All @@ -33,7 +33,7 @@ stored in `y`.
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
"""
function jacobian(f!, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T, CHK}
function jacobian(f!::F, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F,T, CHK}
require_one_based_indexing(y, x)
CHK && checktag(T, f!, x)
if chunksize(cfg) == length(x)
Expand All @@ -54,7 +54,7 @@ This method assumes that `isa(f(x), AbstractArray)`.
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
"""
function jacobian!(result::Union{AbstractArray,DiffResult}, f, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK}
function jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T, CHK}
result isa DiffResult ? require_one_based_indexing(x) : require_one_based_indexing(result, x)
CHK && checktag(T, f, x)
if chunksize(cfg) == length(x)
Expand All @@ -75,7 +75,7 @@ This method assumes that `isa(f(x), AbstractArray)`.
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
"""
function jacobian!(result::Union{AbstractArray,DiffResult}, f!, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T,CHK}
function jacobian!(result::Union{AbstractArray,DiffResult}, f!::F, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
result isa DiffResult ? require_one_based_indexing(y, x) : require_one_based_indexing(result, y, x)
CHK && checktag(T, f!, x)
if chunksize(cfg) == length(x)
Expand All @@ -86,13 +86,13 @@ function jacobian!(result::Union{AbstractArray,DiffResult}, f!, y::AbstractArray
return result
end

@inline jacobian(f, x::StaticArray) = vector_mode_jacobian(f, x)
@inline jacobian(f, x::StaticArray, cfg::JacobianConfig) = jacobian(f, x)
@inline jacobian(f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian(f, x)
@inline jacobian(f::F, x::StaticArray) where F = vector_mode_jacobian(f, x)
@inline jacobian(f::F, x::StaticArray, cfg::JacobianConfig) where F = jacobian(f, x)
@inline jacobian(f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where F = jacobian(f, x)

@inline jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_jacobian!(result, f, x)
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig) = jacobian!(result, f, x)
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian!(result, f, x)
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where F = vector_mode_jacobian!(result, f, x)
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig) where F = jacobian!(result, f, x)
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where F = jacobian!(result, f, x)

jacobian(f, x::Real) = throw(DimensionMismatch("jacobian(f, x) expects that x is an array. Perhaps you meant derivative(f, x)?"))

Expand Down

0 comments on commit 6a19554

Please sign in to comment.