Skip to content

Add derivatives_given_output for scalar functions #453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,15 @@ macro scalar_rule(call, maybe_setup, partials...)
)
f = call.args[1]

frule_expr = scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
# Generate variables to store derivatives named dfi/dxj
derivatives = map(keys(partials)) do i
syms = map(j -> Symbol("∂f", i, "/∂x", j), keys(inputs))
return Expr(:tuple, syms...)
end

derivative_expr = scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials)
frule_expr = scalar_frule_expr(__source__, f, call, [], inputs, derivatives)
rrule_expr = scalar_rrule_expr(__source__, f, call, [], inputs, derivatives)

# Final return: building the expression to insert in the place of this macro
code = quote
Expand All @@ -99,6 +106,7 @@ macro scalar_rule(call, maybe_setup, partials...)
))
end

$(derivative_expr)
$(frule_expr)
$(rrule_expr)
end
Expand Down Expand Up @@ -135,16 +143,45 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
# For consistency in code that follows we make all partials tuple expressions
partials = map(partials) do partial
if Meta.isexpr(partial, :tuple)
partial
Expr(:tuple, map(esc, partial.args)...)
else
length(inputs) == 1 || error("Invalid use of `@scalar_rule`")
Expr(:tuple, partial)
Expr(:tuple, esc(partial))
end
end

return call, setup_stmts, inputs, partials
end

"""
derivatives_given_output(Ω, f, xs...)

Compute the derivative of scalar function `f` at primal input point `xs...`,
given that it had primal output `Ω`.
Return a tuple of tuples with the partial derivatives of `f` with respect to the `xs...`.
The derivative of the `i`-th component of `f` with respect to the `j`-th input can be
accessed as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`.

!!! warning "Experimental"
This function is experimental and not part of the stable API.
At the moment, it can be considered an implementation detail of the macro
[`@scalar_rule`](@ref), in which it is used.
In the future, the exact semantics of this function will stabilize, and it
will be added to the stable API.
When that happens, this warning will be removed.

"""
function derivatives_given_output end

function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials)
return @strip_linenos quote
function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...))
$(__source__)
$(setup_stmts...)
return $(Expr(:tuple, partials...))
end
end
end

function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
n_outputs = length(partials)
Expand Down Expand Up @@ -173,6 +210,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
$(__source__)
$(esc(:Ω)) = $call
$(setup_stmts...)
$(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little surprised that his doesn't break our IsolatedSubmodule stuff.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning was that ChainRulesCore was available a few lines above, so it should be available here as well...

IIUC, the key trick for this to work is that the expression returned by scalar_frule_expr is not escaped, one only escapes the user input, so outside of those escaped subexpression, things are evaluated in the context of ChainRulesCore.

return $(esc(:Ω)), $pushforward_returns
end
end
Expand Down Expand Up @@ -210,6 +248,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
$(__source__)
$(esc(:Ω)) = $call
$(setup_stmts...)
$(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...))
return $(esc(:Ω)), $pullback
end
end
Expand Down Expand Up @@ -240,9 +279,9 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
# This is basically Δs ⋅ ∂s
_∂s = map(∂s) do ∂s_i
if _conj
:(conj($(esc(∂s_i))))
:(conj($∂s_i))
else
esc(∂s_i)
∂s_i
end
end

Expand Down
7 changes: 6 additions & 1 deletion test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ end
@test ẏ == Tangent{typeof(y)}(50f0, 100f0)
# make sure type is exactly as expected:
@test ẏ isa Tangent{Tuple{Irrational{:π}, Float64}, Tuple{Float32, Float32}}

xs, Ω = (3,), (3, 6)
@test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == ((1f0,), (2f0,))
end

@testset "@scalar_rule projection" begin
Expand Down Expand Up @@ -298,7 +301,7 @@ module IsolatedModuleForTestingScoping
module IsolatedSubmodule
# check that rules defined in isolated module without imports can be called
# without errors
using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent
using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output
using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id
using Test

Expand Down Expand Up @@ -328,6 +331,8 @@ module IsolatedModuleForTestingScoping
y, f_pullback = rrule(my_id, x)
@test y == x
@test f_pullback(Δy) == (NoTangent(), Δy)

@test derivatives_given_output(y, my_id, x) == ((1.0,),)
end
end
end