Skip to content

Get pullback without running primal pass for @scalar_rules #246

Closed
@GiggleLiu

Description

@GiggleLiu
julia> @benchmark cos(x) setup=(x=0.5)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     4.409 ns (0.00% GC)
  median time:      4.421 ns (0.00% GC)
  mean time:        4.530 ns (0.00% GC)
  maximum time:     101.587 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark rrule(sin, x)[2](1.0) setup=(x=0.5)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     7.840 ns (0.00% GC)
  median time:      8.102 ns (0.00% GC)
  mean time:        8.157 ns (0.00% GC)
  maximum time:     35.264 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     999

Because the macro @scalar_rule generates the following code.

julia> MacroTools.prettify(@macroexpand @scalar_rule sinc(x) cosc(x))
quote
    if !(sinc isa ChainRulesCore.Type) && ChainRulesCore.fieldcount(ChainRulesCore.typeof(sinc)) > 0
        ChainRulesCore.throw(ChainRulesCore.ArgumentError("@scalar_rule cannot be used on closures/functors (such as $(sinc))"))
    end
    function (ChainRulesCore.ChainRulesCore).frule((ChainRulesCore._, Δ1), ::ChainRulesCore.typeof(sinc), x::Number)
        Ω = sinc(x)
        nothing
        return (Ω, cosc(x) * Δ1)
    end
    function (ChainRulesCore.ChainRulesCore).rrule(::ChainRulesCore.typeof(sinc), x::Number)
        Ω = sinc(x)
        nothing
        return (Ω, function sinc_pullback(gull)
                    return (ChainRulesCore.NO_FIELDS, ChainRulesCore.conj(cosc(x)) * gull)
                end)
    end
end

In order to make the code more friendly to packages that want to make use of these scalar rules. The following generated code might be better?

julia> MacroTools.prettify(@macroexpand @scalar_rule sinc(x) cosc(x))
quote
    if !(sinc isa ChainRulesCore.Type) && ChainRulesCore.fieldcount(ChainRulesCore.typeof(sinc)) > 0
        ChainRulesCore.throw(ChainRulesCore.ArgumentError("@scalar_rule cannot be used on closures/functors (such as $(sinc))"))
    end
    function (ChainRulesCore.ChainRulesCore).frule((ChainRulesCore._, Δ1), ::ChainRulesCore.typeof(sinc), x::Number)
        Ω = sinc(x)
        nothing
        return (Ω, cosc(x) * Δ1)
    end

    function scalar_pullback(::ChainRulesCore.typeof(sinc), x::Number)
           function sinc_pullback(gull)
                return (ChainRulesCore.NO_FIELDS, ChainRulesCore.conj(cosc(x)) * gull)
           end
    end

    function (ChainRulesCore.ChainRulesCore).rrule(f::ChainRulesCore.typeof(sinc), x::Number)
        Ω = sinc(x)
        nothing
        return (Ω, scalar_pullback(f, x))
    end
end

Note: there are limited number of functions like + that do not need to know the value of x. We can define them separately.

Metadata

Metadata

Assignees

No one assigned

    Labels

    designRequires some desgin before changes are madeenhancementNew feature or requestrule definition helperrelating to helpers for declaring rules

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions