Skip to content

Improve DiffRules integration and tests #209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 16, 2022
Merged

Improve DiffRules integration and tests #209

merged 4 commits into from
Oct 16, 2022

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented Oct 1, 2022

This PR fixes some problems with the DiffRules integration and its tests. It is needed for JuliaDiff/DiffRules.jl#79 (relevant DiffRules tests pass with that PR).

Mainly, the PR

  • disables DiffRules definitions and tests for functions with complex outputs and derivatives as they are not supported by ReverseDiff
  • adds support for NaN comparisons to the tests (necessary since in DiffRules undefined and non-existing derivatives are implemented as NaN and hence otherwise comparisons with ForwardDiff will fail if both return NaN)
  • changes the tests such that they compare ReverseDiff results with the results of the corresponding ForwardDiff calls (currently, e.g., derivatives wrt to two arguments in ReverseDiff are compared with two separately computed ForwardDiff derivatives even though internally both derivatives are computed with one ForwardDiff call - this causes test errors since e.g. if one derivative is NaN the ForwardDiff results of both approaches are different for the derivative of the other argument, one will be NaN and one might not)
  • improves the map/broadcasting of DiffRules as currently internally derivatives are computed with ForwardDiff always for both arguments, even if only one is tracked (this was uncovered by the changes to the tests mentioned above and it ensures that derivatives of functions where derivatives are defined only for one argument return non-NaN results, as in ForwardDiff)

The vcat test error is unrelated and also present on the master branch and other PRs. Edit: Fixed on the master branch,

I also assume we could do better than ForwardDiff here and also avoid that all results become NaN if derivatives are computed with respect to both arguments and only one is defined/exists.
But replacing ForwardDiff with a direct implementation of the DiffRules-derivatives seemed to require much larger changes, and I tried to apply only a somewhat minimal set of changes required for JuliaDiff/DiffRules.jl#79.

@codecov-commenter
Copy link

codecov-commenter commented Oct 3, 2022

Codecov Report

Base: 85.16% // Head: 81.24% // Decreases project coverage by -3.92% ⚠️

Coverage data is based on head (088182c) compared to base (8ac1f7d).
Patch coverage: 58.06% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #209      +/-   ##
==========================================
- Coverage   85.16%   81.24%   -3.93%     
==========================================
  Files          18       18              
  Lines        1861     1578     -283     
==========================================
- Hits         1585     1282     -303     
- Misses        276      296      +20     
Impacted Files Coverage Δ
src/ReverseDiff.jl 100.00% <ø> (ø)
src/derivatives/scalars.jl 95.74% <ø> (-0.98%) ⬇️
src/derivatives/elementwise.jl 74.71% <58.06%> (-4.16%) ⬇️
src/derivatives/broadcast.jl 76.62% <0.00%> (-13.99%) ⬇️
src/tape.jl 55.55% <0.00%> (-9.16%) ⬇️
src/tracked.jl 86.81% <0.00%> (-5.49%) ⬇️
src/api/hessians.jl 84.00% <0.00%> (-3.50%) ⬇️
src/api/tape.jl 72.22% <0.00%> (-3.05%) ⬇️
src/derivatives/linalg/arithmetic.jl 69.67% <0.00%> (-1.45%) ⬇️
... and 10 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@devmotion devmotion requested a review from mohamed82008 October 3, 2022 12:29
@devmotion
Copy link
Member Author

Bump 🙂

It would be good to fix ReverseDiff such that we can move forward with JuliaDiff/DiffRules.jl#79.

@mohamed82008
Copy link
Member

Sorry for the delay, been swamped recently. I will take a look tonight.

Copy link
Member

@mohamed82008 mohamed82008 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have tests for ForwardOptimize where both x and y are tracked? Seems there might be a method ambiguity error in this case?

@mohamed82008
Copy link
Member

Ah the methods are defined a bit further in the same file, nevermind.

@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::TrackedArray{X}, y::$A) where {F,X}
result = DiffResults.GradientResult(SVector(zero(X)))
df = (vx, vy) -> let vy=vy
ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why s[1]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since s is a SVector with a single element vx which we want to use here. That's just the one-argument version of the current implementation on the master branch:

result = DiffResults.GradientResult(SVector(zero(S), zero(S)))
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))

@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::TrackedArray{Y}) where {F,Y}
result = DiffResults.GradientResult(SVector(zero(Y)))
df = let vx=vx
(vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s[1]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@eval function Base.$(g)(f::ForwardOptimize{F}, x::TrackedArray{X,D}, y::$A) where {F,X,D}
result = DiffResults.GradientResult(SVector(zero(X)))
df = (vx, vy) -> let vy=vy
ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s[1]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2, b_bound)
if istracked(a)
p += 1
diffresult_increment_deriv!(a, output_deriv, results, p, a_bound)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change the value of p here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To extract the correct partial. If a is tracked, its corresponding partial has index p = 1 but if only b is tracked, the first partial (p = 1) corresponds to b. And if both a and b are tracked, p = 1 corresponds to a and p = 2 to b. So incrementing p in the branches allows us to avoid checking and handling all three scenarios separately.

Note that on the master branch p = 1 for a and p = 2 for b are hardcoded. That only works because on the master branch always the partials wrt to both arguments are computed and stored, even if only one argument is tracked.

end
if istracked(b)
p += 1
diffresult_increment_deriv!(b, output_deriv, results, p, b_bound)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same p comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::TrackedArray{Y,D}) where {F,Y,D}
result = DiffResults.GradientResult(SVector(zero(Y)))
df = (vx, vy) -> let vx=vx
ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same s[1] comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Member

@mohamed82008 mohamed82008 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mohamed82008 mohamed82008 merged commit f06b776 into master Oct 16, 2022
@devmotion devmotion deleted the dw/diffrules branch October 17, 2022 07:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants