Closed
Description
using Flux, Enzyme, Statistics, Random
function enzyme_withgradient(f, x...)
args = []
for x in x
if x isa Number
push!(args, Enzyme.Active(x))
else
push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
end
end
ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return ret[2], g
end
loss(model, x) = mean(model(x))
model = Flux.BatchNorm(2)
x = randn(Float32, 2, 5)
Flux.testmode!(model)
enzyme_withgradient(loss, model, x) # ok
Flux.trainmode!(model)
enzyme_withgradient(loss, model, x) # ERROR
output:
ERROR:
No create nofree of empty function (julia.gc_loaded) julia.gc_loaded)
at context: call fastcc void @julia_reduced_indices_135883([2 x [1 x i64]]* noalias nocapture noundef nonnull sret([2 x [1 x i64]]) align 8 dereferenceable(16) %5, [2 x [1 x i64]] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(16) %15, {} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %1) #241, !dbg !245 (julia_reduced_indices_135883)
Stacktrace:
[1] reduced_indices
@ ./reducedim.jl:15
[2] reducedim_initarray
@ ./reducedim.jl:53
Stacktrace:
[1] make_typealiases
@ ./show.jl:849
[2] macro expansion
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:562 [inlined]
[3] lindex_v1
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:529 [inlined]
[4] macro expansion
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:769 [inlined]
[5] lindex_v3
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:700 [inlined]
[6] override_bc_copyto!
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:802 [inlined]
[7] copyto!
@ ./broadcast.jl:920 [inlined]
[8] copy
@ ./broadcast.jl:892 [inlined]
[9] materialize
@ ./broadcast.jl:867 [inlined]
[10] #_norm_layer_forward#252
@ ~/.julia/dev/Flux/src/layers/normalise.jl:248
[11] #_norm_layer_forward#252
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:0
[12] getproperty
@ ./Base.jl:49 [inlined]
[13] setindex!
@ ./array.jl:987 [inlined]
[14] centralize_sumabs2!
@ ~/.julia/packages/Statistics/gbcbG/src/Statistics.jl:275
[15] macro expansion
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5317 [inlined]
[16] enzyme_call
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:4863 [inlined]
[17] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:4735 [inlined]
[18] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
@ Enzyme ~/.julia/packages/Enzyme/DiEvV/src/Enzyme.jl:503
[19] enzyme_withgradient(::Function, ::BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, ::Vararg{Any})
@ Main ./REPL[14]:11
[20] top-level scope
@ REPL[21]:1
Some type information was truncated. Use `show(err)` to see complete types.
cc @wsmoses