Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,20 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[extensions]
SymbolicRegressionEnzymeExt = "Enzyme"
SymbolicRegressionJSON3Ext = "JSON3"
SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils"

[compat]
Compat = "^4.2"
DynamicExpressions = "0.13"
DynamicExpressions = "0.14"
DynamicQuantities = "^0.6.2"
Enzyme = "0.11"
JSON3 = "1"
LineSearches = "7"
LossFunctions = "0.10, 0.11"
Expand All @@ -57,6 +60,7 @@ julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -69,4 +73,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets", "Aqua", "ForwardDiff", "LinearAlgebra", "JSON3", "MLJBase", "MLJTestInterface", "Suppressor", "SymbolicUtils", "Zygote"]
test = ["Test", "SafeTestsets", "Aqua", "Enzyme", "ForwardDiff", "LinearAlgebra", "JSON3", "MLJBase", "MLJTestInterface", "Suppressor", "SymbolicUtils", "Zygote"]
40 changes: 40 additions & 0 deletions ext/SymbolicRegressionEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module SymbolicRegressionEnzymeExt

import Enzyme: autodiff, Duplicated, Const, Reverse
import SymbolicRegression: Dataset, Options
import SymbolicRegression.ConstantOptimizationModule: opt_func!, opt_func_g!

@inline function opt_func_g!(
x,
dx,
dataset::Dataset{T,L},
tree,
ctree,
constant_nodes,
c_constant_nodes,
options::Options,
idx,
) where {T,L}
result = [zero(L)]
dresult = [one(L)]
fill!(dx, one(T))
foreach(ctree) do t
if t.degree == 0 && t.constant
t.val::T = zero(T)
end
end
autodiff(
Reverse,
opt_func!,
Duplicated(result, dresult),
Duplicated(x, dx),
Const(dataset),
Duplicated(tree, ctree),
Duplicated(constant_nodes, c_constant_nodes),
Const(options),
Const(idx),
)
return nothing
end

end
77 changes: 56 additions & 21 deletions src/ConstantOptimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import ..UtilsModule: get_birth_order
import ..LossFunctionsModule: score_func, eval_loss, batch_sample
import ..PopMemberModule: PopMember

opt_func_g!(args...) = error("Please load the Enzyme.jl package.")

# Proxy function for optimization
@inline function opt_func(
x, dataset::Dataset{T,L}, tree, constant_nodes, options, idx
Expand All @@ -17,6 +19,10 @@ import ..PopMemberModule: PopMember
loss = eval_loss(tree, dataset, options; regularization=false, idx=idx)
return loss::L
end
function opt_func!(result, x, dataset, tree, constant_nodes, options, idx)
result[1] = opt_func(x, dataset, tree, constant_nodes, options, idx)
return nothing
end

function _set_constants!(x::AbstractArray{T}, constant_nodes) where {T}
for (xi, node) in zip(x, constant_nodes)
Expand All @@ -42,48 +48,77 @@ function dispatch_optimize_constants(
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
nconst = count_constants(member.tree)
nconst == 0 && return (member, 0.0)
if T <: Complex
# TODO: Make this more general. Also, do we even need Newton here at all??
algorithm = Optim.BFGS(; linesearch=LineSearches.BackTracking())#order=3))
function call_opt(algorithm)
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
dataset,
member,
options,
algorithm,
options.optimizer_options,
idx,
options.v_enable_enzyme,
)
end
if T <: Complex
# TODO: Make this more general. Also, do we even need Newton here at all??
return call_opt(Optim.BFGS(; linesearch=LineSearches.BackTracking()))
elseif nconst == 1
algorithm = Optim.Newton(; linesearch=LineSearches.BackTracking())
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
)
return call_opt(Optim.Newton(; linesearch=LineSearches.BackTracking()))
else
if options.optimizer_algorithm == "NelderMead"
algorithm = Optim.NelderMead(; linesearch=LineSearches.BackTracking())
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
)
return call_opt(Optim.NelderMead(; linesearch=LineSearches.BackTracking()))
elseif options.optimizer_algorithm == "BFGS"
algorithm = Optim.BFGS(; linesearch=LineSearches.BackTracking())#order=3))
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
)
return call_opt(Optim.BFGS(; linesearch=LineSearches.BackTracking()))
else
error("Optimization function not implemented.")
end
end
end

function _optimize_constants(
dataset, member::PopMember{T,L}, options, algorithm, optimizer_options, idx
)::Tuple{PopMember{T,L},Float64} where {T,L}
dataset,
member::PopMember{T,L},
options,
algorithm,
optimizer_options,
idx,
::Val{use_autodiff},
)::Tuple{PopMember{T,L},Float64} where {T,L,use_autodiff}
tree = member.tree
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree)

ctree = use_autodiff ? copy(tree) : nothing
c_constant_nodes =
use_autodiff ? filter(t -> t.degree == 0 && t.constant, ctree) : nothing

x0 = [n.val::T for n in constant_nodes]
f(x) = opt_func(x, dataset, tree, constant_nodes, options, idx)
result = Optim.optimize(f, x0, algorithm, optimizer_options)
function f(x)
return opt_func(x, dataset, tree, constant_nodes, options, idx)
end
function g!(storage, x)
# TODO: Can we just use storage instead here?
dx = similar(x)
opt_func_g!(
x, dx, dataset, tree, ctree, constant_nodes, c_constant_nodes, options, idx
)
storage .= dx
return nothing
end
result = if use_autodiff
Optim.optimize(f, g!, x0, algorithm, optimizer_options)
else
Optim.optimize(f, x0, algorithm, optimizer_options)
end
num_evals = 0.0
num_evals += result.f_calls
# Try other initial conditions:
for i in 1:(options.optimizer_nrestarts)
new_start = x0 .* (T(1) .+ T(1//2) * randn(T, size(x0, 1)))
tmpresult = Optim.optimize(f, new_start, algorithm, optimizer_options)
tmpresult = if use_autodiff
Optim.optimize(f, g!, new_start, algorithm, optimizer_options)
else
Optim.optimize(f, new_start, algorithm, optimizer_options)
end
num_evals += tmpresult.f_calls

if tmpresult.minimum < result.minimum
Expand Down
19 changes: 5 additions & 14 deletions src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import DynamicExpressions:
eval_diff_tree_array,
eval_grad_tree_array,
print_tree,
string_tree,
differentiable_eval_tree_array
string_tree
import DynamicQuantities: dimension, ustrip
import ..CoreModule: Options
import ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap
Expand Down Expand Up @@ -54,7 +53,10 @@ which speed up evaluation significantly.
to the equation.
"""
function eval_tree_array(tree::Node, X::AbstractArray, options::Options; kws...)
return eval_tree_array(tree, X, options.operators; turbo=options.turbo, kws...)
fuse_level = options.v_enable_enzyme === Val(true) ? Val(1) : Val(2)
return eval_tree_array(
tree, X, options.operators; turbo=options.v_turbo, fuse_level, kws...
)
end

"""
Expand Down Expand Up @@ -112,17 +114,6 @@ function eval_grad_tree_array(tree::Node, X::AbstractArray, options::Options; kw
return eval_grad_tree_array(tree, X, options.operators; kws...)
end

"""
differentiable_eval_tree_array(tree::Node, X::AbstractArray, options::Options)

Evaluate an expression tree in a way that can be auto-differentiated.
"""
function differentiable_eval_tree_array(
tree::Node, X::AbstractArray, options::Options; kws...
)
return differentiable_eval_tree_array(tree, X, options.operators; kws...)
end

"""
string_tree(tree::Node, options::Options; kws...)

Expand Down
13 changes: 3 additions & 10 deletions src/LossFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module LossFunctionsModule
import Random: MersenneTwister
using StatsBase: StatsBase
import DynamicExpressions: Node
using LossFunctions: LossFunctions
import LossFunctions: SupervisedLoss
import ..InterfaceDynamicExpressionsModule: eval_tree_array
import ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE
Expand All @@ -13,22 +12,16 @@ import ..DimensionalAnalysisModule: violates_dimensional_constraints
function _loss(
x::AbstractArray{T}, y::AbstractArray{T}, loss::LT
) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}}
if LT <: SupervisedLoss
return LossFunctions.mean(loss, x, y)
else
l(i) = loss(x[i], y[i])
return LossFunctions.mean(l, eachindex(x))
end
return sum(@. loss(x, y)) / length(x)
end

function _weighted_loss(
x::AbstractArray{T}, y::AbstractArray{T}, w::AbstractArray{T}, loss::LT
) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}}
if LT <: SupervisedLoss
return LossFunctions.sum(loss, x, y, w; normalize=true)
return sum(@. loss(x, y) * w) / sum(w)
else
l(i) = loss(x[i], y[i], w[i])
return sum(l, eachindex(x)) / sum(w)
return sum(@. loss(x, y, w)) / sum(w)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to avoid this, but it seems like sum is using the dataset array for temporary storage somehow? (Or Enzyme.jl thinks it is?)

end
end

Expand Down
20 changes: 15 additions & 5 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,6 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators
- `max_evals`: Int (or Nothing) - the maximum number of evaluations of expressions to perform.
- `skip_mutation_failures`: Whether to simply skip over mutations that fail or are rejected, rather than to replace the mutated
expression with the original expression and proceed normally.
- `enable_autodiff`: Whether to enable automatic differentiation functionality. This is turned off by default.
If turned on, this will be turned off if one of the operators does not have well-defined gradients.
- `nested_constraints`: Specifies how many times a combination of operators can be nested. For example,
`[sin => [cos => 0], cos => [cos => 2]]` specifies that `cos` may never appear within a `sin`,
but `sin` can be nested with itself an unlimited number of times. The second term specifies that `cos`
Expand Down Expand Up @@ -416,14 +414,15 @@ function Options end
timeout_in_seconds::Union{Nothing,Real}=nothing,
max_evals::Union{Nothing,Integer}=nothing,
skip_mutation_failures::Bool=true,
enable_autodiff::Bool=false,
enable_enzyme::Bool=false,
nested_constraints=nothing,
deterministic::Bool=false,
# Not search options; just construction options:
define_helper_functions::Bool=true,
deprecated_return_state=nothing,
# Deprecated args:
fast_cycle::Bool=false,
enable_autodiff=nothing,
npopulations::Union{Nothing,Integer}=nothing,
npop::Union{Nothing,Integer}=nothing,
kws...,
Expand Down Expand Up @@ -491,6 +490,13 @@ function Options end
Base.depwarn("`npopulations` is deprecated. Use `populations` instead.", :Options)
populations = npopulations
end
if enable_autodiff !== nothing
Base.depwarn(
"`enable_autodiff` is deprecated and has no effect. " *
"Simply loading `Zygote` will enable differentiation.",
:Options,
)
end

if elementwise_loss === nothing
elementwise_loss = L2DistLoss()
Expand Down Expand Up @@ -669,7 +675,6 @@ function Options end
OperatorEnum(;
binary_operators=binary_operators,
unary_operators=unary_operators,
enable_autodiff=false, # Not needed; we just want the constructors
define_helper_functions=true,
empty_old_operators=true,
)
Expand All @@ -681,7 +686,6 @@ function Options end
operators = OperatorEnum(;
binary_operators=binary_operators,
unary_operators=unary_operators,
enable_autodiff=enable_autodiff,
define_helper_functions=define_helper_functions,
empty_old_operators=false,
)
Expand Down Expand Up @@ -728,6 +732,8 @@ function Options end
mutation_weights
end

v_enable_enzyme = Val(enable_enzyme)

@assert print_precision > 0

options = Options{
Expand All @@ -736,6 +742,8 @@ function Options end
use_recorder,
typeof(optimizer_options),
typeof(tournament_selection_weights),
turbo,
enable_enzyme,
}(
operators,
bin_constraints,
Expand All @@ -750,6 +758,7 @@ function Options end
maxsize,
maxdepth,
turbo,
turbo ? Val(true) : Val(false),
migration,
hof_migration,
should_simplify,
Expand Down Expand Up @@ -796,6 +805,7 @@ function Options end
nested_constraints,
deterministic,
define_helper_functions,
v_enable_enzyme,
)

return options
Expand Down
6 changes: 5 additions & 1 deletion src/OptionsStruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ function ComplexityMapping(;
)
end

struct Options{CT,OP<:AbstractOperatorEnum,use_recorder,OPT<:Optim.Options,W}
struct Options{
CT,OP<:AbstractOperatorEnum,use_recorder,OPT<:Optim.Options,W,_turbo,enable_enzyme
}
operators::OP
bin_constraints::Vector{Tuple{Int,Int}}
una_constraints::Vector{Int}
Expand All @@ -143,6 +145,7 @@ struct Options{CT,OP<:AbstractOperatorEnum,use_recorder,OPT<:Optim.Options,W}
maxsize::Int
maxdepth::Int
turbo::Bool
v_turbo::Val{_turbo}
migration::Bool
hof_migration::Bool
should_simplify::Bool
Expand Down Expand Up @@ -189,6 +192,7 @@ struct Options{CT,OP<:AbstractOperatorEnum,use_recorder,OPT<:Optim.Options,W}
nested_constraints::Union{Vector{Tuple{Int,Int,Vector{Tuple{Int,Int,Int}}}},Nothing}
deterministic::Bool
define_helper_functions::Bool
v_enable_enzyme::Val{enable_enzyme}
end

function Base.print(io::IO, options::Options)
Expand Down
Loading