-
Notifications
You must be signed in to change notification settings - Fork 82
Description
The old issue #1018 from August was getting a bit lengthy so I'm moving to a new issue – feel free to close that one.
This relates to my one-year effort to try to get Enzyme.jl working as one of the AD backends for DynamicExpressions.jl, SymbolicRegression.jl, and PySR. The current status is:
- Working first-order gradients, if I disable some of the optimizations
- Hanging first-order gradients (extremely long compilation time), if all the optimizations are left on
- Hanging second-order gradients (extremely long compilation time), regardless of optimization settings
(expand) I've boiled down the MWE to the following code which replicates the issues I am seeing:
using Enzyme
################################################################################
### OperatorEnum.jl
################################################################################
struct OperatorEnum{B,U}
binops::B
unaops::U
end
################################################################################
################################################################################
### Equation.jl
################################################################################
mutable struct Node{T}
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
constant::Bool # false if variable
val::Union{T,Nothing} # If is a constant, this stores the actual value
# ------------------- (possibly undefined below)
feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
l::Node{T} # Left child node. Only defined for degree=1 or degree=2.
r::Node{T} # Right child node. Only defined for degree=2.
Node(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
Node(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
Node(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f))
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l)
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}, r::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r)
end
function Node(::Type{T}; val::T1=nothing, feature::T2=nothing)::Node{T} where {T,T1,T2}
if T2 <: Nothing
!(T1 <: T) && (val = convert(T, val))
return Node(T, 0, true, val)
else
return Node(T, 0, false, nothing, feature)
end
end
Node(op::Integer, l::Node{T}) where {T} = Node(1, false, nothing, 0, op, l)
Node(op::Integer, l::Node{T}, r::Node{T}) where {T} = Node(2, false, nothing, 0, op, l, r)
################################################################################
################################################################################
### Utils.jl
################################################################################
@inline function fill_similar(value, array, args...)
out_array = similar(array, args...)
out_array .= value
return out_array
end
is_bad_array(array) = !(isempty(array) || isfinite(sum(array)))
function is_constant(tree::Node)
if tree.degree == 0
return tree.constant
elseif tree.degree == 1
return is_constant(tree.l)
else
return is_constant(tree.l) && is_constant(tree.r)
end
end
################################################################################
################################################################################
### EvaluateEquation.jl
################################################################################
struct ResultOk{A<:AbstractArray}
x::A
ok::Bool
end
function eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, fuse_level=Val(2)) where {T<:Number}
result = _eval_tree_array(tree, cX, operators, fuse_level)
return (result.x, result.ok && !is_bad_array(result.x))
end
counttuple(::Type{<:NTuple{N,Any}}) where {N} = N
get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U)
get_nbin(::Type{<:OperatorEnum{B}}) where {B} = counttuple(B)
@generated function _eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{fuse_level})::ResultOk where {T<:Number,fuse_level}
nuna = get_nuna(operators)
nbin = get_nbin(operators)
quote
# First, we see if there are only constants in the tree - meaning
# we can just return the constant result.
if tree.degree == 0
return deg0_eval(tree, cX)
elseif is_constant(tree)
# Speed hack for constant trees.
const_result = _eval_constant_tree(tree, operators)::ResultOk{Vector{T}}
!const_result.ok && return ResultOk(similar(cX, axes(cX, 2)), false)
return ResultOk(fill_similar(const_result.x[], cX, axes(cX, 2)), true)
elseif tree.degree == 1
op_idx = tree.op
# This @nif lets us generate an if statement over choice of operator,
# which means the compiler will be able to completely avoid type inference on operators.
return Base.Cartesian.@nif(
$nuna,
i -> i == op_idx,
i -> let op = operators.unaops[i]
if fuse_level > 1 && tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
# op(op2(x, y)), where x, y, z are constants or variables.
l_op_idx = tree.l.op
Base.Cartesian.@nif(
$nbin,
j -> j == l_op_idx,
j -> let op_l = operators.binops[j]
deg1_l2_ll0_lr0_eval(tree, cX, op, op_l)
end,
)
elseif fuse_level > 1 && tree.l.degree == 1 && tree.l.l.degree == 0
# op(op2(x)), where x is a constant or variable.
l_op_idx = tree.l.op
Base.Cartesian.@nif(
$nuna,
j -> j == l_op_idx,
j -> let op_l = operators.unaops[j]
deg1_l1_ll0_eval(tree, cX, op, op_l)
end,
)
else
# op(x), for any x.
result = _eval_tree_array(tree.l, cX, operators, Val(fuse_level))
!result.ok && return result
deg1_eval(result.x, op)
end
end
)
else
op_idx = tree.op
return Base.Cartesian.@nif(
$nbin,
i -> i == op_idx,
i -> let op = operators.binops[i]
if fuse_level > 1 && tree.l.degree == 0 && tree.r.degree == 0
deg2_l0_r0_eval(tree, cX, op)
elseif tree.r.degree == 0
result_l = _eval_tree_array(tree.l, cX, operators, Val(fuse_level))
!result_l.ok && return result_l
# op(x, y), where y is a constant or variable but x is not.
deg2_r0_eval(tree, result_l.x, cX, op)
elseif tree.l.degree == 0
result_r = _eval_tree_array(tree.r, cX, operators, Val(fuse_level))
!result_r.ok && return result_r
# op(x, y), where x is a constant or variable but y is not.
deg2_l0_eval(tree, result_r.x, cX, op)
else
result_l = _eval_tree_array(tree.l, cX, operators, Val(fuse_level))
!result_l.ok && return result_l
result_r = _eval_tree_array(tree.r, cX, operators, Val(fuse_level))
!result_r.ok && return result_r
# op(x, y), for any x or y
deg2_eval(result_l.x, result_r.x, op)
end
end
)
end
end
end
function deg2_eval(
cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F
)::ResultOk where {T<:Number,F}
@inbounds @simd for j in eachindex(cumulator_l)
x = op(cumulator_l[j], cumulator_r[j])::T
cumulator_l[j] = x
end
return ResultOk(cumulator_l, true)
end
function deg1_eval(
cumulator::AbstractVector{T}, op::F
)::ResultOk where {T<:Number,F}
@inbounds @simd for j in eachindex(cumulator)
x = op(cumulator[j])::T
cumulator[j] = x
end
return ResultOk(cumulator, true)
end
function deg0_eval(tree::Node{T}, cX::AbstractMatrix{T})::ResultOk where {T<:Number}
if tree.constant
return ResultOk(fill_similar(tree.val::T, cX, axes(cX, 2)), true)
else
return ResultOk(cX[tree.feature, :], true)
end
end
function deg1_l2_ll0_lr0_eval(
tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2
) where {T<:Number,F,F2}
if tree.l.l.constant && tree.l.r.constant
val_ll = tree.l.l.val::T
val_lr = tree.l.r.val::T
x_l = op_l(val_ll, val_lr)::T
x = op(x_l)::T
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
elseif tree.l.l.constant
val_ll = tree.l.l.val::T
feature_lr = tree.l.r.feature
cumulator = similar(cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(val_ll, cX[feature_lr, j])::T
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
cumulator[j] = x
end
return ResultOk(cumulator, true)
elseif tree.l.r.constant
feature_ll = tree.l.l.feature
val_lr = tree.l.r.val::T
cumulator = similar(cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j], val_lr)::T
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
cumulator[j] = x
end
return ResultOk(cumulator, true)
else
feature_ll = tree.l.l.feature
feature_lr = tree.l.r.feature
cumulator = similar(cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
cumulator[j] = x
end
return ResultOk(cumulator, true)
end
end
# op(op2(x)) for x variable or constant
function deg1_l1_ll0_eval(
tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2
) where {T<:Number,F,F2}
if tree.l.l.constant
val_ll = tree.l.l.val::T
x_l = op_l(val_ll)::T
x = op(x_l)::T
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
else
feature_ll = tree.l.l.feature
cumulator = similar(cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j])::T
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
cumulator[j] = x
end
return ResultOk(cumulator, true)
end
end
# op(x, y) for x and y variable/constant
function deg2_l0_r0_eval(
tree::Node{T}, cX::AbstractMatrix{T}, op::F
) where {T<:Number,F}
if tree.l.constant && tree.r.constant
val_l = tree.l.val::T
val_r = tree.r.val::T
x = op(val_l, val_r)::T
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
elseif tree.l.constant
cumulator = similar(cX, axes(cX, 2))
val_l = tree.l.val::T
feature_r = tree.r.feature
@inbounds @simd for j in axes(cX, 2)
x = op(val_l, cX[feature_r, j])::T
cumulator[j] = x
end
return ResultOk(cumulator, true)
elseif tree.r.constant
cumulator = similar(cX, axes(cX, 2))
feature_l = tree.l.feature
val_r = tree.r.val::T
@inbounds @simd for j in axes(cX, 2)
x = op(cX[feature_l, j], val_r)::T
cumulator[j] = x
end
return ResultOk(cumulator, true)
else
cumulator = similar(cX, axes(cX, 2))
feature_l = tree.l.feature
feature_r = tree.r.feature
@inbounds @simd for j in axes(cX, 2)
x = op(cX[feature_l, j], cX[feature_r, j])::T
cumulator[j] = x
end
return ResultOk(cumulator, true)
end
end
# op(x, y) for x variable/constant, y arbitrary
function deg2_l0_eval(
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F
) where {T<:Number,F}
if tree.l.constant
val = tree.l.val::T
@inbounds @simd for j in eachindex(cumulator)
x = op(val, cumulator[j])::T
cumulator[j] = x
end
return ResultOk(cumulator, true)
else
feature = tree.l.feature
@inbounds @simd for j in eachindex(cumulator)
x = op(cX[feature, j], cumulator[j])::T
cumulator[j] = x
end
return ResultOk(cumulator, true)
end
end
# op(x, y) for x arbitrary, y variable/constant
function deg2_r0_eval(
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F
) where {T<:Number,F}
if tree.r.constant
val = tree.r.val::T
@inbounds @simd for j in eachindex(cumulator)
x = op(cumulator[j], val)::T
cumulator[j] = x
end
return ResultOk(cumulator, true)
else
feature = tree.r.feature
@inbounds @simd for j in eachindex(cumulator)
x = op(cumulator[j], cX[feature, j])::T
cumulator[j] = x
end
return ResultOk(cumulator, true)
end
end
@generated function _eval_constant_tree(tree::Node{T}, operators::OperatorEnum) where {T<:Number}
nuna = get_nuna(operators)
nbin = get_nbin(operators)
quote
if tree.degree == 0
return deg0_eval_constant(tree)::ResultOk{Vector{T}}
elseif tree.degree == 1
op_idx = tree.op
return Base.Cartesian.@nif(
$nuna,
i -> i == op_idx,
i -> deg1_eval_constant(
tree, operators.unaops[i], operators
)::ResultOk{Vector{T}}
)
else
op_idx = tree.op
return Base.Cartesian.@nif(
$nbin,
i -> i == op_idx,
i -> deg2_eval_constant(
tree, operators.binops[i], operators
)::ResultOk{Vector{T}}
)
end
end
end
@inline function deg0_eval_constant(tree::Node{T}) where {T<:Number}
output = tree.val::T
return ResultOk([output], true)::ResultOk{Vector{T}}
end
function deg1_eval_constant(tree::Node{T}, op::F, operators::OperatorEnum) where {T<:Number,F}
result = _eval_constant_tree(tree.l, operators)
!result.ok && return result
output = op(result.x[])::T
return ResultOk([output], isfinite(output))::ResultOk{Vector{T}}
end
function deg2_eval_constant(tree::Node{T}, op::F, operators::OperatorEnum) where {T<:Number,F}
cumulator = _eval_constant_tree(tree.l, operators)
!cumulator.ok && return cumulator
result_r = _eval_constant_tree(tree.r, operators)
!result_r.ok && return result_r
output = op(cumulator.x[], result_r.x[])::T
return ResultOk([output], isfinite(output))::ResultOk{Vector{T}}
end
################################################################################Now, we can see that the forward pass works okay:
# Operators to use:
operators = OperatorEnum((+, -, *, /), (cos, sin, exp, tanh))
# Variables:
x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3)
# Expression:
tree = Node(1, x1, Node(1, x2)) # == x1 + cos(x2)
# Input data
X = randn(3, 100);
# Output:
eval_tree_array(tree, X, operators, Val(2))This evaluates x1 + cos(x2) over 100 random rows. Both Val(1) and Val(2) (fuse_level=1 and =2, respectively) here will work and produce the same output. Please see with ctrl-F which parts it activates in the code – it basically just turns on a couple branches related to "fused" operators (e.g., sin(exp(x)) evaluated over the data inside a single loop).
Now, if I try compiling the reverse-mode gradient with respect to the input data for fuse-level 1:
f(tree, X, operators, output) = (output[] = sum(eval_tree_array(tree, X, operators, Val(1))[1]); nothing)
dX = Enzyme.make_zero(X)
output = [0.0]
doutput = [1.0]
autodiff(
Reverse,
f,
Const(tree),
Duplicated(X, dX),
Const(operators),
Duplicated(output, doutput)
)This takes about 1 minute to compile. But, once it's compiled, it's pretty fast.
However, if I switch on some of the optimizations (fuse_level=2):
f(tree, X, operators, output) = (output[] = sum(eval_tree_array(tree, X, operators, Val(2))[1]); nothing)
output = [0.0]
doutput = [1.0]
autodiff(
Reverse,
f,
Const(tree),
Duplicated(X, dX),
Const(operators),
Duplicated(output, doutput)
)This seems to hang forever. I left it going for about a day and came back and it was still running. I'm assuming it will finish eventually, but it's obviously not a good solution as the existing AD backends with forward-mode auto-diff compile in under a second. And if the user changes data types or operators, it will need to recompile again.
If I force it to quit with ctl-\, I see various LLVM calls:
[68568] signal (3): Quit: 3
in expression starting at REPL[44]:1
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
__psynch_cvwait at /usr/lib/system/libsystem_kernel.dylib (unknown line)
unknown function (ip: 0x0)
_ZN4llvm22MustBeExecutedIterator7advanceEv at /Users/mcranmer/.julia/juliaup/julia-1.10.0-rc1+0.aarch64.apple.darwin14/lib/julia/libLLVM.dylib (unknown line)
unknown function (ip: 0x0)
Allocations: 37668296 (Pool: 37621782; Big: 46514); GC: 45
Any idea how to get this scaling better? It seems like some step of the compilation is hanging here and it is scaling exponentially with the number of branches.
Edit: Updated code MWE to further reduce it.