-
Notifications
You must be signed in to change notification settings - Fork 57
Improve DiffRules integration and tests #209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportBase: 85.16% // Head: 81.24% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #209 +/- ##
==========================================
- Coverage 85.16% 81.24% -3.93%
==========================================
Files 18 18
Lines 1861 1578 -283
==========================================
- Hits 1585 1282 -303
- Misses 276 296 +20
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
Bump 🙂 It would be good to fix ReverseDiff such that we can move forward with JuliaDiff/DiffRules.jl#79. |
Sorry for the delay, been swamped recently. I will take a look tonight. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have tests for ForwardOptimize where both x and y are tracked? Seems there might be a method ambiguity error in this case?
Ah the methods are defined a bit further in the same file, nevermind. |
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::TrackedArray{X}, y::$A) where {F,X} | ||
result = DiffResults.GradientResult(SVector(zero(X))) | ||
df = (vx, vy) -> let vy=vy | ||
ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why s[1]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since s
is a SVector
with a single element vx
which we want to use here. That's just the one-argument version of the current implementation on the master branch:
ReverseDiff.jl/src/derivatives/elementwise.jl
Lines 116 to 117 in d522508
result = DiffResults.GradientResult(SVector(zero(S), zero(S))) | |
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy)) |
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::TrackedArray{Y}) where {F,Y} | ||
result = DiffResults.GradientResult(SVector(zero(Y))) | ||
df = let vx=vx | ||
(vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s[1]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
@eval function Base.$(g)(f::ForwardOptimize{F}, x::TrackedArray{X,D}, y::$A) where {F,X,D} | ||
result = DiffResults.GradientResult(SVector(zero(X))) | ||
df = (vx, vy) -> let vy=vy | ||
ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s[1]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2, b_bound) | ||
if istracked(a) | ||
p += 1 | ||
diffresult_increment_deriv!(a, output_deriv, results, p, a_bound) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why change the value of p
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To extract the correct partial. If a
is tracked, its corresponding partial has index p = 1
but if only b
is tracked, the first partial (p = 1
) corresponds to b
. And if both a
and b
are tracked, p = 1
corresponds to a
and p = 2
to b
. So incrementing p
in the branches allows us to avoid checking and handling all three scenarios separately.
Note that on the master branch p = 1
for a
and p = 2
for b
are hardcoded. That only works because on the master branch always the partials wrt to both arguments are computed and stored, even if only one argument is tracked.
end | ||
if istracked(b) | ||
p += 1 | ||
diffresult_increment_deriv!(b, output_deriv, results, p, b_bound) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same p comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::TrackedArray{Y,D}) where {F,Y,D} | ||
result = DiffResults.GradientResult(SVector(zero(Y))) | ||
df = (vx, vy) -> let vx=vx | ||
ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same s[1]
comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR fixes some problems with the DiffRules integration and its tests. It is needed for JuliaDiff/DiffRules.jl#79 (relevant DiffRules tests pass with that PR).
Mainly, the PR
NaN
comparisons to the tests (necessary since in DiffRules undefined and non-existing derivatives are implemented asNaN
and hence otherwise comparisons with ForwardDiff will fail if both returnNaN
)NaN
the ForwardDiff results of both approaches are different for the derivative of the other argument, one will beNaN
and one might not)map
/broadcast
ing of DiffRules as currently internally derivatives are computed with ForwardDiff always for both arguments, even if only one is tracked (this was uncovered by the changes to the tests mentioned above and it ensures that derivatives of functions where derivatives are defined only for one argument return non-NaN
results, as in ForwardDiff)TheEdit: Fixed on the master branch,vcat
test error is unrelated and also present on the master branch and other PRs.I also assume we could do better than ForwardDiff here and also avoid that all results become
NaN
if derivatives are computed with respect to both arguments and only one is defined/exists.But replacing ForwardDiff with a direct implementation of the DiffRules-derivatives seemed to require much larger changes, and I tried to apply only a somewhat minimal set of changes required for JuliaDiff/DiffRules.jl#79.