-
Notifications
You must be signed in to change notification settings - Fork 79
Add defer_within_autodiff to EnzymeInterpreter
#2254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/compiler/interpreter.jl
Outdated
| (; fargs, argtypes) = arginfo | ||
|
|
||
| if f === Enzyme.within_autodiff | ||
| if !(interp.defer_within_autodiff) && f === Enzyme.within_autodiff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary? This fundamentally breaks this functionality?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for a Reactant bug: EnzymeAD/Reactant.jl#442 (comment)
Reason being that Reactant uses EnzymeInterpreter as well, while not necessarily doing autodiff.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2254 +/- ##
==========================================
- Coverage 74.93% 74.92% -0.01%
==========================================
Files 56 56
Lines 17434 17436 +2
==========================================
Hits 13064 13064
- Misses 4370 4372 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@vchuravy can you give this a review before merge |
|
Seems fine. |
3e37885 to
9f54068
Compare
|
bump on this |
…_autodiff` to no return true during Reactant compilation. When this flag is true, `interp.handler` is responsible for handling within_autodiff, or to toggle defer_within_autodiff to false somewhere down the call chain.
9f54068 to
f1e15c9
Compare
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl
index 77f0027..260c90b 100644
--- a/src/compiler/interpreter.jl
+++ b/src/compiler/interpreter.jl
@@ -173,7 +173,7 @@ function EnzymeInterpreter(
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
- within_autodiff_rewrite::Bool = true,
+ within_autodiff_rewrite::Bool = true,
handler = nothing
)
@assert world <= Base.get_world_counter()
@@ -250,23 +250,27 @@ EnzymeInterpreter(
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, within_autodiff_rewrite, handler)
-function EnzymeInterpreter(interp::EnzymeInterpreter;
- cache_or_token = (@static if HAS_INTEGRATED_CACHE
- interp.token
- else
- interp.code_cache
- end),
- mt = interp.method_table,
- local_cache = interp.local_cache,
- world = interp.world,
- inf_params = interp.inf_params,
- opt_params = interp.opt_params,
- forward_rules = interp.forward_rules,
- reverse_rules = interp.reverse_rules,
- inactive_rules = interp.inactive_rules,
- broadcast_rewrite = interp.broadcast_rewrite,
- within_autodiff_rewrite = interp.within_autodiff_rewrite,
- handler = interp.handler)
+function EnzymeInterpreter(
+ interp::EnzymeInterpreter;
+ cache_or_token = (
+ @static if HAS_INTEGRATED_CACHE
+ interp.token
+ else
+ interp.code_cache
+ end
+ ),
+ mt = interp.method_table,
+ local_cache = interp.local_cache,
+ world = interp.world,
+ inf_params = interp.inf_params,
+ opt_params = interp.opt_params,
+ forward_rules = interp.forward_rules,
+ reverse_rules = interp.reverse_rules,
+ inactive_rules = interp.inactive_rules,
+ broadcast_rewrite = interp.broadcast_rewrite,
+ within_autodiff_rewrite = interp.within_autodiff_rewrite,
+ handler = interp.handler
+ )
return EnzymeInterpreter(
cache_or_token,
mt, |
Together with Reactant pr: EnzymeAD/Reactant.jl#490