Skip to content

Extremely long compilation time on recursive, branching functions (DynamicExpressions.jl) #1156

@MilesCranmer

Description

@MilesCranmer

The old issue #1018 from August was getting a bit lengthy so I'm moving to a new issue – feel free to close that one.

This relates to my one-year effort to try to get Enzyme.jl working as one of the AD backends for DynamicExpressions.jl, SymbolicRegression.jl, and PySR. The current status is:

  1. Working first-order gradients, if I disable some of the optimizations
  2. Hanging first-order gradients (extremely long compilation time), if all the optimizations are left on
  3. Hanging second-order gradients (extremely long compilation time), regardless of optimization settings
(expand) I've boiled down the MWE to the following code which replicates the issues I am seeing:
using Enzyme

################################################################################
### OperatorEnum.jl
################################################################################
struct OperatorEnum{B,U}
    binops::B
    unaops::U
end
################################################################################

################################################################################
### Equation.jl
################################################################################
mutable struct Node{T}
    degree::UInt8  # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
    constant::Bool  # false if variable
    val::Union{T,Nothing}  # If is a constant, this stores the actual value
    # ------------------- (possibly undefined below)
    feature::UInt16  # If is a variable (e.g., x in cos(x)), this stores the feature index.
    op::UInt8  # If operator, this is the index of the operator in operators.binops, or operators.unaops
    l::Node{T}  # Left child node. Only defined for degree=1 or degree=2.
    r::Node{T}  # Right child node. Only defined for degree=2. 
    Node(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
    Node(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
    Node(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f))
    Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l)
    Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}, r::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r)

end
function Node(::Type{T}; val::T1=nothing, feature::T2=nothing)::Node{T} where {T,T1,T2}
    if T2 <: Nothing
        !(T1 <: T) && (val = convert(T, val))
        return Node(T, 0, true, val)
    else
        return Node(T, 0, false, nothing, feature)
    end
end
Node(op::Integer, l::Node{T}) where {T} = Node(1, false, nothing, 0, op, l)
Node(op::Integer, l::Node{T}, r::Node{T}) where {T} = Node(2, false, nothing, 0, op, l, r)
################################################################################

################################################################################
### Utils.jl
################################################################################
@inline function fill_similar(value, array, args...)
    out_array = similar(array, args...)
    out_array .= value
    return out_array
end
is_bad_array(array) = !(isempty(array) || isfinite(sum(array)))
function is_constant(tree::Node)
    if tree.degree == 0
        return tree.constant
    elseif tree.degree == 1
        return is_constant(tree.l)
    else
        return is_constant(tree.l) && is_constant(tree.r)
    end
end

################################################################################

################################################################################
### EvaluateEquation.jl
################################################################################
struct ResultOk{A<:AbstractArray}
    x::A
    ok::Bool
end

function eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, fuse_level=Val(2)) where {T<:Number}
    result = _eval_tree_array(tree, cX, operators, fuse_level)
    return (result.x, result.ok && !is_bad_array(result.x))
end

counttuple(::Type{<:NTuple{N,Any}}) where {N} = N
get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U)
get_nbin(::Type{<:OperatorEnum{B}}) where {B} = counttuple(B)

@generated function _eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{fuse_level})::ResultOk where {T<:Number,fuse_level}
    nuna = get_nuna(operators)
    nbin = get_nbin(operators)
    quote
        # First, we see if there are only constants in the tree - meaning
        # we can just return the constant result.
        if tree.degree == 0
            return deg0_eval(tree, cX)
        elseif is_constant(tree)
            # Speed hack for constant trees.
            const_result = _eval_constant_tree(tree, operators)::ResultOk{Vector{T}}
            !const_result.ok && return ResultOk(similar(cX, axes(cX, 2)), false)
            return ResultOk(fill_similar(const_result.x[], cX, axes(cX, 2)), true)
        elseif tree.degree == 1
            op_idx = tree.op
            # This @nif lets us generate an if statement over choice of operator,
            # which means the compiler will be able to completely avoid type inference on operators.
            return Base.Cartesian.@nif(
                $nuna,
                i -> i == op_idx,
                i -> let op = operators.unaops[i]
                    if fuse_level > 1 && tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
                        # op(op2(x, y)), where x, y, z are constants or variables.
                        l_op_idx = tree.l.op
                        Base.Cartesian.@nif(
                            $nbin,
                            j -> j == l_op_idx,
                            j -> let op_l = operators.binops[j]
                                deg1_l2_ll0_lr0_eval(tree, cX, op, op_l)
                            end,
                        )
                    elseif fuse_level > 1 && tree.l.degree == 1 && tree.l.l.degree == 0
                        # op(op2(x)), where x is a constant or variable.
                        l_op_idx = tree.l.op
                        Base.Cartesian.@nif(
                            $nuna,
                            j -> j == l_op_idx,
                            j -> let op_l = operators.unaops[j]
                                deg1_l1_ll0_eval(tree, cX, op, op_l)
                            end,
                        )
                    else
                        # op(x), for any x.
                        result = _eval_tree_array(tree.l, cX, operators, Val(fuse_level))
                        !result.ok && return result
                        deg1_eval(result.x, op)
                    end
                end
            )
        else
            op_idx = tree.op
            return Base.Cartesian.@nif(
                $nbin,
                i -> i == op_idx,
                i -> let op = operators.binops[i]
                    if fuse_level > 1 && tree.l.degree == 0 && tree.r.degree == 0
                        deg2_l0_r0_eval(tree, cX, op)
                    elseif tree.r.degree == 0
                        result_l = _eval_tree_array(tree.l, cX, operators, Val(fuse_level))
                        !result_l.ok && return result_l
                        # op(x, y), where y is a constant or variable but x is not.
                        deg2_r0_eval(tree, result_l.x, cX, op)
                    elseif tree.l.degree == 0
                        result_r = _eval_tree_array(tree.r, cX, operators, Val(fuse_level))
                        !result_r.ok && return result_r
                        # op(x, y), where x is a constant or variable but y is not.
                        deg2_l0_eval(tree, result_r.x, cX, op)
                    else
                        result_l = _eval_tree_array(tree.l, cX, operators, Val(fuse_level))
                        !result_l.ok && return result_l
                        result_r = _eval_tree_array(tree.r, cX, operators, Val(fuse_level))
                        !result_r.ok && return result_r
                        # op(x, y), for any x or y
                        deg2_eval(result_l.x, result_r.x, op)
                    end
                end
            )
        end
    end
end

function deg2_eval(
    cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F
)::ResultOk where {T<:Number,F}
    @inbounds @simd for j in eachindex(cumulator_l)
        x = op(cumulator_l[j], cumulator_r[j])::T
        cumulator_l[j] = x
    end
    return ResultOk(cumulator_l, true)
end

function deg1_eval(
    cumulator::AbstractVector{T}, op::F
)::ResultOk where {T<:Number,F}
    @inbounds @simd for j in eachindex(cumulator)
        x = op(cumulator[j])::T
        cumulator[j] = x
    end
    return ResultOk(cumulator, true)
end

function deg0_eval(tree::Node{T}, cX::AbstractMatrix{T})::ResultOk where {T<:Number}
    if tree.constant
        return ResultOk(fill_similar(tree.val::T, cX, axes(cX, 2)), true)
    else
        return ResultOk(cX[tree.feature, :], true)
    end
end

function deg1_l2_ll0_lr0_eval(
    tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2
) where {T<:Number,F,F2}
    if tree.l.l.constant && tree.l.r.constant
        val_ll = tree.l.l.val::T
        val_lr = tree.l.r.val::T
        x_l = op_l(val_ll, val_lr)::T
        x = op(x_l)::T
        return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
    elseif tree.l.l.constant
        val_ll = tree.l.l.val::T
        feature_lr = tree.l.r.feature
        cumulator = similar(cX, axes(cX, 2))
        @inbounds @simd for j in axes(cX, 2)
            x_l = op_l(val_ll, cX[feature_lr, j])::T
            x = isfinite(x_l) ? op(x_l)::T : T(Inf)
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    elseif tree.l.r.constant
        feature_ll = tree.l.l.feature
        val_lr = tree.l.r.val::T
        cumulator = similar(cX, axes(cX, 2))
        @inbounds @simd for j in axes(cX, 2)
            x_l = op_l(cX[feature_ll, j], val_lr)::T
            x = isfinite(x_l) ? op(x_l)::T : T(Inf)
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    else
        feature_ll = tree.l.l.feature
        feature_lr = tree.l.r.feature
        cumulator = similar(cX, axes(cX, 2))
        @inbounds @simd for j in axes(cX, 2)
            x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T
            x = isfinite(x_l) ? op(x_l)::T : T(Inf)
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    end
end

# op(op2(x)) for x variable or constant
function deg1_l1_ll0_eval(
    tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2
) where {T<:Number,F,F2}
    if tree.l.l.constant
        val_ll = tree.l.l.val::T
        x_l = op_l(val_ll)::T
        x = op(x_l)::T
        return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
    else
        feature_ll = tree.l.l.feature
        cumulator = similar(cX, axes(cX, 2))
        @inbounds @simd for j in axes(cX, 2)
            x_l = op_l(cX[feature_ll, j])::T
            x = isfinite(x_l) ? op(x_l)::T : T(Inf)
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    end
end

# op(x, y) for x and y variable/constant
function deg2_l0_r0_eval(
    tree::Node{T}, cX::AbstractMatrix{T}, op::F
) where {T<:Number,F}
    if tree.l.constant && tree.r.constant
        val_l = tree.l.val::T
        val_r = tree.r.val::T
        x = op(val_l, val_r)::T
        return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
    elseif tree.l.constant
        cumulator = similar(cX, axes(cX, 2))
        val_l = tree.l.val::T
        feature_r = tree.r.feature
        @inbounds @simd for j in axes(cX, 2)
            x = op(val_l, cX[feature_r, j])::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    elseif tree.r.constant
        cumulator = similar(cX, axes(cX, 2))
        feature_l = tree.l.feature
        val_r = tree.r.val::T
        @inbounds @simd for j in axes(cX, 2)
            x = op(cX[feature_l, j], val_r)::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    else
        cumulator = similar(cX, axes(cX, 2))
        feature_l = tree.l.feature
        feature_r = tree.r.feature
        @inbounds @simd for j in axes(cX, 2)
            x = op(cX[feature_l, j], cX[feature_r, j])::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    end
end

# op(x, y) for x variable/constant, y arbitrary
function deg2_l0_eval(
    tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F
) where {T<:Number,F}
    if tree.l.constant
        val = tree.l.val::T
        @inbounds @simd for j in eachindex(cumulator)
            x = op(val, cumulator[j])::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    else
        feature = tree.l.feature
        @inbounds @simd for j in eachindex(cumulator)
            x = op(cX[feature, j], cumulator[j])::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    end
end

# op(x, y) for x arbitrary, y variable/constant
function deg2_r0_eval(
    tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F
) where {T<:Number,F}
    if tree.r.constant
        val = tree.r.val::T
        @inbounds @simd for j in eachindex(cumulator)
            x = op(cumulator[j], val)::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    else
        feature = tree.r.feature
        @inbounds @simd for j in eachindex(cumulator)
            x = op(cumulator[j], cX[feature, j])::T
            cumulator[j] = x
        end
        return ResultOk(cumulator, true)
    end
end
@generated function _eval_constant_tree(tree::Node{T}, operators::OperatorEnum) where {T<:Number}
    nuna = get_nuna(operators)
    nbin = get_nbin(operators)
    quote
        if tree.degree == 0
            return deg0_eval_constant(tree)::ResultOk{Vector{T}}
        elseif tree.degree == 1
            op_idx = tree.op
            return Base.Cartesian.@nif(
                $nuna,
                i -> i == op_idx,
                i -> deg1_eval_constant(
                    tree, operators.unaops[i], operators
                )::ResultOk{Vector{T}}
            )
        else
            op_idx = tree.op
            return Base.Cartesian.@nif(
                $nbin,
                i -> i == op_idx,
                i -> deg2_eval_constant(
                    tree, operators.binops[i], operators
                )::ResultOk{Vector{T}}
            )
        end
    end
end
@inline function deg0_eval_constant(tree::Node{T}) where {T<:Number}
    output = tree.val::T
    return ResultOk([output], true)::ResultOk{Vector{T}}
end
function deg1_eval_constant(tree::Node{T}, op::F, operators::OperatorEnum) where {T<:Number,F}
    result = _eval_constant_tree(tree.l, operators)
    !result.ok && return result
    output = op(result.x[])::T
    return ResultOk([output], isfinite(output))::ResultOk{Vector{T}}
end
function deg2_eval_constant(tree::Node{T}, op::F, operators::OperatorEnum) where {T<:Number,F}
    cumulator = _eval_constant_tree(tree.l, operators)
    !cumulator.ok && return cumulator
    result_r = _eval_constant_tree(tree.r, operators)
    !result_r.ok && return result_r
    output = op(cumulator.x[], result_r.x[])::T
    return ResultOk([output], isfinite(output))::ResultOk{Vector{T}}
end
################################################################################

Now, we can see that the forward pass works okay:

# Operators to use:
operators = OperatorEnum((+, -, *, /), (cos, sin, exp, tanh))

# Variables:
x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3)

# Expression:
tree = Node(1, x1, Node(1, x2))  # == x1 + cos(x2)

# Input data
X = randn(3, 100);

# Output:
eval_tree_array(tree, X, operators, Val(2))

This evaluates x1 + cos(x2) over 100 random rows. Both Val(1) and Val(2) (fuse_level=1 and =2, respectively) here will work and produce the same output. Please see with ctrl-F which parts it activates in the code – it basically just turns on a couple branches related to "fused" operators (e.g., sin(exp(x)) evaluated over the data inside a single loop).

Now, if I try compiling the reverse-mode gradient with respect to the input data for fuse-level 1:

f(tree, X, operators, output) = (output[] = sum(eval_tree_array(tree, X, operators, Val(1))[1]); nothing)
dX = Enzyme.make_zero(X)
output = [0.0]
doutput = [1.0]

autodiff(
    Reverse,
    f,
    Const(tree),
    Duplicated(X, dX),
    Const(operators),
    Duplicated(output, doutput)
)

This takes about 1 minute to compile. But, once it's compiled, it's pretty fast.

However, if I switch on some of the optimizations (fuse_level=2):

f(tree, X, operators, output) = (output[] = sum(eval_tree_array(tree, X, operators, Val(2))[1]); nothing)

output = [0.0]
doutput = [1.0]

autodiff(
    Reverse,
    f,
    Const(tree),
    Duplicated(X, dX),
    Const(operators),
    Duplicated(output, doutput)
)

This seems to hang forever. I left it going for about a day and came back and it was still running. I'm assuming it will finish eventually, but it's obviously not a good solution as the existing AD backends with forward-mode auto-diff compile in under a second. And if the user changes data types or operators, it will need to recompile again.

If I force it to quit with ctl-\, I see various LLVM calls:

[68568] signal (3): Quit: 3
in expression starting at REPL[44]:1
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
_ZN4llvm22MustBeExecutedIterator7advanceEv at /Users/mcranmer/.julia/juliaup/julia-1.10.0-rc1+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib (unknown line)
unknown function (ip: 0x0)
Allocations: 37668296 (Pool: 37621782; Big: 46514); GC: 45

Any idea how to get this scaling better? It seems like some step of the compilation is hanging here and it is scaling exponentially with the number of branches.


Edit: Updated code MWE to further reduce it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions