11module DynamicExpressionsOptimExt
22
3- using DynamicExpressions: AbstractExpressionNode, filter_map, eval_tree_array
3+ using DynamicExpressions:
4+ AbstractExpression,
5+ AbstractExpressionNode,
6+ filter_map,
7+ eval_tree_array,
8+ get_constants,
9+ set_constants!
410using Compat: @inline
511
612import Optim: Optim, OptimizationResults, NLSolversBase
@@ -12,7 +18,7 @@ import Optim: Optim, OptimizationResults, NLSolversBase
1218Optimization results for an expression, which wraps the base optimization results
1319on a vector of constants.
1420"""
15- struct ExpressionOptimizationResults{R<: OptimizationResults ,N<: AbstractExpressionNode } <: OptimizationResults
21+ struct ExpressionOptimizationResults{R<: OptimizationResults ,N<: Union{ AbstractExpressionNode,AbstractExpression} } <: OptimizationResults
1622 _results:: R # The raw results from Optim.
1723 tree:: N # The final expression tree
1824end
3339
3440""" Wrap function or objective with insertion of values of the constant nodes."""
3541function wrap_func (
36- f:: F , tree:: N , constant_refs :: AbstractArray
37- ) where {F<: Function ,T,N<: AbstractExpressionNode{T} }
42+ f:: F , tree:: N , refs
43+ ) where {F<: Function ,T,N<: Union{ AbstractExpressionNode{T},AbstractExpression{T} } }
3844 function wrapped_f (args:: Vararg{Any,M} ) where {M}
39- first_args = args[1 : (end - 1 )]
40- x = last (args)
41- @inbounds for i in eachindex (constant_refs, x)
42- constant_refs[i][]. val = x[i]
43- end
45+ first_args = args[begin : (end - 1 )]
46+ x = args[end ]
47+ set_constants! (tree, x, refs)
4448 return @inline (f (first_args... , tree))
4549 end
4650 return wrapped_f
4751end
4852function wrap_func (
49- :: Nothing , tree:: N , constant_refs :: AbstractArray
50- ) where {N<: AbstractExpressionNode }
53+ :: Nothing , tree:: N , refs
54+ ) where {N<: Union{ AbstractExpressionNode,AbstractExpression} }
5155 return nothing
5256end
5357function wrap_func (
54- f:: NLSolversBase.InplaceObjective , tree:: N , constant_refs :: AbstractArray
55- ) where {N<: AbstractExpressionNode }
58+ f:: NLSolversBase.InplaceObjective , tree:: N , refs
59+ ) where {N<: Union{ AbstractExpressionNode,AbstractExpression} }
5660 # Some objectives, like `Optim.only_fg!(fg!)`, are not functions but instead
5761 # `InplaceObjective`. These contain multiple functions, each of which needs to be
5862 # wrapped. Some functions are `nothing`; those can be left as-is.
5963 @assert fieldnames (NLSolversBase. InplaceObjective) == (:df , :fdf , :fgh , :hv , :fghv )
6064 return NLSolversBase. InplaceObjective (
61- wrap_func (f. df, tree, constant_refs ),
62- wrap_func (f. fdf, tree, constant_refs ),
63- wrap_func (f. fgh, tree, constant_refs ),
64- wrap_func (f. hv, tree, constant_refs ),
65- wrap_func (f. fghv, tree, constant_refs ),
65+ wrap_func (f. df, tree, refs ),
66+ wrap_func (f. fdf, tree, refs ),
67+ wrap_func (f. fgh, tree, refs ),
68+ wrap_func (f. hv, tree, refs ),
69+ wrap_func (f. fghv, tree, refs ),
6670 )
6771end
6872
@@ -74,24 +78,29 @@ Returns an `ExpressionOptimizationResults` object, which wraps the base
7478optimization results on a vector of constants. You may use `res.minimizer`
7579to view the optimized expression tree.
7680"""
77- function Optim. optimize (f:: F , tree:: AbstractExpressionNode , args... ; kwargs... ) where {F}
81+ function Optim. optimize (
82+ f:: F , tree:: Union{AbstractExpressionNode,AbstractExpression} , args... ; kwargs...
83+ ) where {F}
7884 return Optim. optimize (f, nothing , tree, args... ; kwargs... )
7985end
8086function Optim. optimize (
81- f:: F , g!:: G , tree:: AbstractExpressionNode , args... ; kwargs...
87+ f:: F , g!:: G , tree:: Union{ AbstractExpressionNode,AbstractExpression} , args... ; kwargs...
8288) where {F,G}
8389 return Optim. optimize (f, g!, nothing , tree, args... ; kwargs... )
8490end
8591function Optim. optimize (
86- f:: F , g!:: G , h!:: H , tree:: AbstractExpressionNode{T} , args... ; make_copy= true , kwargs...
92+ f:: F ,
93+ g!:: G ,
94+ h!:: H ,
95+ tree:: Union{AbstractExpressionNode{T},AbstractExpression{T}} ,
96+ args... ;
97+ make_copy= true ,
98+ kwargs... ,
8799) where {F,G,H,T}
88100 if make_copy
89101 tree = copy (tree)
90102 end
91- constant_refs = filter_map (
92- t -> t. degree == 0 && t. constant, t -> Ref (t), tree, Ref{typeof (tree)}
93- )
94- x0 = T[copy (t[]. val) for t in constant_refs]
103+ x0, refs = get_constants (tree)
95104 if ! isnothing (h!)
96105 throw (
97106 ArgumentError (
@@ -101,20 +110,14 @@ function Optim.optimize(
101110 )
102111 end
103112 base_res = if isnothing (g!)
104- Optim. optimize (wrap_func (f, tree, constant_refs ), x0, args... ; kwargs... )
113+ Optim. optimize (wrap_func (f, tree, refs ), x0, args... ; kwargs... )
105114 else
106115 Optim. optimize (
107- wrap_func (f, tree, constant_refs),
108- wrap_func (g!, tree, constant_refs),
109- x0,
110- args... ;
111- kwargs... ,
116+ wrap_func (f, tree, refs), wrap_func (g!, tree, refs), x0, args... ; kwargs...
112117 )
113118 end
114119 minimizer = Optim. minimizer (base_res)
115- @inbounds for i in eachindex (constant_refs, minimizer)
116- constant_refs[i][]. val = minimizer[i]
117- end
120+ set_constants! (tree, minimizer, refs)
118121 return ExpressionOptimizationResults (base_res, tree)
119122end
120123
0 commit comments