|
| 1 | +module DynamicExpressionsOptimExt |
| 2 | + |
| 3 | +using DynamicExpressions: AbstractExpressionNode, eval_tree_array |
| 4 | +using Compat: @inline |
| 5 | + |
| 6 | +import Optim: Optim, OptimizationResults, NLSolversBase |
| 7 | + |
| 8 | +#! format: off |
| 9 | +""" |
| 10 | + ExpressionOptimizationResults{R,N<:AbstractExpressionNode} |
| 11 | +
|
| 12 | +Optimization results for an expression, which wraps the base optimization results |
| 13 | +on a vector of constants. |
| 14 | +""" |
| 15 | +struct ExpressionOptimizationResults{R<:OptimizationResults,N<:AbstractExpressionNode} <: OptimizationResults |
| 16 | + _results::R # The raw results from Optim. |
| 17 | + tree::N # The final expression tree |
| 18 | +end |
| 19 | +#! format: on |
| 20 | +function Base.getproperty(r::ExpressionOptimizationResults, s::Symbol) |
| 21 | + if s == :tree || s == :minimizer |
| 22 | + return getfield(r, :tree) |
| 23 | + else |
| 24 | + return getproperty(getfield(r, :_results), s) |
| 25 | + end |
| 26 | +end |
| 27 | +function Base.propertynames(r::ExpressionOptimizationResults) |
| 28 | + return (:tree, propertynames(getfield(r, :_results))...) |
| 29 | +end |
| 30 | +function Optim.minimizer(r::ExpressionOptimizationResults) |
| 31 | + return r.tree |
| 32 | +end |
| 33 | + |
| 34 | +function set_constant_nodes!( |
| 35 | + constant_nodes::AbstractArray{N}, x |
| 36 | +) where {T,N<:AbstractExpressionNode{T}} |
| 37 | + for (ci, xi) in zip(constant_nodes, x) |
| 38 | + ci.val::T = xi::T |
| 39 | + end |
| 40 | + return nothing |
| 41 | +end |
| 42 | + |
| 43 | +"""Wrap function or objective with insertion of values of the constant nodes.""" |
| 44 | +function wrap_func( |
| 45 | + f::F, tree::N, constant_nodes::AbstractArray{N} |
| 46 | +) where {F<:Function,T,N<:AbstractExpressionNode{T}} |
| 47 | + function wrapped_f(args::Vararg{Any,M}) where {M} |
| 48 | + first_args = args[1:(end - 1)] |
| 49 | + x = last(args) |
| 50 | + set_constant_nodes!(constant_nodes, x) |
| 51 | + return @inline(f(first_args..., tree)) |
| 52 | + end |
| 53 | + return wrapped_f |
| 54 | +end |
| 55 | +function wrap_func( |
| 56 | + ::Nothing, tree::N, constant_nodes::AbstractArray{N} |
| 57 | +) where {N<:AbstractExpressionNode} |
| 58 | + return nothing |
| 59 | +end |
| 60 | +function wrap_func( |
| 61 | + f::NLSolversBase.InplaceObjective, tree::N, constant_nodes::AbstractArray{N} |
| 62 | +) where {N<:AbstractExpressionNode} |
| 63 | + # Some objectives, like `Optim.only_fg!(fg!)`, are not functions but instead |
| 64 | + # `InplaceObjective`. These contain multiple functions, each of which needs to be |
| 65 | + # wrapped. Some functions are `nothing`; those can be left as-is. |
| 66 | + @assert fieldnames(NLSolversBase.InplaceObjective) == (:df, :fdf, :fgh, :hv, :fghv) |
| 67 | + return NLSolversBase.InplaceObjective( |
| 68 | + wrap_func(f.df, tree, constant_nodes), |
| 69 | + wrap_func(f.fdf, tree, constant_nodes), |
| 70 | + wrap_func(f.fgh, tree, constant_nodes), |
| 71 | + wrap_func(f.hv, tree, constant_nodes), |
| 72 | + wrap_func(f.fghv, tree, constant_nodes), |
| 73 | + ) |
| 74 | +end |
| 75 | + |
| 76 | +""" |
| 77 | + optimize(f, [g!, [h!,]] tree, args...; kwargs...) |
| 78 | +
|
| 79 | +Optimize an expression tree with respect to the constants in the tree. |
| 80 | +Returns an `ExpressionOptimizationResults` object, which wraps the base |
| 81 | +optimization results on a vector of constants. You may use `res.minimizer` |
| 82 | +to view the optimized expression tree. |
| 83 | +""" |
| 84 | +function Optim.optimize(f::F, tree::AbstractExpressionNode, args...; kwargs...) where {F} |
| 85 | + return Optim.optimize(f, nothing, tree, args...; kwargs...) |
| 86 | +end |
| 87 | +function Optim.optimize( |
| 88 | + f::F, g!::G, tree::AbstractExpressionNode, args...; kwargs... |
| 89 | +) where {F,G} |
| 90 | + return Optim.optimize(f, g!, nothing, tree, args...; kwargs...) |
| 91 | +end |
| 92 | +function Optim.optimize( |
| 93 | + f::F, g!::G, h!::H, tree::AbstractExpressionNode{T}, args...; make_copy=true, kwargs... |
| 94 | +) where {F,G,H,T} |
| 95 | + if make_copy |
| 96 | + tree = copy(tree) |
| 97 | + end |
| 98 | + constant_nodes = filter(t -> t.degree == 0 && t.constant, tree) |
| 99 | + x0 = T[t.val::T for t in constant_nodes] |
| 100 | + if !isnothing(h!) |
| 101 | + throw( |
| 102 | + ArgumentError( |
| 103 | + "Optim.optimize does not yet support Hessians on `AbstractExpressionNode`. " * |
| 104 | + "Please raise an issue at github.com/SymbolicML/DynamicExpressions.jl.", |
| 105 | + ), |
| 106 | + ) |
| 107 | + end |
| 108 | + base_res = if isnothing(g!) |
| 109 | + Optim.optimize(wrap_func(f, tree, constant_nodes), x0, args...; kwargs...) |
| 110 | + else |
| 111 | + Optim.optimize( |
| 112 | + wrap_func(f, tree, constant_nodes), |
| 113 | + wrap_func(g!, tree, constant_nodes), |
| 114 | + x0, |
| 115 | + args...; |
| 116 | + kwargs..., |
| 117 | + ) |
| 118 | + end |
| 119 | + set_constant_nodes!(constant_nodes, Optim.minimizer(base_res)) |
| 120 | + return ExpressionOptimizationResults(base_res, tree) |
| 121 | +end |
| 122 | + |
| 123 | +end |
0 commit comments