From 152ce4a164a4602bf8804090e79b1ca59772f70c Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 22 Jan 2019 10:07:42 +0000 Subject: [PATCH] conversions for dual numbers --- src/tracker/lib/real.jl | 6 ++++++ test/tracker.jl | 2 ++ 2 files changed, 8 insertions(+) diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index b1fbb19fd8..a4f90a0c1d 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -99,6 +99,12 @@ import Base:^ ^(a::TrackedReal, b::Integer) = track(^, a, b) +# Hack for conversions + +using ForwardDiff: Dual + +(T::Type{<:Real})(x::Dual) = Dual(T(x.value), map(T, x.partials.values)) + # Tuples struct TrackedTuple{T<:Tuple} diff --git a/test/tracker.jl b/test/tracker.jl index 6b35f9cfc2..4380402e80 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -189,6 +189,8 @@ end @test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2)) +@test gradtest(x -> Float64.(x), 5) + @testset "equality & order" begin # TrackedReal @test param(2)^2 == param(4)