Skip to content

Commit 339c619

Browse files
committed
Get Optim extension working with non-Function objectives
1 parent f3ce75d commit 339c619

File tree

2 files changed

+108
-42
lines changed

2 files changed

+108
-42
lines changed

ext/DynamicExpressionsOptimExt.jl

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module DynamicExpressionsOptimExt
33
using DynamicExpressions: AbstractExpressionNode, eval_tree_array
44
using Compat: @inline
55

6-
import Optim: Optim, OptimizationResults
6+
import Optim: Optim, OptimizationResults, NLSolversBase
77

88
#! format: off
99
"""
@@ -40,29 +40,40 @@ function set_constant_nodes!(
4040
for (ci, xi) in zip(constant_nodes, x)
4141
ci.val::T = xi::T
4242
end
43+
return nothing
4344
end
4445

45-
"""Wrap f with insertion of values of the constant nodes."""
46-
function get_wrapped_f(
46+
"""Wrap function or objective with insertion of values of the constant nodes."""
47+
function wrap_func(
4748
f::F, tree::N, constant_nodes::AbstractArray{N}
48-
) where {F,T,N<:AbstractExpressionNode{T}}
49-
function wrapped_f(x)
49+
) where {F<:Function,T,N<:AbstractExpressionNode{T}}
50+
function wrapped_f(args::Vararg{Any,M}) where {M}
51+
first_args = args[1:(end - 1)]
52+
x = last(args)
5053
set_constant_nodes!(constant_nodes, x)
51-
return @inline(f(tree))
54+
return @inline(f(first_args..., tree))
5255
end
5356
return wrapped_f
5457
end
55-
56-
"""Wrap g! or h! with insertion of values of the constant nodes."""
57-
function get_wrapped_gh!(
58-
gh!::GH, tree::N, constant_nodes::AbstractArray{N}
59-
) where {GH,T,N<:AbstractExpressionNode{T}}
60-
function wrapped_gh!(G, x)
61-
set_constant_nodes!(constant_nodes, x)
62-
@inline(gh!(G, tree))
63-
return nothing
64-
end
65-
return wrapped_gh!
58+
function wrap_func(
59+
::Nothing, tree::N, constant_nodes::AbstractArray{N}
60+
) where {N<:AbstractExpressionNode}
61+
return nothing
62+
end
63+
function wrap_func(
64+
f::NLSolversBase.InplaceObjective, tree::N, constant_nodes::AbstractArray{N}
65+
) where {N<:AbstractExpressionNode}
66+
# Some objectives, like `Optim.only_fg!(fg!)`, are not functions but instead
67+
# `InplaceObjective`. These contain multiple functions, each of which needs to be
68+
# wrapped. Some functions are `nothing`; those can be left as-is.
69+
@assert fieldnames(NLSolversBase.InplaceObjective) == (:df, :fdf, :fgh, :hv, :fghv)
70+
return NLSolversBase.InplaceObjective(
71+
wrap_func(f.df, tree, constant_nodes),
72+
wrap_func(f.fdf, tree, constant_nodes),
73+
wrap_func(f.fgh, tree, constant_nodes),
74+
wrap_func(f.hv, tree, constant_nodes),
75+
wrap_func(f.fghv, tree, constant_nodes),
76+
)
6677
end
6778

6879
"""
@@ -73,40 +84,38 @@ Returns an `ExpressionOptimizationResults` object, which wraps the base
7384
optimization results on a vector of constants. You may use `res.minimizer`
7485
to view the optimized expression tree.
7586
"""
76-
function Optim.optimize(
77-
f::F, tree::AbstractExpressionNode, args...; kwargs...
78-
) where {F<:Function}
87+
function Optim.optimize(f::F, tree::AbstractExpressionNode, args...; kwargs...) where {F}
7988
return Optim.optimize(f, nothing, tree, args...; kwargs...)
8089
end
8190
function Optim.optimize(
8291
f::F, g!::G, tree::AbstractExpressionNode, args...; kwargs...
83-
) where {F,G<:Union{Function,Nothing}}
92+
) where {F,G}
8493
return Optim.optimize(f, g!, nothing, tree, args...; kwargs...)
8594
end
8695
function Optim.optimize(
8796
f::F, g!::G, h!::H, tree::AbstractExpressionNode{T}, args...; make_copy=true, kwargs...
88-
) where {F,G<:Union{Function,Nothing},H<:Union{Function,Nothing},T}
97+
) where {F,G,H,T}
8998
if make_copy
9099
tree = copy(tree)
91100
end
92101
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree)
93102
x0 = T[t.val::T for t in constant_nodes]
94103
base_res = if g! === nothing
95104
@assert h! === nothing
96-
Optim.optimize(get_wrapped_f(f, tree, constant_nodes), x0, args...; kwargs...)
105+
Optim.optimize(wrap_func(f, tree, constant_nodes), x0, args...; kwargs...)
97106
elseif h! === nothing
98107
Optim.optimize(
99-
get_wrapped_f(f, tree, constant_nodes),
100-
get_wrapped_gh!(g!, tree, constant_nodes),
108+
wrap_func(f, tree, constant_nodes),
109+
wrap_func(g!, tree, constant_nodes),
101110
x0,
102111
args...;
103112
kwargs...,
104113
)
105114
else
106115
Optim.optimize(
107-
get_wrapped_f(f, tree, constant_nodes),
108-
get_wrapped_gh!(g!, tree, constant_nodes),
109-
get_wrapped_gh!(h!, tree, constant_nodes),
116+
wrap_func(f, tree, constant_nodes),
117+
wrap_func(g!, tree, constant_nodes),
118+
wrap_func(h!, tree, constant_nodes),
110119
x0,
111120
args...;
112121
kwargs...,

test/test_optim.jl

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,79 @@
11
using DynamicExpressions, Optim, Zygote
22
using Random: Xoshiro
33

4-
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,))
5-
x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3);
4+
@testset "Basic optimization" begin
5+
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,))
6+
x1, x2 = (i -> Node(Float64; feature=i)).(1:2)
67

7-
X = rand(Xoshiro(0), Float64, 3, 100);
8-
y = @. exp(X[1, :] * 2.1 - 0.9) + X[3, :] * -0.9
8+
X = rand(Xoshiro(0), Float64, 2, 100)
9+
y = @. exp(X[1, :] * 2.1 - 0.9) + X[2, :] * -0.9
910

10-
original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x3
11-
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x3
12-
tree = copy(original_tree)
11+
original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2
12+
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2
13+
tree = copy(original_tree)
1314

14-
res = optimize(t -> sum(abs2, t(X, operators) .- y), tree)
15+
f(tree) = sum(abs2, tree(X, operators) .- y)
1516

16-
# Should be unchanged by default
17-
if VERSION >= v"1.9"
18-
ext = Base.get_extension(DynamicExpressions, :DynamicExpressionsOptimExt)
19-
@test res isa ext.ExpressionOptimizationResults
17+
res = optimize(f, tree)
18+
19+
# Should be unchanged by default
20+
if VERSION >= v"1.9"
21+
ext = Base.get_extension(DynamicExpressions, :DynamicExpressionsOptimExt)
22+
@test res isa ext.ExpressionOptimizationResults
23+
end
24+
@test tree == original_tree
25+
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
26+
end
27+
28+
@testset "With gradients" begin
29+
did_i_run = Ref(false)
30+
# Now, try with gradients too (via Zygote and our hand-rolled forward-mode AD)
31+
g!(G, tree) =
32+
let
33+
ŷ, dŷ_dconstants, _ = eval_grad_tree_array(tree, X, operators; variable=false)
34+
dresult_dŷ = @. 2 * (ŷ - y)
35+
for i in eachindex(G)
36+
G[i] = sum(
37+
j -> dresult_dŷ[j] * dŷ_dconstants[i, j],
38+
eachindex(axes(dŷ_dconstants, 2), axes(dresult_dŷ, 1)),
39+
)
40+
end
41+
did_i_run[] = true
42+
return nothing
43+
end
44+
45+
res = optimize(f, g!, tree, BFGS())
46+
@test did_i_run[]
47+
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
48+
end
49+
50+
# Now, try combined
51+
@testset "Combined evaluation with gradient" begin
52+
did_i_run_2 = Ref(false)
53+
fg!(F, G, tree) =
54+
let
55+
if G !== nothing
56+
ŷ, dŷ_dconstants, _ = eval_grad_tree_array(
57+
tree, X, operators; variable=false
58+
)
59+
dresult_dŷ = @. 2 * (ŷ - y)
60+
for i in eachindex(G)
61+
G[i] = sum(
62+
j -> dresult_dŷ[j] * dŷ_dconstants[i, j],
63+
eachindex(axes(dŷ_dconstants, 2), axes(dresult_dŷ, 1)),
64+
)
65+
end
66+
if F !== nothing
67+
did_i_run_2[] = true
68+
return sum(abs2, ŷ .- y)
69+
end
70+
elseif F !== nothing
71+
# Only f
72+
return sum(abs2, tree(X, operators) .- y)
73+
end
74+
end
75+
res = optimize(Optim.only_fg!(fg!), tree, BFGS())
76+
77+
@test did_i_run_2[]
78+
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
2079
end
21-
@test tree == original_tree
22-
@test get_constants(res.minimizer) get_constants(target_tree)

0 commit comments

Comments
 (0)