Open
Description
using Flux, Enzyme, Statistics, Random
function enzyme_withgradient(f, x...)
args = []
for x in x
if x isa Number
push!(args, Enzyme.Active(x))
else
push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
end
end
ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return ret[2], g
end
loss(model, x) = mean(model(x))
model = Flux.Bilinear((2, 2) => 3)
x = randn(Float32, 2, 1)
enzyme_withgradient(loss, model, x)
Output:
ERROR: MethodError: no method matching function_attributes(::LLVM.UserOperandSet)
The function `function_attributes` exists, but no method is defined for this combination of argument types.
Closest candidates are:
function_attributes(::LLVM.Function)
@ LLVM ~/.julia/packages/LLVM/wMjUU/src/core/function.jl:127
Stacktrace:
[1] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{…}, imported::Set{…}, f::LLVM.Function, deletedfns::Vector{…}, mod::LLVM.Module)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:402
[2] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{Tuple{String, Vector{…}, Any}}, mod::LLVM.Module)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:210
[3] check_ir
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:179 [inlined]
[4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:3413
[5] codegen
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:3338 [inlined]
[6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5387
[7] _thunk
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5387 [inlined]
[8] cached_compilation
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5439 [inlined]
[9] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5550
[10] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5735
[11] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
@ Enzyme ~/.julia/packages/Enzyme/DiEvV/src/Enzyme.jl:485
[12] enzyme_withgradient(::Function, ::Flux.Bilinear{typeof(identity), Array{Float32, 3}, Vector{Float32}}, ::Vararg{Any})
@ Main ~/.julia/dev/Flux/test/test_utils.jl:32
[13] top-level scope
@ ~/.julia/dev/Flux/prova.jl:17
Some type information was truncated. Use `show(err)` to see complete types.
cc @wsmoses
Activity