-
Notifications
You must be signed in to change notification settings - Fork 64
Add derivatives_given_output
for scalar functions
#453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
e65c9ed
to
7c69ac8
Compare
Codecov Report
@@ Coverage Diff @@
## master #453 +/- ##
==========================================
+ Coverage 92.88% 92.94% +0.06%
==========================================
Files 14 14
Lines 787 794 +7
==========================================
+ Hits 731 738 +7
Misses 56 56
Continue to review full report at Codecov.
|
@shashi @YingboMa this should be useful for Symbolics.jl |
I guess this could be useful for defining more general rules for broadcasting of |
Yes! You could have something like function rrule(::typeof(broadcast), f, x)
y = broadcast(f, x)
function pullback_broadcast(ȳ)
x̄ = @. only(only(derivatives_given_output(y, f, x))) * ȳ
return NoTangent(), NoTangent(), x̄
end
end which would be pretty efficient. Some approximation of this may exist already by computing the derivative together with the output in the forward pass, then splitting into two separate arrays and remembering the derivatives in the backward pass. Still, the above snippet seems simpler and doesn't require the array of structs to struct of arrays transformation on the fly. The AoS to SoA is a separate issue to think about (I think right now it works out-of-the-box with StructArrays on CPU but fails on GPU, even though it should be fixable). That being said, this PR is just a way to allow package developers to get some optimizations for specific use cases since it is almost for free, separate tools to make this easier should probably be added in separate PRs. |
@@ -173,6 +202,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) | |||
$(__source__) | |||
$(esc(:Ω)) = $call | |||
$(setup_stmts...) | |||
$(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a little surprised that his doesn't break our IsolatedSubmodule
stuff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My reasoning was that ChainRulesCore
was available a few lines above, so it should be available here as well...
IIUC, the key trick for this to work is that the expression returned by scalar_frule_expr
is not escaped, one only escapes the user input, so outside of those escaped subexpression, things are evaluated in the context of ChainRulesCore.
Using this for broadcasting is well and truely outside the scope of this PR We have talked about doing this, and other variations on this for broadcasting for a while. |
hmm why is this breaking ChainRules.jl? |
Seems spurious, locally I get ChainRules test to pass with this PR. We can see if it gets fixed after I push the changes suggested above |
011a23e
to
7492c45
Compare
Actually, I've investigated the failure and it is real on julia 1.5.4 (which is where the integration test is run) but seems unrelated. It is an inference failure on @testset "diag" begin
N = 7
VERSION ≥ v"1.3" && @testset "k=$k" for k in (-1, 0, 2)
test_rrule(diag, randn(N, N), k)
end
end |
Ah, ok. |
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
0d91fdd
to
d8e8f63
Compare
Waiting for @shashi 's review. |
This looks good to me, it should be enough to get differentiation in Symbolics.jl to work. 👍 I love how simple a change this is..! |
Derivatives_given_output is what you want for broadcasting. For |
This follows up on the discussion on Slack. It is a small refactor of
@scalar_rule
such that the actual derivative computation is factored out in an (internal) methodThis allows some optimization if one happens to have both
xs
andΩ
and does not want to recomputef
to get derivatives.This does fix #246 in its original formulation (getting derivative without running primal pass for scalar rules). It is specific to scalar rules, it does not try to define a general API for
frule
andrrule
when the output is known.Notes:
esc
fix, in that_normalize_scalarrules_macro_input
was "lying" about correctly escaping everything. It was not escapingpartials
, which were instead escaped inpropagation_expr
. That caused some complications in this refactor, so I fixed it.EDIT: I've also added some tests of using
derivatives_given_output
independently offrule
andrrule