Skip to content

use SpecialFunctions package due to function deprecations in Julia v0.6 #200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ DiffBase 0.0.3
Compat 0.17.0
Calculus 0.2.0
NaNMath 0.2.2
SpecialFunctions 0.1.0
9 changes: 9 additions & 0 deletions src/ForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using DiffBase: DiffResult

import Calculus
import NaNMath
import SpecialFunctions

#############################
# types/functions/constants #
Expand Down Expand Up @@ -35,9 +36,17 @@ end
#---------------------#

const AUTO_DEFINED_UNARY_FUNCS = map(first, Calculus.symbolic_derivatives_1arg())

const NANMATH_FUNCS = (:sin, :cos, :tan, :asin, :acos, :acosh,
:atanh, :log, :log2, :log10, :lgamma, :log1p)

const SPECIAL_FUNCS = (:erf, :erfc, :erfinv, :erfcinv, :erfi, :erfcx,
:dawson, :digamma, :eta, :zeta, :airyai, :airyaiprime,
:airybi, :airybiprime, :airyaix, :besselj, :besselj0,
:besselj1, :besseljx, :bessely, :bessely0, :bessely1,
:besselyx, :besselh, :hankelh1, :hankelh1x, :hankelh2,
:hankelh2x, :besseli, :besselix, :besselk, :besselkx)

# chunk settings #
#----------------#

Expand Down
19 changes: 15 additions & 4 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,21 @@ for fsym in AUTO_DEFINED_UNARY_FUNCS

# exp and sqrt are manually defined below
if !(in(fsym, (:exp, :sqrt)))
@eval begin
@inline function Base.$(fsym)(n::Dual)
$(v) = value(n)
return Dual($(fsym)($v), $(deriv) * partials(n))
is_special_function = in(fsym, SPECIAL_FUNCS)
if is_special_function
@eval begin
@inline function SpecialFunctions.$(fsym)(n::Dual)
$(v) = value(n)
return Dual(SpecialFunctions.$(fsym)($v), $(deriv) * partials(n))
end
end
end
if !(is_special_function) || VERSION < v"0.6.0-dev.2767"
@eval begin
@inline function Base.$(fsym)(n::Dual)
$(v) = value(n)
return Dual(Base.$(fsym)($v), $(deriv) * partials(n))
end
end
end
end
Expand Down
37 changes: 19 additions & 18 deletions test/DualTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ForwardDiff: Partials, Dual, value, partials

import NaNMath
import Calculus
import SpecialFunctions

samerng() = MersenneTwister(1)

Expand Down Expand Up @@ -405,29 +406,29 @@ for N in (0,3), M in (0,4), T in (Int, Float32)
try
v = :v
deriv = Calculus.differentiate(:($(fsym)($v)), v)
is_domain_err_func = fsym in DOMAIN_ERR_FUNCS
is_nanmath_func = fsym in ForwardDiff.NANMATH_FUNCS
is_unsupported_nested_func = fsym in UNSUPPORTED_NESTED_FUNCS
@eval begin
fdnum = $(is_domain_err_func ? FDNUM + 1 : FDNUM)
$(v) = ForwardDiff.value(fdnum)
$(test_approx_diffnums)($(fsym)(fdnum), ForwardDiff.Dual($(fsym)($v), $(deriv) * ForwardDiff.partials(fdnum)))
if $(is_nanmath_func)
$(test_approx_diffnums)(NaNMath.$(fsym)(fdnum), ForwardDiff.Dual(NaNMath.$(fsym)($v), $(deriv) * ForwardDiff.partials(fdnum)))
end

if $(!(is_unsupported_nested_func))
nested_fdnum = $(is_domain_err_func ? NESTED_FDNUM + 1 : NESTED_FDNUM)
$(v) = ForwardDiff.value(nested_fdnum)
$(test_approx_diffnums)($(fsym)(nested_fdnum), ForwardDiff.Dual($(fsym)($v), $(deriv) * ForwardDiff.partials(nested_fdnum)))
if $(is_nanmath_func)
$(test_approx_diffnums)(NaNMath.$(fsym)(nested_fdnum), ForwardDiff.Dual(NaNMath.$(fsym)($v), $(deriv) * ForwardDiff.partials(nested_fdnum)))
is_nanmath_func = in(fsym, ForwardDiff.NANMATH_FUNCS)
is_special_func = in(fsym, ForwardDiff.SPECIAL_FUNCS)
is_domain_err_func = in(fsym, DOMAIN_ERR_FUNCS)
is_unsupported_nested_func = in(fsym, UNSUPPORTED_NESTED_FUNCS)
tested_funcs = Vector{Expr}(0)
is_nanmath_func && push!(tested_funcs, :(NaNMath.$(fsym)))
is_special_func && push!(tested_funcs, :(SpecialFunctions.$(fsym)))
(!(is_special_func) || VERSION < v"0.6.0-dev.2767") && push!(tested_funcs, :(Base.$(fsym)))
for func in tested_funcs
@eval begin
fdnum = $(is_domain_err_func ? FDNUM + 1 : FDNUM)
$(v) = ForwardDiff.value(fdnum)
$(test_approx_diffnums)($(func)(fdnum), ForwardDiff.Dual($(func)($v), $(deriv) * ForwardDiff.partials(fdnum)))
if $(!(is_unsupported_nested_func))
nested_fdnum = $(is_domain_err_func ? NESTED_FDNUM + 1 : NESTED_FDNUM)
$(v) = ForwardDiff.value(nested_fdnum)
$(test_approx_diffnums)($(func)(nested_fdnum), ForwardDiff.Dual($(func)($v), $(deriv) * ForwardDiff.partials(nested_fdnum)))
end
end
end
catch err
warn("Encountered error when testing $(fsym)(::Dual):")
throw(err)
rethrow(err)
end
end
end
Expand Down