-
Notifications
You must be signed in to change notification settings - Fork 64
Add derivatives_given_output
for scalar functions
#453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f37d6a3
6d29b7c
69e1e4f
4fd98d4
7d99bb6
2339e05
d8e8f63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,8 +88,15 @@ macro scalar_rule(call, maybe_setup, partials...) | |
) | ||
f = call.args[1] | ||
|
||
frule_expr = scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) | ||
rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) | ||
# Generate variables to store derivatives named dfi/dxj | ||
derivatives = map(keys(partials)) do i | ||
syms = map(j -> Symbol("∂f", i, "/∂x", j), keys(inputs)) | ||
return Expr(:tuple, syms...) | ||
end | ||
|
||
derivative_expr = scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) | ||
frule_expr = scalar_frule_expr(__source__, f, call, [], inputs, derivatives) | ||
rrule_expr = scalar_rrule_expr(__source__, f, call, [], inputs, derivatives) | ||
|
||
# Final return: building the expression to insert in the place of this macro | ||
code = quote | ||
|
@@ -99,6 +106,7 @@ macro scalar_rule(call, maybe_setup, partials...) | |
)) | ||
end | ||
|
||
$(derivative_expr) | ||
$(frule_expr) | ||
$(rrule_expr) | ||
end | ||
|
@@ -135,16 +143,45 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) | |
# For consistency in code that follows we make all partials tuple expressions | ||
partials = map(partials) do partial | ||
if Meta.isexpr(partial, :tuple) | ||
partial | ||
Expr(:tuple, map(esc, partial.args)...) | ||
else | ||
length(inputs) == 1 || error("Invalid use of `@scalar_rule`") | ||
Expr(:tuple, partial) | ||
Expr(:tuple, esc(partial)) | ||
end | ||
end | ||
|
||
return call, setup_stmts, inputs, partials | ||
end | ||
|
||
""" | ||
derivatives_given_output(Ω, f, xs...) | ||
|
||
Compute the derivative of scalar function `f` at primal input point `xs...`, | ||
given that it had primal output `Ω`. | ||
Return a tuple of tuples with the partial derivatives of `f` with respect to the `xs...`. | ||
The derivative of the `i`-th component of `f` with respect to the `j`-th input can be | ||
accessed as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`. | ||
|
||
!!! warning "Experimental" | ||
This function is experimental and not part of the stable API. | ||
At the moment, it can be considered an implementation detail of the macro | ||
[`@scalar_rule`](@ref), in which it is used. | ||
In the future, the exact semantics of this function will stabilize, and it | ||
will be added to the stable API. | ||
When that happens, this warning will be removed. | ||
|
||
""" | ||
function derivatives_given_output end | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) | ||
return @strip_linenos quote | ||
function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) | ||
$(__source__) | ||
$(setup_stmts...) | ||
return $(Expr(:tuple, partials...)) | ||
end | ||
end | ||
end | ||
|
||
function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) | ||
n_outputs = length(partials) | ||
|
@@ -173,6 +210,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) | |
$(__source__) | ||
$(esc(:Ω)) = $call | ||
$(setup_stmts...) | ||
$(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a little surprised that his doesn't break our There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My reasoning was that IIUC, the key trick for this to work is that the expression returned by |
||
return $(esc(:Ω)), $pushforward_returns | ||
end | ||
end | ||
|
@@ -210,6 +248,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) | |
$(__source__) | ||
$(esc(:Ω)) = $call | ||
$(setup_stmts...) | ||
$(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) | ||
return $(esc(:Ω)), $pullback | ||
end | ||
end | ||
|
@@ -240,9 +279,9 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) | |
# This is basically Δs ⋅ ∂s | ||
_∂s = map(∂s) do ∂s_i | ||
if _conj | ||
:(conj($(esc(∂s_i)))) | ||
:(conj($∂s_i)) | ||
else | ||
esc(∂s_i) | ||
∂s_i | ||
end | ||
end | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.