Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ function overload_autodiff(
argprefix,
resprefix,
resargprefix,
within_autodiff=true,
)
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
fnwrap = mlir_fn_res.fnwrapped
Expand Down
24 changes: 20 additions & 4 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ function set_reactant_abi(
if f === Reactant.call_with_reactant
arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end])
return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods)
elseif !(interp.within_autodiff_rewrite) && f === overload_autodiff
interp′ = Enzyme.Compiler.Interpreter.EnzymeInterpreter(
interp; within_autodiff_rewrite=true
)
return Base.@invoke abstract_call_known(
interp′::Enzyme.Compiler.Interpreter.EnzymeInterpreter,
f,
arginfo,
si,
sv,
max_methods,
)
end

return Base.@invoke abstract_call_known(
Expand All @@ -78,7 +90,9 @@ end
@static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE
struct ReactantCacheToken end

function ReactantInterpreter(; world::UInt=Base.get_world_counter())
function ReactantInterpreter(;
world::UInt=Base.get_world_counter(), within_autodiff=false
)
return Enzyme.Compiler.Interpreter.EnzymeInterpreter(
ReactantCacheToken(),
REACTANT_METHOD_TABLE,
Expand All @@ -87,15 +101,17 @@ end
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
false, #=within_autodiff_rewrite=#
within_autodiff, #=within_autodiff_rewrite=#
set_reactant_abi,
)
end
else
const REACTANT_CACHE = Enzyme.GPUCompiler.CodeCache()

function ReactantInterpreter(;
world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE
world::UInt=Base.get_world_counter(),
code_cache=REACTANT_CACHE,
within_autodiff=false,
)
return Enzyme.Compiler.Interpreter.EnzymeInterpreter(
REACTANT_CACHE,
Expand All @@ -105,7 +121,7 @@ else
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
false, #=within_autodiff_rewrite=#
within_autodiff, #=within_autodiff_rewrite=#
set_reactant_abi,
)
end
Expand Down
15 changes: 13 additions & 2 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ function make_mlir_fn(
args_in_result::Symbol=:all,
construct_function_without_args::Bool=false,
do_transpose=true,
within_autodiff=false,
input_shardings=nothing, # This is not meant to be used by the user.
output_shardings=nothing, # This is not meant to be used by the user.
runtime=nothing,
Expand Down Expand Up @@ -329,9 +330,19 @@ function make_mlir_fn(
process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map)

if isempty(kwargs)
Reactant.call_with_reactant(f, traced_args...)
if within_autodiff
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like a cleaner way to do this, is not to have a second interpreter. But instead we can create a new global ref set to false, and overlay within_autodiff to lookup that var, and during autodiff set that to true

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with this though, but if we were to do it in this form, I would probably change call_with_reactant to take a config type var, which stores the current state of whether in autodiff or not (and also we can extend to other things down the line as well)

Reactant.call_with_reactant_within_autodiff(f, traced_args...)
else
Reactant.call_with_reactant(f, traced_args...)
end
else
Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...)
if within_autodiff
Reactant.call_with_reactant_within_autodiff(
Core.kwcall, kwargs, f, traced_args...
)
else
Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...)
end
end
finally
MLIR.IR.deactivate!(fnbody)
Expand Down
9 changes: 8 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ function apply(f::F, args...; kwargs...) where {F}
end

function call_with_reactant end
function call_with_reactant_within_autodiff end

function maybe_argextype(@nospecialize(x), src)
return try
Expand Down Expand Up @@ -632,7 +633,9 @@ function call_with_reactant_generator(
))
end

interp = ReactantInterpreter(; world)
interp = ReactantInterpreter(;
world, within_autodiff=self == typeof(Reactant.call_with_reactant_within_autodiff)
)

min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
Expand Down Expand Up @@ -876,6 +879,10 @@ end
$(Expr(:meta, :generated_only))
return $(Expr(:meta, :generated, call_with_reactant_generator))
end
@eval function call_with_reactant_within_autodiff($REDUB_ARGUMENTS_NAME...)
$(Expr(:meta, :generated_only))
return $(Expr(:meta, :generated, call_with_reactant_generator))
end

@static if isdefined(Core, :BFloat16)
nmantissa(::Type{Core.BFloat16}) = 7
Expand Down
15 changes: 15 additions & 0 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,21 @@ end
@test res1[1] ≈ ores1[1]
end

function error_not_within_autodiff()
!Enzyme.within_autodiff() && error("Not within autodiff")
return nothing
end

fwd_within_autodiff(Mode, RT) = Enzyme.autodiff(Mode, error_not_within_autodiff, RT)

@testset "within_autodiff" begin
@test_throws ErrorException error_not_within_autodiff()
@test fwd_within_autodiff(Forward, Const) == ()

@test_throws ErrorException @jit error_not_within_autodiff()
@test (@jit fwd_within_autodiff(Forward, Const)) == ()
end

function gw(z)
return Enzyme.gradient(Forward, sum, z; chunk=Val(1))
end
Expand Down
Loading