11module DynamicExpressionsSymbolicUtilsExt
22
33using SymbolicUtils
4- import DynamicExpressions. EquationModule: Node, DEFAULT_NODE_TYPE
4+ import DynamicExpressions. EquationModule:
5+ AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
56import DynamicExpressions. OperatorEnumModule: AbstractOperatorEnum
67import DynamicExpressions. UtilsModule: isgood, isbad, @return_on_false , deprecate_varmap
78import DynamicExpressions. ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
1920subs_bad (x) = isgood (x) ? x : Inf
2021
2122function parse_tree_to_eqs (
22- tree:: Node{T} , operators:: AbstractOperatorEnum , index_functions:: Bool = false
23+ tree:: AbstractExpressionNode{T} ,
24+ operators:: AbstractOperatorEnum ,
25+ index_functions:: Bool = false ,
2326) where {T}
2427 if tree. degree == 0
2528 # Return constant if needed
2629 tree. constant && return subs_bad (tree. val:: T )
2730 return SymbolicUtils. Sym {LiteralReal} (Symbol (" x$(tree. feature) " ))
2831 end
2932 # Collect the next children
33+ # TODO : Type instability!
3034 children = tree. degree == 2 ? (tree. l, tree. r) : (tree. l,)
3135 # Get the operation
3236 op = tree. degree == 2 ? operators. binops[tree. op] : operators. unaops[tree. op]
@@ -66,11 +70,12 @@ convert_to_function(x, operators::AbstractOperatorEnum) = x
6670function split_eq (
6771 op,
6872 args,
69- operators:: AbstractOperatorEnum ;
73+ operators:: AbstractOperatorEnum ,
74+ :: Type{N} = Node;
7075 variable_names:: Union{Array{String,1},Nothing} = nothing ,
7176 # Deprecated:
7277 varMap= nothing ,
73- )
78+ ) where {N <: AbstractExpressionNode }
7479 variable_names = deprecate_varmap (variable_names, varMap, :split_eq )
7580 ! (op ∈ (sum, prod, + , * )) && throw (error (" Unsupported operation $op in expression!" ))
7681 if Symbol (op) == Symbol (sum)
@@ -80,10 +85,10 @@ function split_eq(
8085 else
8186 ind = findoperation (op, operators. binops)
8287 end
83- return Node (
88+ return constructorof (N) (
8489 ind,
85- convert (Node , args[1 ], operators; variable_names= variable_names),
86- convert (Node , op (args[2 : end ]. .. ), operators; variable_names= variable_names),
90+ convert (N , args[1 ], operators; variable_names= variable_names),
91+ convert (N , op (args[2 : end ]. .. ), operators; variable_names= variable_names),
8792 )
8893end
8994
96101
97102function Base. convert (
98103 :: typeof (SymbolicUtils. Symbolic),
99- tree:: Node ,
104+ tree:: AbstractExpressionNode ,
100105 operators:: AbstractOperatorEnum ;
101106 variable_names:: Union{Array{String,1},Nothing} = nothing ,
102107 index_functions:: Bool = false ,
@@ -109,20 +114,22 @@ function Base.convert(
109114 )
110115end
111116
112- function Base. convert (:: typeof (Node), x:: Number , operators:: AbstractOperatorEnum ; kws... )
113- return Node (; val= DEFAULT_NODE_TYPE (x))
117+ function Base. convert (
118+ :: Type{N} , x:: Number , operators:: AbstractOperatorEnum ; kws...
119+ ) where {N<: AbstractExpressionNode }
120+ return constructorof (N)(; val= DEFAULT_NODE_TYPE (x))
114121end
115122
116123function Base. convert (
117- :: typeof (Node) ,
124+ :: Type{N} ,
118125 expr:: SymbolicUtils.Symbolic ,
119126 operators:: AbstractOperatorEnum ;
120127 variable_names:: Union{Array{String,1},Nothing} = nothing ,
121- )
128+ ) where {N <: AbstractExpressionNode }
122129 variable_names = deprecate_varmap (variable_names, nothing , :convert )
123130 if ! SymbolicUtils. istree (expr)
124- variable_names === nothing && return Node (String (expr. name))
125- return Node (String (expr. name), variable_names)
131+ variable_names === nothing && return constructorof (N) (String (expr. name))
132+ return constructorof (N) (String (expr. name), variable_names)
126133 end
127134
128135 # First, we remove integer powers:
@@ -134,20 +141,21 @@ function Base.convert(
134141 op = convert_to_function (SymbolicUtils. operation (expr), operators)
135142 args = SymbolicUtils. arguments (expr)
136143
137- length (args) > 2 && return split_eq (op, args, operators; variable_names= variable_names)
144+ length (args) > 2 &&
145+ return split_eq (op, args, operators, N; variable_names= variable_names)
138146 ind = if length (args) == 2
139147 findoperation (op, operators. binops)
140148 else
141149 findoperation (op, operators. unaops)
142150 end
143151
144- return Node (
145- ind, map (x -> convert (Node , x, operators; variable_names= variable_names), args)...
152+ return constructorof (N) (
153+ ind, map (x -> convert (N , x, operators; variable_names= variable_names), args)...
146154 )
147155end
148156
149157"""
150- node_to_symbolic(tree::Node , operators::AbstractOperatorEnum;
158+ node_to_symbolic(tree::AbstractExpressionNode , operators::AbstractOperatorEnum;
151159 variable_names::Union{Array{String, 1}, Nothing}=nothing,
152160 index_functions::Bool=false)
153161
@@ -156,17 +164,17 @@ will generate a symbolic equation in SymbolicUtils.jl format.
156164
157165## Arguments
158166
159- - `tree::Node `: The equation to convert.
167+ - `tree::AbstractExpressionNode `: The equation to convert.
160168- `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation.
161169- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: What variable names to use for
162170 each feature. Default is [x1, x2, x3, ...].
163171- `index_functions::Bool=false`: Whether to generate special names for the
164- operators, which then allows one to convert back to a `Node ` format
172+ operators, which then allows one to convert back to a `AbstractExpressionNode ` format
165173 using `symbolic_to_node`.
166174 (CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84).
167175"""
168176function node_to_symbolic (
169- tree:: Node ,
177+ tree:: AbstractExpressionNode ,
170178 operators:: AbstractOperatorEnum ;
171179 variable_names:: Union{Array{String,1},Nothing} = nothing ,
172180 index_functions:: Bool = false ,
@@ -192,13 +200,14 @@ end
192200
193201function symbolic_to_node (
194202 eqn:: SymbolicUtils.Symbolic ,
195- operators:: AbstractOperatorEnum ;
203+ operators:: AbstractOperatorEnum ,
204+ :: Type{N} = Node;
196205 variable_names:: Union{Array{String,1},Nothing} = nothing ,
197206 # Deprecated:
198207 varMap= nothing ,
199- ):: Node
208+ ) where {N <: AbstractExpressionNode }
200209 variable_names = deprecate_varmap (variable_names, varMap, :symbolic_to_node )
201- return convert (Node , eqn, operators; variable_names= variable_names)
210+ return convert (N , eqn, operators; variable_names= variable_names)
202211end
203212
204213function multiply_powers (eqn:: Number ):: Tuple{SYMBOLIC_UTILS_TYPES,Bool}
0 commit comments