Skip to content

Get Mooncake direct adjoints working #2723

@ChrisRackauckas

Description

@ChrisRackauckas

MWE:

using Mooncake, OrdinaryDiffEq, StaticArrays

function lorenz!(du, u, p, t)
    du[1] = 10.0(u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]

function f(u0::Array{Float64})
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
    sum(sol)
end;
u0 = [1.0; 0.0; 0.0]
mooncake_gradient(f, x) = Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2]
mooncake_gradient(f, u0)
caused by: Mooncake.UnhandledLanguageFeatureException("Encountered UpsilonNode: ϒ (_4). These are generated as part of some try / catch / finally blocks. At the present time, Mooncake.jl cannot differentiate through these, so they must be avoided. Strategies for resolving this error include re-writing code such that it avoids generating any UpsilonNodes, or writing a rule to differentiate the code by hand. If you are in any doubt as to what to do, please request assistance by opening an issue at github.com/chalk-lab/Mooncake.jl.")
Stacktrace:
  [1] unhandled_feature(msg::String)
    @ Mooncake ~/.julia/packages/Mooncake/5I5qv/src/interpreter/ir_utils.jl:247
  [2] make_ad_stmts!(stmt::Core.UpsilonNode, ::Mooncake.BasicBlockCode.ID, ::Mooncake.ADInfo)
    @ Mooncake ~/.julia/packages/Mooncake/5I5qv/src/interpreter/s2s_reverse_mode_ad.jl:630
  [3] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
  [4] _broadcast_getindex
    @ ./broadcast.jl:682 [inlined]
  [5] getindex
    @ ./broadcast.jl:636 [inlined]
  [6] macro expansion
    @ ./broadcast.jl:1004 [inlined]
  [7] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [8] copyto!
    @ ./broadcast.jl:1003 [inlined]
  [9] copyto!
    @ ./broadcast.jl:956 [inlined]
 [10] copy
    @ ./broadcast.jl:928 [inlined]
 [11] materialize(bc::Base.Broadcast.Broadcasted{…})
    @ Base.Broadcast ./broadcast.jl:903
 [12] (::Mooncake.var"#208#210"{Mooncake.ADInfo})(primal_blk::Mooncake.BasicBlockCode.BBlock)
    @ Mooncake ~/.julia/packages/Mooncake/5I5qv/src/interpreter/s2s_reverse_mode_ad.jl:1175
 [13] iterate
    @ ./generator.jl:47 [inlined]
 [14] collect_to!(dest::Vector{Tuple{…}}, itr::Base.Generator{Vector{…}, Mooncake.var"#208#210"{…}}, offs::Int64, st::Int64)
    @ Base ./array.jl:892
 [15] collect_to_with_first!(dest::Vector{…}, v1::Tuple{…}, itr::Base.Generator{…}, st::Int64)
    @ Base ./array.jl:870
 [16] _collect(c::Vector{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
    @ Base ./array.jl:864
 [17] collect_similar(cont::Vector{Mooncake.BasicBlockCode.BBlock}, itr::Base.Generator{Vector{…}, Mooncake.var"#208#210"{…}})
    @ Base ./array.jl:763
 [18] map
    @ ./abstractarray.jl:3286 [inlined]
 [19] generate_ir(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Type; debug_mode::Bool, do_inline::Bool)
    @ Mooncake ~/.julia/packages/Mooncake/5I5qv/src/interpreter/s2s_reverse_mode_ad.jl:1172
 [20] generate_ir
    @ ~/.julia/packages/Mooncake/5I5qv/src/interpreter/s2s_reverse_mode_ad.jl:1149 [inlined]
 [21] build_rrule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Type; debug_mode::Bool, silence_debug_messages::Bool)
    @ Mooncake ~/.julia/packages/Mooncake/5I5qv/src/interpreter/s2s_reverse_mode_ad.jl:1114
 [22] top-level scope
    @ ~/Desktop/test2.jl:27
Some type information was truncated. Use `show(err)` to see complete types.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions