Closed
Description
Due to the usage of irrational numbers, some of the functions have adjoints which will mistakenly promote the numerical precision of the derivative/gradient. In particular this occurs because certain impls will first call act on the irrational number which often by default ends up converting the irrational number to Float64
. E.g. for erfc
we will first call sqrt(π)
which results in Float64
, and instead of promoting Irrational
to what we expected the output-type to be, we end up promoting the output-type to Float64
(if we're using floats with lower precision):
julia> using SpecialFunctions, ChainRulesCore
julia> y, ȳ = ChainRulesCore.frule((ChainRulesCore.NO_FIELDS, 1f0), SpecialFunctions.erfc, 1f0)
(0.1572992f0, -0.41510750774498784)
julia> typeof(y), typeof(ȳ)
(Float32, Float64)
This is essentially the same issue as in DiffRules (JuliaDiff/DiffRules.jl#55).
Anyone got a better idea on what to do here, or should I just make a similar PR to SpecialFunctions.jl?
Metadata
Metadata
Assignees
Labels
No labels