Skip to content

Commit 17f04ad

Browse files
authored
Merge pull request #30 from SymbolicML/constant-optimization
Overload `Optim.optimize` for `::Node`
2 parents 160c30a + 80c370f commit 17f04ad

File tree

4 files changed

+224
-1
lines changed

4 files changed

+224
-1
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1616
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1717

1818
[weakdeps]
19+
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1920
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
2021
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2122

2223
[extensions]
24+
DynamicExpressionsOptimExt = "Optim"
2325
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
2426
DynamicExpressionsZygoteExt = "Zygote"
2527

@@ -28,6 +30,7 @@ Aqua = "0.7"
2830
Compat = "3.37, 4"
2931
Enzyme = "^0.11.12"
3032
LoopVectorization = "0.12"
33+
Optim = "0.19, 1"
3134
MacroTools = "0.4, 0.5"
3235
PackageExtensionCompat = "1"
3336
PrecompileTools = "1"
@@ -40,6 +43,7 @@ julia = "1.6"
4043
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4144
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4245
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
46+
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
4347
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4448
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4549
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -48,4 +52,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4852
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4953

5054
[targets]
51-
test = ["Test", "SafeTestsets", "Aqua", "Enzyme", "ForwardDiff", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]
55+
test = ["Test", "SafeTestsets", "Aqua", "Enzyme", "Optim", "ForwardDiff", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]

ext/DynamicExpressionsOptimExt.jl

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
module DynamicExpressionsOptimExt
2+
3+
using DynamicExpressions: AbstractExpressionNode, eval_tree_array
4+
using Compat: @inline
5+
6+
import Optim: Optim, OptimizationResults, NLSolversBase
7+
8+
#! format: off
9+
"""
10+
ExpressionOptimizationResults{R,N<:AbstractExpressionNode}
11+
12+
Optimization results for an expression, which wraps the base optimization results
13+
on a vector of constants.
14+
"""
15+
struct ExpressionOptimizationResults{R<:OptimizationResults,N<:AbstractExpressionNode} <: OptimizationResults
16+
_results::R # The raw results from Optim.
17+
tree::N # The final expression tree
18+
end
19+
#! format: on
20+
function Base.getproperty(r::ExpressionOptimizationResults, s::Symbol)
21+
if s == :tree || s == :minimizer
22+
return getfield(r, :tree)
23+
else
24+
return getproperty(getfield(r, :_results), s)
25+
end
26+
end
27+
function Base.propertynames(r::ExpressionOptimizationResults)
28+
return (:tree, propertynames(getfield(r, :_results))...)
29+
end
30+
function Optim.minimizer(r::ExpressionOptimizationResults)
31+
return r.tree
32+
end
33+
34+
function set_constant_nodes!(
35+
constant_nodes::AbstractArray{N}, x
36+
) where {T,N<:AbstractExpressionNode{T}}
37+
for (ci, xi) in zip(constant_nodes, x)
38+
ci.val::T = xi::T
39+
end
40+
return nothing
41+
end
42+
43+
"""Wrap function or objective with insertion of values of the constant nodes."""
44+
function wrap_func(
45+
f::F, tree::N, constant_nodes::AbstractArray{N}
46+
) where {F<:Function,T,N<:AbstractExpressionNode{T}}
47+
function wrapped_f(args::Vararg{Any,M}) where {M}
48+
first_args = args[1:(end - 1)]
49+
x = last(args)
50+
set_constant_nodes!(constant_nodes, x)
51+
return @inline(f(first_args..., tree))
52+
end
53+
return wrapped_f
54+
end
55+
function wrap_func(
56+
::Nothing, tree::N, constant_nodes::AbstractArray{N}
57+
) where {N<:AbstractExpressionNode}
58+
return nothing
59+
end
60+
function wrap_func(
61+
f::NLSolversBase.InplaceObjective, tree::N, constant_nodes::AbstractArray{N}
62+
) where {N<:AbstractExpressionNode}
63+
# Some objectives, like `Optim.only_fg!(fg!)`, are not functions but instead
64+
# `InplaceObjective`. These contain multiple functions, each of which needs to be
65+
# wrapped. Some functions are `nothing`; those can be left as-is.
66+
@assert fieldnames(NLSolversBase.InplaceObjective) == (:df, :fdf, :fgh, :hv, :fghv)
67+
return NLSolversBase.InplaceObjective(
68+
wrap_func(f.df, tree, constant_nodes),
69+
wrap_func(f.fdf, tree, constant_nodes),
70+
wrap_func(f.fgh, tree, constant_nodes),
71+
wrap_func(f.hv, tree, constant_nodes),
72+
wrap_func(f.fghv, tree, constant_nodes),
73+
)
74+
end
75+
76+
"""
77+
optimize(f, [g!, [h!,]] tree, args...; kwargs...)
78+
79+
Optimize an expression tree with respect to the constants in the tree.
80+
Returns an `ExpressionOptimizationResults` object, which wraps the base
81+
optimization results on a vector of constants. You may use `res.minimizer`
82+
to view the optimized expression tree.
83+
"""
84+
function Optim.optimize(f::F, tree::AbstractExpressionNode, args...; kwargs...) where {F}
85+
return Optim.optimize(f, nothing, tree, args...; kwargs...)
86+
end
87+
function Optim.optimize(
88+
f::F, g!::G, tree::AbstractExpressionNode, args...; kwargs...
89+
) where {F,G}
90+
return Optim.optimize(f, g!, nothing, tree, args...; kwargs...)
91+
end
92+
function Optim.optimize(
93+
f::F, g!::G, h!::H, tree::AbstractExpressionNode{T}, args...; make_copy=true, kwargs...
94+
) where {F,G,H,T}
95+
if make_copy
96+
tree = copy(tree)
97+
end
98+
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree)
99+
x0 = T[t.val::T for t in constant_nodes]
100+
if !isnothing(h!)
101+
throw(
102+
ArgumentError(
103+
"Optim.optimize does not yet support Hessians on `AbstractExpressionNode`. " *
104+
"Please raise an issue at github.com/SymbolicML/DynamicExpressions.jl.",
105+
),
106+
)
107+
end
108+
base_res = if isnothing(g!)
109+
Optim.optimize(wrap_func(f, tree, constant_nodes), x0, args...; kwargs...)
110+
else
111+
Optim.optimize(
112+
wrap_func(f, tree, constant_nodes),
113+
wrap_func(g!, tree, constant_nodes),
114+
x0,
115+
args...;
116+
kwargs...,
117+
)
118+
end
119+
set_constant_nodes!(constant_nodes, Optim.minimizer(base_res))
120+
return ExpressionOptimizationResults(base_res, tree)
121+
end
122+
123+
end

test/test_optim.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
using DynamicExpressions, Optim, Zygote
2+
using Random: MersenneTwister as RNG
3+
4+
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,))
5+
x1, x2 = (i -> Node(Float64; feature=i)).(1:2)
6+
7+
X = rand(RNG(0), Float64, 2, 100)
8+
y = @. exp(X[1, :] * 2.1 - 0.9) + X[2, :] * -0.9
9+
10+
original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2
11+
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2
12+
13+
f(tree) = sum(abs2, tree(X, operators) .- y)
14+
15+
@testset "Basic optimization" begin
16+
tree = copy(original_tree)
17+
res = optimize(f, tree)
18+
19+
# Should be unchanged by default
20+
if VERSION >= v"1.9"
21+
ext = Base.get_extension(DynamicExpressions, :DynamicExpressionsOptimExt)
22+
@test res isa ext.ExpressionOptimizationResults
23+
end
24+
@test tree == original_tree
25+
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
26+
end
27+
28+
@testset "With gradients" begin
29+
tree = copy(original_tree)
30+
did_i_run = Ref(false)
31+
# Now, try with gradients too (via Zygote and our hand-rolled forward-mode AD)
32+
g!(G, tree) =
33+
let
34+
ŷ, dŷ_dconstants, _ = eval_grad_tree_array(tree, X, operators; variable=false)
35+
dresult_dŷ = @. 2 * (ŷ - y)
36+
for i in eachindex(G)
37+
G[i] = sum(
38+
j -> dresult_dŷ[j] * dŷ_dconstants[i, j],
39+
eachindex(axes(dŷ_dconstants, 2), axes(dresult_dŷ, 1)),
40+
)
41+
end
42+
did_i_run[] = true
43+
return nothing
44+
end
45+
46+
res = optimize(f, g!, tree, BFGS())
47+
@test did_i_run[]
48+
@test res.f_calls > 0
49+
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
50+
@test Optim.minimizer(res) === res.minimizer
51+
@test propertynames(res) == (:tree, propertynames(getfield(res, :_results))...)
52+
53+
@testset "Hessians not implemented" begin
54+
@test_throws ArgumentError optimize(f, g!, t -> t, tree, BFGS())
55+
VERSION >= v"1.9" && @test_throws(
56+
"Optim.optimize does not yet support Hessians on `AbstractExpressionNode`",
57+
optimize(f, g!, t -> t, tree, BFGS())
58+
)
59+
end
60+
end
61+
62+
# Now, try combined
63+
@testset "Combined evaluation with gradient" begin
64+
tree = copy(original_tree)
65+
did_i_run_2 = Ref(false)
66+
fg!(F, G, tree) =
67+
let
68+
if G !== nothing
69+
ŷ, dŷ_dconstants, _ = eval_grad_tree_array(
70+
tree, X, operators; variable=false
71+
)
72+
dresult_dŷ = @. 2 * (ŷ - y)
73+
for i in eachindex(G)
74+
G[i] = sum(
75+
j -> dresult_dŷ[j] * dŷ_dconstants[i, j],
76+
eachindex(axes(dŷ_dconstants, 2), axes(dresult_dŷ, 1)),
77+
)
78+
end
79+
if F !== nothing
80+
did_i_run_2[] = true
81+
return sum(abs2, ŷ .- y)
82+
end
83+
elseif F !== nothing
84+
# Only f
85+
return sum(abs2, tree(X, operators) .- y)
86+
end
87+
end
88+
res = optimize(Optim.only_fg!(fg!), tree, BFGS())
89+
90+
@test did_i_run_2[]
91+
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
92+
end

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ end
1212
include("test_deprecations.jl")
1313
end
1414

15+
@safetestset "Test Optim.jl" begin
16+
include("test_optim.jl")
17+
end
18+
1519
@safetestset "Test tree construction and scoring" begin
1620
include("test_tree_construction.jl")
1721
end

0 commit comments

Comments
 (0)