Skip to content

Commit 05079c8

Browse files
committed
Set up fuse_level option for Enzyme compatibility
1 parent fed4a69 commit 05079c8

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

src/ConstantOptimization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function dispatch_optimize_constants(
5656
algorithm,
5757
options.optimizer_options,
5858
idx,
59-
options.enable_enzyme ? Val(true) : Val(false),
59+
options.v_enable_enzyme,
6060
)
6161
end
6262
if T <: Complex

src/InterfaceDynamicExpressions.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ which speed up evaluation significantly.
5353
to the equation.
5454
"""
5555
function eval_tree_array(tree::Node, X::AbstractArray, options::Options; kws...)
56-
return eval_tree_array(tree, X, options.operators; turbo=options.v_turbo, kws...)
56+
fuse_level = options.v_enable_enzyme === Val(true) ? Val(1) : Val(2)
57+
return eval_tree_array(
58+
tree, X, options.operators; turbo=options.v_turbo, fuse_level, kws...
59+
)
5760
end
5861

5962
"""

src/Options.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,8 @@ function Options end
732732
mutation_weights
733733
end
734734

735+
v_enable_enzyme = Val(enable_enzyme)
736+
735737
@assert print_precision > 0
736738

737739
options = Options{
@@ -741,6 +743,7 @@ function Options end
741743
typeof(optimizer_options),
742744
typeof(tournament_selection_weights),
743745
turbo,
746+
enable_enzyme,
744747
}(
745748
operators,
746749
bin_constraints,
@@ -802,7 +805,7 @@ function Options end
802805
nested_constraints,
803806
deterministic,
804807
define_helper_functions,
805-
enable_enzyme,
808+
v_enable_enzyme,
806809
)
807810

808811
return options

src/OptionsStruct.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ function ComplexityMapping(;
129129
)
130130
end
131131

132-
struct Options{CT,OP<:AbstractOperatorEnum,use_recorder,OPT<:Optim.Options,W,_turbo}
132+
struct Options{
133+
CT,OP<:AbstractOperatorEnum,use_recorder,OPT<:Optim.Options,W,_turbo,enable_enzyme
134+
}
133135
operators::OP
134136
bin_constraints::Vector{Tuple{Int,Int}}
135137
una_constraints::Vector{Int}
@@ -190,7 +192,7 @@ struct Options{CT,OP<:AbstractOperatorEnum,use_recorder,OPT<:Optim.Options,W,_tu
190192
nested_constraints::Union{Vector{Tuple{Int,Int,Vector{Tuple{Int,Int,Int}}}},Nothing}
191193
deterministic::Bool
192194
define_helper_functions::Bool
193-
enable_enzyme::Bool
195+
v_enable_enzyme::Val{enable_enzyme}
194196
end
195197

196198
function Base.print(io::IO, options::Options)

0 commit comments

Comments
 (0)