Skip to content

Chain rules for certain functions does not respect numerical precision #307

Closed
@torfjelde

Description

@torfjelde

Due to the usage of irrational numbers, some of the functions have adjoints which will mistakenly promote the numerical precision of the derivative/gradient. In particular this occurs because certain impls will first call act on the irrational number which often by default ends up converting the irrational number to Float64. E.g. for erfc we will first call sqrt(π) which results in Float64, and instead of promoting Irrational to what we expected the output-type to be, we end up promoting the output-type to Float64 (if we're using floats with lower precision):

julia> using SpecialFunctions, ChainRulesCore

julia> y, ȳ = ChainRulesCore.frule((ChainRulesCore.NO_FIELDS, 1f0), SpecialFunctions.erfc, 1f0)
(0.1572992f0, -0.41510750774498784)

julia> typeof(y), typeof(ȳ)
(Float32, Float64)

This is essentially the same issue as in DiffRules (JuliaDiff/DiffRules.jl#55).

Anyone got a better idea on what to do here, or should I just make a similar PR to SpecialFunctions.jl?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions