@@ -3,7 +3,7 @@ module DynamicExpressionsOptimExt
33using DynamicExpressions: AbstractExpressionNode, eval_tree_array
44using Compat: @inline
55
6- import Optim: Optim, OptimizationResults
6+ import Optim: Optim, OptimizationResults, NLSolversBase
77
88# ! format: off
99"""
@@ -40,29 +40,40 @@ function set_constant_nodes!(
4040 for (ci, xi) in zip (constant_nodes, x)
4141 ci. val:: T = xi:: T
4242 end
43+ return nothing
4344end
4445
45- """ Wrap f with insertion of values of the constant nodes."""
46- function get_wrapped_f (
46+ """ Wrap function or objective with insertion of values of the constant nodes."""
47+ function wrap_func (
4748 f:: F , tree:: N , constant_nodes:: AbstractArray{N}
48- ) where {F,T,N<: AbstractExpressionNode{T} }
49- function wrapped_f (x)
49+ ) where {F<: Function ,T,N<: AbstractExpressionNode{T} }
50+ function wrapped_f (args:: Vararg{Any,M} ) where {M}
51+ first_args = args[1 : (end - 1 )]
52+ x = last (args)
5053 set_constant_nodes! (constant_nodes, x)
51- return @inline (f (tree))
54+ return @inline (f (first_args ... , tree))
5255 end
5356 return wrapped_f
5457end
55-
56- """ Wrap g! or h! with insertion of values of the constant nodes."""
57- function get_wrapped_gh! (
58- gh!:: GH , tree:: N , constant_nodes:: AbstractArray{N}
59- ) where {GH,T,N<: AbstractExpressionNode{T} }
60- function wrapped_gh! (G, x)
61- set_constant_nodes! (constant_nodes, x)
62- @inline (gh! (G, tree))
63- return nothing
64- end
65- return wrapped_gh!
58+ function wrap_func (
59+ :: Nothing , tree:: N , constant_nodes:: AbstractArray{N}
60+ ) where {N<: AbstractExpressionNode }
61+ return nothing
62+ end
63+ function wrap_func (
64+ f:: NLSolversBase.InplaceObjective , tree:: N , constant_nodes:: AbstractArray{N}
65+ ) where {N<: AbstractExpressionNode }
66+ # Some objectives, like `Optim.only_fg!(fg!)`, are not functions but instead
67+ # `InplaceObjective`. These contain multiple functions, each of which needs to be
68+ # wrapped. Some functions are `nothing`; those can be left as-is.
69+ @assert fieldnames (NLSolversBase. InplaceObjective) == (:df , :fdf , :fgh , :hv , :fghv )
70+ return NLSolversBase. InplaceObjective (
71+ wrap_func (f. df, tree, constant_nodes),
72+ wrap_func (f. fdf, tree, constant_nodes),
73+ wrap_func (f. fgh, tree, constant_nodes),
74+ wrap_func (f. hv, tree, constant_nodes),
75+ wrap_func (f. fghv, tree, constant_nodes),
76+ )
6677end
6778
6879"""
@@ -73,40 +84,38 @@ Returns an `ExpressionOptimizationResults` object, which wraps the base
7384optimization results on a vector of constants. You may use `res.minimizer`
7485to view the optimized expression tree.
7586"""
76- function Optim. optimize (
77- f:: F , tree:: AbstractExpressionNode , args... ; kwargs...
78- ) where {F<: Function }
87+ function Optim. optimize (f:: F , tree:: AbstractExpressionNode , args... ; kwargs... ) where {F}
7988 return Optim. optimize (f, nothing , tree, args... ; kwargs... )
8089end
8190function Optim. optimize (
8291 f:: F , g!:: G , tree:: AbstractExpressionNode , args... ; kwargs...
83- ) where {F,G<: Union{Function,Nothing} }
92+ ) where {F,G}
8493 return Optim. optimize (f, g!, nothing , tree, args... ; kwargs... )
8594end
8695function Optim. optimize (
8796 f:: F , g!:: G , h!:: H , tree:: AbstractExpressionNode{T} , args... ; make_copy= true , kwargs...
88- ) where {F,G<: Union{Function,Nothing} ,H <: Union{Function,Nothing} ,T}
97+ ) where {F,G,H ,T}
8998 if make_copy
9099 tree = copy (tree)
91100 end
92101 constant_nodes = filter (t -> t. degree == 0 && t. constant, tree)
93102 x0 = T[t. val:: T for t in constant_nodes]
94103 base_res = if g! === nothing
95104 @assert h! === nothing
96- Optim. optimize (get_wrapped_f (f, tree, constant_nodes), x0, args... ; kwargs... )
105+ Optim. optimize (wrap_func (f, tree, constant_nodes), x0, args... ; kwargs... )
97106 elseif h! === nothing
98107 Optim. optimize (
99- get_wrapped_f (f, tree, constant_nodes),
100- get_wrapped_gh! (g!, tree, constant_nodes),
108+ wrap_func (f, tree, constant_nodes),
109+ wrap_func (g!, tree, constant_nodes),
101110 x0,
102111 args... ;
103112 kwargs... ,
104113 )
105114 else
106115 Optim. optimize (
107- get_wrapped_f (f, tree, constant_nodes),
108- get_wrapped_gh! (g!, tree, constant_nodes),
109- get_wrapped_gh! (h!, tree, constant_nodes),
116+ wrap_func (f, tree, constant_nodes),
117+ wrap_func (g!, tree, constant_nodes),
118+ wrap_func (h!, tree, constant_nodes),
110119 x0,
111120 args... ;
112121 kwargs... ,
0 commit comments