Skip to content

Task switch error on Enzyme v0.13 #2081

Open
@MilesCranmer

Description

@MilesCranmer

When trying to use Enzyme as the autodiff backend for SymbolicRegression searches I ran into this error:

        nested task error: task switch not allowed from inside staged nor pure functions

The full stack trace:

┌ Error: Problem fitting the machine machine(SRRegressor(defaults = nothing, ), ). 
└ @ MLJBase ~/.julia/packages/MLJBase/7nGJF/src/machines.jl:694
[ Info: Running type checks... 
[ Info: Type checks okay. 
ERROR: LoadError: TaskFailedException
Stacktrace:
  [1] wait(t::Task)
    @ Base ./task.jl:370
  [2] fetch
    @ ./task.jl:390 [inlined]
  [3] _main_search_loop!(state::SymbolicRegression.SearchUtilsModule.SearchState{…}, datasets::Vector{…}, ropt::SymbolicRegression.SearchUtilsModule.RuntimeOptions{…}, options::Options{…})
    @ SymbolicRegression ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/SymbolicRegression.jl:833
  [4] _equation_search(datasets::Vector{…}, ropt::SymbolicRegression.SearchUtilsModule.RuntimeOptions{…}, options::Options{…}, saved_state::Nothing)
    @ SymbolicRegression ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/SymbolicRegression.jl:535
  [5] equation_search(datasets::Vector{…}; options::Options{…}, saved_state::Nothing, runtime_options::Nothing, runtime_options_kws::@Kwargs{})
    @ SymbolicRegression ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/SymbolicRegression.jl:525
  [6] equation_search(X::Matrix{…}, y::Matrix{…}; niterations::Int64, weights::Nothing, options::Options{…}, variable_names::Vector{…}, display_variable_names::Vector{…}, y_variable_names::Nothing, parallelism::Symbol, numprocs::Nothing, procs::Nothing, addprocs_function::Nothing, heap_size_hint_in_bytes::Nothing, runtests::Bool, saved_state::Nothing, return_state::Bool, run_id::Nothing, loss_type::Type{…}, verbosity::Int64, logger::Nothing, progress::Nothing, X_units::Nothing, y_units::Nothing, extra::@NamedTuple{}, v_dim_out::Val{…}, multithreaded::Nothing)
    @ SymbolicRegression ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/SymbolicRegression.jl:476
  [7] #equation_search#34
    @ ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/SymbolicRegression.jl:499 [inlined]
  [8] _update(m::SRRegressor{…}, verbosity::Int64, old_fitresult::Nothing, old_cache::Nothing, X::@NamedTuple{}, y::Vector{…}, w::Nothing, options::Options{…}, class::Vector{…})
    @ SymbolicRegression.MLJInterfaceModule ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/MLJInterface.jl:253
  [9] _update(m::SRRegressor{…}, verbosity::Int64, old_fitresult::Nothing, old_cache::Nothing, X::@NamedTuple{}, y::Vector{…}, w::Nothing, options::Options{…}, class::Nothing)
    @ SymbolicRegression.MLJInterfaceModule ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/MLJInterface.jl:220
 [10] update(m::SRRegressor{DynamicQuantities.SymbolicDimensions{…}, DataType}, verbosity::Int64, old_fitresult::Nothing, old_cache::Nothing, X::@NamedTuple{x1::Vector{…}, x2::Vector{…}, class::Vector{…}}, y::Vector{Float64}, w::Nothing)
    @ SymbolicRegression.MLJInterfaceModule ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/MLJInterface.jl:201
 [11] fit
    @ ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/MLJInterface.jl:189 [inlined]
 [12] fit(m::SRRegressor{DynamicQuantities.SymbolicDimensions{DynamicQuantities.FixedRational{Int32, 25200}}, DataType}, verbosity::Int64, X::@NamedTuple{x1::Vector{Float64}, x2::Vector{Float64}, class::Vector{Int64}}, y::Vector{Float64})
    @ SymbolicRegression.MLJInterfaceModule ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/MLJInterface.jl:189
 [13] fit_only!(mach::MLJBase.Machine{SRRegressor{DynamicQuantities.SymbolicDimensions{…}, DataType}, SRRegressor{DynamicQuantities.SymbolicDimensions{…}, DataType}, true}; rows::Nothing, verbosity::Int64, force::Bool, composite::Nothing)
    @ MLJBase ~/.julia/packages/MLJBase/7nGJF/src/machines.jl:692
 [14] fit_only!
    @ ~/.julia/packages/MLJBase/7nGJF/src/machines.jl:617 [inlined]
 [15] #fit!#63
    @ ~/.julia/packages/MLJBase/7nGJF/src/machines.jl:789 [inlined]
 [16] fit!(mach::MLJBase.Machine{SRRegressor{DynamicQuantities.SymbolicDimensions{DynamicQuantities.FixedRational{Int32, 25200}}, DataType}, SRRegressor{DynamicQuantities.SymbolicDimensions{DynamicQuantities.FixedRational{Int32, 25200}}, DataType}, true})
    @ MLJBase ~/.julia/packages/MLJBase/7nGJF/src/machines.jl:786
 [17] top-level scope
    @ ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/examples/parameterized_function.jl:103
 [18] include(fname::String)
    @ Main ./sysimg.jl:38
 [19] top-level scope
    @ REPL[3]:1

    nested task error: TaskFailedException
    Stacktrace:
     [1] wait(t::Task)
       @ Base ./task.jl:370
     [2] fetch
       @ ./task.jl:390 [inlined]
     [3] (::SymbolicRegression.var"#89#94"{SymbolicRegression.SearchUtilsModule.SearchState{}, Int64, Int64})()
       @ SymbolicRegression ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/SymbolicRegression.jl:810
    
        nested task error: task switch not allowed from inside staged nor pure functions
        Stacktrace:
          [1] try_yieldto(undo::typeof(Base.ensure_rescheduled))
            @ Base ./task.jl:948
          [2] wait()
            @ Base ./task.jl:1022
          [3] wait(c::Base.GenericCondition{Base.Threads.SpinLock}; first::Bool)
            @ Base ./condition.jl:130
          [4] wait
            @ ./condition.jl:125 [inlined]
          [5] (::Base.var"#slowlock#733")(rl::ReentrantLock)
            @ Base ./lock.jl:157
          [6] lock
            @ ./lock.jl:147 [inlined]
          [7] cached_compilation
            @ ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:8412 [inlined]
          [8] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
            @ Enzyme.Compiler ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:8548
          [9] #s2104#19135
            @ ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:8685 [inlined]
         [10] 
            @ Enzyme.Compiler ./none:0
         [11] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
            @ Core ./boot.jl:707
         [12] autodiff
            @ ~/.julia/packages/Enzyme/RvNgp/src/Enzyme.jl:473 [inlined]
         [13] autodiff
            @ ~/.julia/packages/Enzyme/RvNgp/src/Enzyme.jl:537 [inlined]
         [14] autodiff
            @ ~/.julia/packages/Enzyme/RvNgp/src/Enzyme.jl:504 [inlined]
         [15] (::SymbolicRegression.ConstantOptimizationModule.GradEvaluator{…})(::Float64, G::Vector{…}, x::Vector{…})
            @ SymbolicRegressionEnzymeExt ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/ext/SymbolicRegressionEnzymeExt.jl:42
         [16] (::NLSolversBase.var"#69#70"{NLSolversBase.InplaceObjective{}, Float64})(G::Vector{Float64}, x::Vector{Float64})
            @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/objective_types/incomplete.jl:54
         [17] value_gradient!!(obj::NLSolversBase.OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, x::Vector{Float64})
            @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/interface.jl:82
         [18] initial_state(method::Optim.BFGS{…}, options::Optim.Options{…}, d::NLSolversBase.OnceDifferentiable{…}, initial_x::Vector{…})
            @ Optim ~/.julia/packages/Optim/ZhuZN/src/multivariate/solvers/first_order/bfgs.jl:94
         [19] optimize
            @ ~/.julia/packages/Optim/ZhuZN/src/multivariate/optimize/optimize.jl:36 [inlined]
         [20] optimize(f::NLSolversBase.InplaceObjective{…}, initial_x::Vector{…}, method::Optim.BFGS{…}, options::Optim.Options{…}; inplace::Bool, autodiff::Symbol)
            @ Optim ~/.julia/packages/Optim/ZhuZN/src/multivariate/optimize/interface.jl:143
         [21] optimize
            @ ~/.julia/packages/Optim/ZhuZN/src/multivariate/optimize/interface.jl:139 [inlined]
         [22] macro expansion
            @ ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/ConstantOptimization.jl:76 [inlined]
         [23] _optimize_constants(dataset::Dataset{…}, member::PopMember{…}, options::Options{…}, algorithm::Optim.BFGS{…}, optimizer_options::Optim.Options{…}, idx::Nothing)
            @ SymbolicRegression.ConstantOptimizationModule ~/.julia/packages/DispatchDoctor/ZmxWH/src/stabilization.jl:314
         [24] macro expansion
            @ ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/ConstantOptimization.jl:46 [inlined]
         [25] dispatch_optimize_constants(dataset::Dataset{…}, member::PopMember{…}, options::Options{…}, idx::Nothing)
            @ SymbolicRegression.ConstantOptimizationModule ~/.julia/packages/DispatchDoctor/ZmxWH/src/stabilization.jl:314
         [26] macro expansion
            @ ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/ConstantOptimization.jl:27 [inlined]
         [27] optimize_constants(dataset::Dataset{…}, member::PopMember{…}, options::Options{…})
            @ SymbolicRegression.ConstantOptimizationModule ~/.julia/packages/DispatchDoctor/ZmxWH/src/stabilization.jl:314
         [28] macro expansion
            @ ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/SingleIteration.jl:118 [inlined]
         [29] macro expansion
            @ ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/Utils.jl:159 [inlined]
         [30] macro expansion
            @ ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/SingleIteration.jl:109 [inlined]
         [31] optimize_and_simplify_population(dataset::Dataset{…}, pop::Population{…}, options::Options{…}, curmaxsize::Int64, record::Dict{…})
            @ SymbolicRegression.SingleIterationModule ~/.julia/packages/DispatchDoctor/ZmxWH/src/stabilization.jl:314
         [32] macro expansion
            @ ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/SymbolicRegression.jl:1087 [inlined]
         [33] _dispatch_s_r_cycle(in_pop::Population{…}, dataset::Dataset{…}, options::Options{…}; pop::Int64, out::Int64, iteration::Int64, verbosity::Int64, cur_maxsize::Int64, running_search_statistics::SymbolicRegression.AdaptiveParsimonyModule.RunningSearchStatistics)
            @ SymbolicRegression ~/.julia/packages/DispatchDoctor/ZmxWH/src/stabilization.jl:314
         [34] macro expansion
            @ ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/SymbolicRegression.jl:762 [inlined]
         [35] (::SymbolicRegression.var"#86#88"{})()
            @ SymbolicRegression ~/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/src/SearchUtils.jl:263
in expression starting at /Users/mcranmer/PermaDocuments/SymbolicRegressionMonorepo/SymbolicRegression.jl/examples/parameterized_function.jl:103
Some type information was truncated. Use `show(err)` to see complete types.

To reproduce, you can run the example here: https://ai.damtp.cam.ac.uk/symbolicregression/dev/examples/parameterized_function/ and swap :Zygote for :Enzyme.

For example:

using SymbolicRegression
using Random: MersenneTwister
using Enzyme
using MLJBase: machine, fit!, predict, report
using Test

X = let rng = MersenneTwister(0), n = 30
    (; x1=randn(rng, n), x2=randn(rng, n), class=rand(rng, 1:2, n))
end

y = let P1 = [0.1, 1.5], P2 = [3.2, 0.5]
    [2 * cos(x2 + P1[class]) + x1^2 - P2[class] for (x1, x2, class) in zip(X.x1, X.x2, X.class)]
end

model = SRRegressor(;
    niterations=100,
    binary_operators=[+, *, /, -],
    unary_operators=[cos, exp],
    populations=30,
    expression_type=ParametricExpression,
    expression_options=(; max_parameters=2),
    autodiff_backend=:Enzyme,
);

mach = machine(model, X, y)
fit!(mach)

Could it be because I am running Enzyme.jl from a task within the code?

More context: Enzyme.jl used to work for this, and I don't think I changed anything that would cause new behavior on my end. But I just switched to v0.13 so not sure if something changed. I can't test v0.12 due to the error here: #2080

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions