Closed
Description
julia> @benchmark cos(x) setup=(x=0.5)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 4.409 ns (0.00% GC)
median time: 4.421 ns (0.00% GC)
mean time: 4.530 ns (0.00% GC)
maximum time: 101.587 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 1000
julia> @benchmark rrule(sin, x)[2](1.0) setup=(x=0.5)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 7.840 ns (0.00% GC)
median time: 8.102 ns (0.00% GC)
mean time: 8.157 ns (0.00% GC)
maximum time: 35.264 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 999
Because the macro @scalar_rule
generates the following code.
julia> MacroTools.prettify(@macroexpand @scalar_rule sinc(x) cosc(x))
quote
if !(sinc isa ChainRulesCore.Type) && ChainRulesCore.fieldcount(ChainRulesCore.typeof(sinc)) > 0
ChainRulesCore.throw(ChainRulesCore.ArgumentError("@scalar_rule cannot be used on closures/functors (such as $(sinc))"))
end
function (ChainRulesCore.ChainRulesCore).frule((ChainRulesCore._, Δ1), ::ChainRulesCore.typeof(sinc), x::Number)
Ω = sinc(x)
nothing
return (Ω, cosc(x) * Δ1)
end
function (ChainRulesCore.ChainRulesCore).rrule(::ChainRulesCore.typeof(sinc), x::Number)
Ω = sinc(x)
nothing
return (Ω, function sinc_pullback(gull)
return (ChainRulesCore.NO_FIELDS, ChainRulesCore.conj(cosc(x)) * gull)
end)
end
end
In order to make the code more friendly to packages that want to make use of these scalar rules. The following generated code might be better?
julia> MacroTools.prettify(@macroexpand @scalar_rule sinc(x) cosc(x))
quote
if !(sinc isa ChainRulesCore.Type) && ChainRulesCore.fieldcount(ChainRulesCore.typeof(sinc)) > 0
ChainRulesCore.throw(ChainRulesCore.ArgumentError("@scalar_rule cannot be used on closures/functors (such as $(sinc))"))
end
function (ChainRulesCore.ChainRulesCore).frule((ChainRulesCore._, Δ1), ::ChainRulesCore.typeof(sinc), x::Number)
Ω = sinc(x)
nothing
return (Ω, cosc(x) * Δ1)
end
function scalar_pullback(::ChainRulesCore.typeof(sinc), x::Number)
function sinc_pullback(gull)
return (ChainRulesCore.NO_FIELDS, ChainRulesCore.conj(cosc(x)) * gull)
end
end
function (ChainRulesCore.ChainRulesCore).rrule(f::ChainRulesCore.typeof(sinc), x::Number)
Ω = sinc(x)
nothing
return (Ω, scalar_pullback(f, x))
end
end
Note: there are limited number of functions like +
that do not need to know the value of x
. We can define them separately.