Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
reverse_rules::Bool
inactive_rules::Bool
broadcast_rewrite::Bool

# When false, leave the check for within_autodiff to the handler.
within_autodiff_rewrite::Bool

handler::T
end

Expand Down Expand Up @@ -169,6 +173,7 @@ function EnzymeInterpreter(
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
within_autodiff_rewrite::Bool = true,
handler = nothing
)
@assert world <= Base.get_world_counter()
Expand Down Expand Up @@ -229,6 +234,7 @@ function EnzymeInterpreter(
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool,
within_autodiff_rewrite::Bool,
handler
)
end
Expand All @@ -240,8 +246,42 @@ EnzymeInterpreter(
mode::API.CDerivativeMode,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
within_autodiff_rewrite::Bool = true,
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, handler)
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, within_autodiff_rewrite, handler)

function EnzymeInterpreter(interp::EnzymeInterpreter;
cache_or_token = (@static if HAS_INTEGRATED_CACHE
interp.token
else
interp.code_cache
end),
mt = interp.method_table,
local_cache = interp.local_cache,
world = interp.world,
inf_params = interp.inf_params,
opt_params = interp.opt_params,
forward_rules = interp.forward_rules,
reverse_rules = interp.reverse_rules,
inactive_rules = interp.inactive_rules,
broadcast_rewrite = interp.broadcast_rewrite,
within_autodiff_rewrite = interp.within_autodiff_rewrite,
handler = interp.handler)
return EnzymeInterpreter(
cache_or_token,
mt,
local_cache,
world,
inf_params,
opt_params,
forward_rules,
reverse_rules,
inactive_rules,
broadcast_rewrite,
within_autodiff_rewrite,
handler
)
end

Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params
Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params
Expand Down Expand Up @@ -933,7 +973,7 @@ function abstract_call_known(

(; fargs, argtypes) = arginfo

if f === Enzyme.within_autodiff
if interp.within_autodiff_rewrite && f === Enzyme.within_autodiff
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
Expand Down
Loading