Skip to content

[enzyme] broken BatchNorm gradient #2566

Closed
@CarloLucibello

Description

@CarloLucibello
using Flux, Enzyme, Statistics, Random

function enzyme_withgradient(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Enzyme.Active(x))
        else
            push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
        end
    end
    ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
    ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return ret[2], g
end

loss(model, x) = mean(model(x))
model = Flux.BatchNorm(2)
x = randn(Float32, 2, 5)
Flux.testmode!(model)
enzyme_withgradient(loss, model, x) # ok
Flux.trainmode!(model)
enzyme_withgradient(loss, model, x) # ERROR

output:

ERROR: 
No create nofree of empty function (julia.gc_loaded) julia.gc_loaded)
 at context:   call fastcc void @julia_reduced_indices_135883([2 x [1 x i64]]* noalias nocapture noundef nonnull sret([2 x [1 x i64]]) align 8 dereferenceable(16) %5, [2 x [1 x i64]] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(16) %15, {} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %1) #241, !dbg !245 (julia_reduced_indices_135883)

Stacktrace:
 [1] reduced_indices
   @ ./reducedim.jl:15
 [2] reducedim_initarray
   @ ./reducedim.jl:53


Stacktrace:
  [1] make_typealiases
    @ ./show.jl:849
  [2] macro expansion
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:562 [inlined]
  [3] lindex_v1
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:529 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:769 [inlined]
  [5] lindex_v3
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:700 [inlined]
  [6] override_bc_copyto!
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:802 [inlined]
  [7] copyto!
    @ ./broadcast.jl:920 [inlined]
  [8] copy
    @ ./broadcast.jl:892 [inlined]
  [9] materialize
    @ ./broadcast.jl:867 [inlined]
 [10] #_norm_layer_forward#252
    @ ~/.julia/dev/Flux/src/layers/normalise.jl:248
 [11] #_norm_layer_forward#252
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler/interpreter.jl:0
 [12] getproperty
    @ ./Base.jl:49 [inlined]
 [13] setindex!
    @ ./array.jl:987 [inlined]
 [14] centralize_sumabs2!
    @ ~/.julia/packages/Statistics/gbcbG/src/Statistics.jl:275
 [15] macro expansion
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5317 [inlined]
 [16] enzyme_call
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:4863 [inlined]
 [17] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:4735 [inlined]
 [18] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/DiEvV/src/Enzyme.jl:503
 [19] enzyme_withgradient(::Function, ::BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, ::Vararg{Any})
    @ Main ./REPL[14]:11
 [20] top-level scope
    @ REPL[21]:1
Some type information was truncated. Use `show(err)` to see complete types.

cc @wsmoses

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions