-
Notifications
You must be signed in to change notification settings - Fork 57
Integrate ReverseDiff with ChainRules #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #180 +/- ##
==========================================
+ Coverage 84.36% 84.86% +0.49%
==========================================
Files 18 18
Lines 1721 1777 +56
==========================================
+ Hits 1452 1508 +56
Misses 269 269
Continue to review full report at Codecov.
|
The next I am trying to do is to find the built-in functions that have their Does anyone have any idea about this? |
Currently, I find the functions that have using ChainRulesCore, ChainRules
const noop = (x) -> x
function ruleof(method)
parameters = if isa(method.sig, DataType)
method.sig.parameters
elseif method.sig.body |> typeof == DataType
method.sig.body.parameters
else
[]
end
if length(parameters) >= 2
type2 = parameters[2]
typeof(type2) == DataType && type2.super == Function && return type2.instance
end
# type2 is RuleConfig, use the next parameter
if length(parameters) >= 3
type3 = parameters[3]
typeof(type3) == DataType && type3.super == Function && return type3.instance
end
return noop # return a function to keep type-stable
end
const methods_with_rrule = Set([ruleof(m) for m in methods(rrule)])
And use a similar way, I can find a list of functions which can be the first argument of So far, it seems possible to compare the two sets mentioned above, and import the rules of interest to RD. Is there any flaw in this approach? |
It seems CI is breaking for reasons not related to this PR but would be good to fix it. |
Yeah, and I can't reproduce it under my local ENV, and I am trying to print some debug info and see the details in the CI logs. |
https://github.com/JuliaDiff/ReverseDiff.jl/pull/180/checks?check_run_id=3357730072#step:6:288 Here is the detail. It somehow has many scalar instructions in the tape. |
Above is the tape when run under my local ENV, it seems |
This PR needs more tests of different types of functions, inputs, outputs, kwargs, etc. Then we need to error if the arguments are not suitable, e.g. ChainRules supports struct and arbitrary inputs but ReverseDiff doesn't the way it's defined currently. See the |
This PR doesn't try to introduce the diverse inputs support which |
I found this is caused by a change in DiffRules (JuliaDiff/DiffRules.jl@cbf17ea), it added And, this @mohamed82008 @yebai @devmotion do you have any idea about the reason for this failure? |
If you write a MWE here, I can try to look into it. |
Thank you, here is a MWE: using ForwardDiff, ReverseDiff
using LinearAlgebra: norm
f(x) = norm(x' .* x)
x = rand(5)
test = ForwardDiff.gradient!(DiffResults.GradientResult(x), f, x)
@show DiffResults.gradient(test)
seedx = rand(eltype(x), size(x))
tp = ReverseDiff.GradientTape(f, seedx)
@show ReverseDiff.gradient!(tp, x) Run it on the master branch of ReverseDiff, and ForwardDiff=0.10.18, DiffRules=1.0.2, the two results will be the same. It is these two lines that cause the problem:
|
And, add test cases for varargs and kwargs
@KDr2 has fixed all the comments, I think. Can everyone take another look before it's merged? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
Many thanks, @KDr2 @devmotion @oxinabox and others for the help! |
This PR introduces a new macro
@grad_from_cr
for importrrule
s inChainRuls
toReverseDiff
.