Skip to content

Add derivatives_given_output for scalar functions #453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 11, 2021

Conversation

piever
Copy link
Contributor

@piever piever commented Sep 8, 2021

This follows up on the discussion on Slack. It is a small refactor of @scalar_rule such that the actual derivative computation is factored out in an (internal) method

derivatives_given_output(Ω, f, xs...) # not sure about the signature, but this seemed reasonable

This allows some optimization if one happens to have both xs and Ω and does not want to recompute f to get derivatives.

This does fix #246 in its original formulation (getting derivative without running primal pass for scalar rules). It is specific to scalar rules, it does not try to define a general API for frule and rrule when the output is known.

Notes:

  • The second commit does a small esc fix, in that _normalize_scalarrules_macro_input was "lying" about correctly escaping everything. It was not escaping partials, which were instead escaped in propagation_expr. That caused some complications in this refactor, so I fixed it.
  • Still needs bumping version number. I'm not sure if it should be 1.4.1 or 1.3.2.

EDIT: I've also added some tests of using derivatives_given_output independently of frule and rrule

@codecov-commenter
Copy link

codecov-commenter commented Sep 8, 2021

Codecov Report

Merging #453 (0d91fdd) into master (3b46ac5) will increase coverage by 0.06%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #453      +/-   ##
==========================================
+ Coverage   92.88%   92.94%   +0.06%     
==========================================
  Files          14       14              
  Lines         787      794       +7     
==========================================
+ Hits          731      738       +7     
  Misses         56       56              
Impacted Files Coverage Δ
src/rule_definition_tools.jl 96.15% <100.00%> (+0.18%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3b46ac5...0d91fdd. Read the comment docs.

@oxinabox
Copy link
Member

oxinabox commented Sep 8, 2021

@shashi @YingboMa this should be useful for Symbolics.jl
@GiggleLiu this should be useful for NiLang

@devmotion
Copy link
Member

I guess this could be useful for defining more general rules for broadcasting of @scalar_rule as well, similar to what exists in NNlib and lib/logexpfunctions.jl in Zygote (LogExpFunctions contains the @scalar_rules now but Zygote defines some additional Zygote-adjoints for broadcast)? On the other hand, I guess the function is not needed for this purpose and we could just build these rules with a macro?

@piever
Copy link
Contributor Author

piever commented Sep 8, 2021

I guess this could be useful for defining more general rules for broadcasting of @scalar_rule as well, similar to what exists in NNlib and lib/logexpfunctions.jl in Zygote

Yes! You could have something like

function rrule(::typeof(broadcast), f, x)
    y = broadcast(f, x)
    function pullback_broadcast(ȳ)
        x̄ = @. only(only(derivatives_given_output(y, f, x))) *return NoTangent(), NoTangent(), x̄
    end
end

which would be pretty efficient. Some approximation of this may exist already by computing the derivative together with the output in the forward pass, then splitting into two separate arrays and remembering the derivatives in the backward pass. Still, the above snippet seems simpler and doesn't require the array of structs to struct of arrays transformation on the fly. The AoS to SoA is a separate issue to think about (I think right now it works out-of-the-box with StructArrays on CPU but fails on GPU, even though it should be fixable).

That being said, this PR is just a way to allow package developers to get some optimizations for specific use cases since it is almost for free, separate tools to make this easier should probably be added in separate PRs.

@@ -173,6 +202,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
$(__source__)
$(esc(:Ω)) = $call
$(setup_stmts...)
$(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little surprised that his doesn't break our IsolatedSubmodule stuff.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning was that ChainRulesCore was available a few lines above, so it should be available here as well...

IIUC, the key trick for this to work is that the expression returned by scalar_frule_expr is not escaped, one only escapes the user input, so outside of those escaped subexpression, things are evaluated in the context of ChainRulesCore.

@oxinabox
Copy link
Member

oxinabox commented Sep 8, 2021

Using this for broadcasting is well and truely outside the scope of this PR

We have talked about doing this, and other variations on this for broadcasting for a while.
One downside is that it basically breaks broadcast fusing.
Many ADs already break broadcast-fusing though.
We could use a configured rule for this maybe?
RuleConfig{>:OKWithBreakingBroadcastFusing}

@oxinabox
Copy link
Member

oxinabox commented Sep 8, 2021

hmm why is this breaking ChainRules.jl?

@piever
Copy link
Contributor Author

piever commented Sep 8, 2021

hmm why is this breaking ChainRules.jl?

Seems spurious, locally I get ChainRules test to pass with this PR. We can see if it gets fixed after I push the changes suggested above

@piever
Copy link
Contributor Author

piever commented Sep 8, 2021

hmm why is this breaking ChainRules.jl?

Actually, I've investigated the failure and it is real on julia 1.5.4 (which is where the integration test is run) but seems unrelated. It is an inference failure on diag here. I've checked and I get the same error for the incriminated test even using the released ChainRulesCore

@testset "diag" begin
    N = 7
    VERSION  v"1.3" && @testset "k=$k" for k in (-1, 0, 2)
        test_rrule(diag, randn(N, N), k)
    end
end

@oxinabox
Copy link
Member

oxinabox commented Sep 10, 2021

Ah, ok.
I don't care about v1.5, I will update the integration tests CI.
#454

@oxinabox
Copy link
Member

Waiting for @shashi 's review.
If we don't have that by 5pm UTC time on Monday, I will merge and we can handle in a followup PR.
It is no big deal since we marked this as breaking.

@shashi
Copy link
Collaborator

shashi commented Sep 10, 2021

This looks good to me, it should be enough to get differentiation in Symbolics.jl to work. 👍 I love how simple a change this is..!

@mcabbott
Copy link
Member

mcabbott commented Sep 11, 2021

One thing this function would be useful for is things like sum(f, xs), where instead of storing all the closures you could (when this is available) store nothing extra and broadcast it at the end. Maybe not.

Derivatives_given_output is what you want for broadcasting. For sum(f, xs), you want to know whether derivatives_given_input is cheap, or not.

@oxinabox oxinabox merged commit 344f3d5 into JuliaDiff:master Sep 11, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Get pullback without running primal pass for @scalar_rules
6 participants