-
Couldn't load subscription status.
- Fork 34
Properly set within_autodiff (#442) #490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
dd14b85 to
8731c7f
Compare
|
|
||
| if isempty(kwargs) | ||
| Reactant.call_with_reactant(f, traced_args...) | ||
| if within_autodiff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like a cleaner way to do this, is not to have a second interpreter. But instead we can create a new global ref set to false, and overlay within_autodiff to lookup that var, and during autodiff set that to true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm okay with this though, but if we were to do it in this form, I would probably change call_with_reactant to take a config type var, which stores the current state of whether in autodiff or not (and also we can extend to other things down the line as well)
… === overload_autodiff`. This doesn't work for some reason, the function within overload autodiff uses the original interpreter (?)
…mlir_fn. In order to pass this information from make_mlir_fn to call_with_reactant_generator, I introduced a new function `call_with_reactant_within_autodiff` which allows detection by looking at `self`.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
a3115b5 to
2a4d0f2
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #490 +/- ##
==========================================
+ Coverage 42.55% 42.56% +0.01%
==========================================
Files 123 123
Lines 21816 21826 +10
==========================================
+ Hits 9283 9290 +7
- Misses 12533 12536 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
julia> using Enzyme
julia> function error_not_within_autodiff()
!Enzyme.within_autodiff() && error("Not within autodiff")
return nothing
end
error_not_within_autodiff (generic function with 1 method)
julia> fwd_within_autodiff(Mode, RT) = Enzyme.autodiff(Mode, error_not_within_autodiff, RT)
fwd_within_autodiff (generic function with 1 method)
julia> error_not_within_autodiff()
ERROR: Not within autodiff
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] error_not_within_autodiff()
@ Main ./REPL[5]:2
[3] top-level scope
@ REPL[7]:1
[4] top-level scope
@ none:1
julia> fwd_within_autodiff(Forward, Const)
()
julia> error_not_within_autodiff()
julia> Enzyme.within_autodiff()
falseI am extremely confused why is the 2nd call not throw an error here. Only happens if I call fwd_within_autodiff in between. cc @wsmoses this is in isolation from Reactant |
|
@vchuravy er wat |
fixes #442
needs Enzyme.jl: EnzymeAD/Enzyme.jl#2254
I had to introduce a new function
call_with_reactant_within_autodiffto smuggle thewithin_autodiffin thecall_with_reactant_generatorthrough theselfargument.I also tried doing things through
set_reactant_abibut that didn't seem to suffice (first commit).Perhaps the extra code in
set_reactant_abiisn't strictly necessary now so I can try removing it again if wanted.