Skip to content

Integrate ReverseDiff with ChainRules #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Oct 27, 2021
Merged

Integrate ReverseDiff with ChainRules #180

merged 35 commits into from
Oct 27, 2021

Conversation

KDr2
Copy link
Contributor

@KDr2 KDr2 commented Jul 13, 2021

This PR introduces a new macro @grad_from_cr for import rrules in ChainRuls to ReverseDiff.

@codecov-commenter
Copy link

codecov-commenter commented Jul 27, 2021

Codecov Report

Merging #180 (90a55ad) into master (01041c8) will increase coverage by 0.49%.
The diff coverage is 96.36%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #180      +/-   ##
==========================================
+ Coverage   84.36%   84.86%   +0.49%     
==========================================
  Files          18       18              
  Lines        1721     1777      +56     
==========================================
+ Hits         1452     1508      +56     
  Misses        269      269              
Impacted Files Coverage Δ
src/ReverseDiff.jl 100.00% <ø> (ø)
src/macros.jl 93.40% <96.36%> (+1.14%) ⬆️
src/api/tape.jl 75.00% <0.00%> (+0.28%) ⬆️
src/derivatives/broadcast.jl 91.01% <0.00%> (+1.12%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 01041c8...90a55ad. Read the comment docs.

@ChrisRackauckas ChrisRackauckas marked this pull request as ready for review July 27, 2021 15:56
@KDr2
Copy link
Contributor Author

KDr2 commented Aug 3, 2021

The next I am trying to do is to find the built-in functions that have their rrules in ChainRules but don't have grad rules in ReverseDiff.jl, then import these rrules to this package.

Does anyone have any idea about this?

@KDr2
Copy link
Contributor Author

KDr2 commented Aug 4, 2021

to find the built-in functions that have their rrules in ChainRules but don't have grad rules in ReverseDiff.jl

Currently, I find the functions that have rrules in ChainRules by the following code:

using ChainRulesCore, ChainRules

const noop = (x) -> x

function ruleof(method)
    parameters = if isa(method.sig, DataType)
        method.sig.parameters
    elseif method.sig.body |> typeof == DataType
        method.sig.body.parameters
    else
        []
    end
    if length(parameters) >= 2
        type2 = parameters[2]
        typeof(type2) == DataType && type2.super == Function && return type2.instance
    end

    # type2 is RuleConfig, use the next parameter
    if length(parameters) >= 3
        type3 = parameters[3]
        typeof(type3) == DataType && type3.super == Function && return type3.instance
    end

    return noop # return a function to keep type-stable
end

const methods_with_rrule = Set([ruleof(m) for m in methods(rrule)])

methods_with_rrule is a Set{Function}.

And use a similar way, I can find a list of functions which can be the first argument of ReverseDiff.track(i.e., the functions have customized derivative rules provided via @grad). This, plus the list from DiffRules.diffrules(), I get the functions who have their derivative rules in ReverseDiff.

So far, it seems possible to compare the two sets mentioned above, and import the rules of interest to RD. Is there any flaw in this approach?

@yebai
Copy link
Contributor

yebai commented Aug 17, 2021

It seems CI is breaking for reasons not related to this PR but would be good to fix it.

@KDr2
Copy link
Contributor Author

KDr2 commented Aug 18, 2021

It seems CI is breaking for reasons not related to this PR but would be good to fix it.

Yeah, and I can't reproduce it under my local ENV, and I am trying to print some debug info and see the details in the CI logs.

@KDr2
Copy link
Contributor Author

KDr2 commented Aug 18, 2021

It seems CI is breaking for reasons not related to this PR but would be good to fix it.

Yeah, and I can't reproduce it under my local ENV, and I am trying to print some debug info and see the details in the CI logs.

https://github.com/JuliaDiff/ReverseDiff.jl/pull/180/checks?check_run_id=3357730072#step:6:288 Here is the detail. It somehow has many scalar instructions in the tape.

@KDr2
Copy link
Contributor Author

KDr2 commented Aug 18, 2021

https://github.com/JuliaDiff/ReverseDiff.jl/pull/180/checks?check_run_id=3357730072#step:6:288 Here is the detail. It somehow has many scalar instructions in the tape.

  testing Array -> Number functions: `#19`...
f = Main.LinAlgTests.var"#19#23"()
x = [0.8701349878780872 0.28204319051637505 0.9123350298255161; 0.006457014103689707 0.5531597766343656 0.4730983492512397; 0.773618934964177 0.08793093987132328 0.8364989250826362]
1-element InstructionTape:
tp = 1 => SpecialInstruction(*):
  input:  (1×9 adjoint(::Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}) with eltype ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}},
           9-element ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}})
  output: TrackedReal<E5Q>(3.5048365030607127, 0.0, KM2, ---)
  cache:  (1×9 adjoint(::Vector{Float64}) with eltype Float64,
           9-element Vector{Float64})

Above is the tape when run under my local ENV, it seems adjoint propagated to the vector's elements and the tape recorded these operations in the CI run.

@mohamed82008
Copy link
Member

This PR needs more tests of different types of functions, inputs, outputs, kwargs, etc. Then we need to error if the arguments are not suitable, e.g. ChainRules supports struct and arbitrary inputs but ReverseDiff doesn't the way it's defined currently. See the @grad macro tests for a reference.

@KDr2
Copy link
Contributor Author

KDr2 commented Aug 23, 2021

e.g. ChainRules supports struct and arbitrary inputs but ReverseDiff doesn't the way it's defined currently. See the @grad macro tests for a reference.

This PR doesn't try to introduce the diverse inputs support which ChainRules has, it aims at using rrule on the data types which ReverseDiff supports. Anyway, I will see the tests for @grad and add more for @grad_from_chainrules.

@KDr2
Copy link
Contributor Author

KDr2 commented Sep 10, 2021

https://github.com/JuliaDiff/ReverseDiff.jl/pull/180/checks?check_run_id=3357730072#step:6:288 Here is the detail. It somehow has many scalar instructions in the tape.

I found this is caused by a change in DiffRules (JuliaDiff/DiffRules.jl@cbf17ea), it added adjoint to the diff rule list, so adjoint was recorded for each element.

And, this adjoint caused another test failure:
https://github.com/JuliaDiff/ReverseDiff.jl/pull/180/checks?check_run_id=3564090993#step:6:352
The gradient from ReverseDiff and ForwardDiff differ a lot, and it seems that the latter is correct.

@mohamed82008 @yebai @devmotion do you have any idea about the reason for this failure?

@mohamed82008
Copy link
Member

If you write a MWE here, I can try to look into it.

@KDr2
Copy link
Contributor Author

KDr2 commented Sep 10, 2021

If you write a MWE here, I can try to look into it.

Thank you, here is a MWE:

using ForwardDiff, ReverseDiff
using LinearAlgebra: norm

f(x) = norm(x' .* x)
x = rand(5)

test = ForwardDiff.gradient!(DiffResults.GradientResult(x), f, x)
@show DiffResults.gradient(test)

seedx = rand(eltype(x), size(x))
tp = ReverseDiff.GradientTape(f, seedx)
@show ReverseDiff.gradient!(tp, x)

Run it on the master branch of ReverseDiff, and ForwardDiff=0.10.18, DiffRules=1.0.2, the two results will be the same.
Run it on the master branch of ReverseDiff, and ForwardDiff=0.10.19, DiffRules=1.3.0, the two results will be different.

It is these two lines that cause the problem:

@define_diffrule Base.conj(x)                 = :(  1                                  )
@define_diffrule Base.adjoint(x)              = :(  1                                  )

https://github.com/JuliaDiff/DiffRules.jl/blob/cbf17ea233d16deb6f0a32bca661c63a79126adc/src/rules.jl#L70

@KDr2 KDr2 requested a review from devmotion October 19, 2021 09:10
@yebai
Copy link
Contributor

yebai commented Oct 26, 2021

@KDr2 has fixed all the comments, I think. Can everyone take another look before it's merged?

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@yebai yebai merged commit df00674 into JuliaDiff:master Oct 27, 2021
@yebai
Copy link
Contributor

yebai commented Oct 27, 2021

Many thanks, @KDr2 @devmotion @oxinabox and others for the help!

@KDr2 KDr2 deleted the chainrules branch October 29, 2021 05:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants