@@ -101,7 +101,9 @@ or `cur_operators` if it is not `nothing`. If left as default,
101101it requires `cur_operators` to not be `nothing`.
102102`cur_operators` would typically be an `OperatorEnum`.
103103"""
104- function get_operators (ex:: AbstractExpression , operators)
104+ function get_operators (
105+ ex:: AbstractExpression , operators:: Union{AbstractOperatorEnum,Nothing} = nothing
106+ )
105107 return error (" `get_operators` function must be implemented for $(typeof (ex)) types." )
106108end
107109
110112
111113The same as `operators`, but for variable names.
112114"""
113- function get_variable_names (ex:: AbstractExpression , variable_names)
115+ function get_variable_names (
116+ ex:: AbstractExpression ,
117+ variable_names:: Union{Nothing,AbstractVector{<:AbstractString}} = nothing ,
118+ )
114119 return error (
115120 " `get_variable_names` function must be implemented for $(typeof (ex)) types."
116121 )
@@ -179,10 +184,23 @@ function preserve_sharing(::Union{E,Type{E}}) where {T,N,E<:AbstractExpression{T
179184 return preserve_sharing (N)
180185end
181186
182- function get_operators (ex:: Expression , operators= nothing )
187+ function get_operators (
188+ tree:: AbstractExpressionNode , operators:: Union{AbstractOperatorEnum,Nothing} = nothing
189+ )
190+ if operators === nothing
191+ throw (ArgumentError (" `operators` must be provided for $(typeof (tree)) types." ))
192+ else
193+ return operators
194+ end
195+ end
196+ function get_operators (
197+ ex:: Expression , operators:: Union{AbstractOperatorEnum,Nothing} = nothing
198+ )
183199 return operators === nothing ? ex. metadata. operators : operators
184200end
185- function get_variable_names (ex:: Expression , variable_names= nothing )
201+ function get_variable_names (
202+ ex:: Expression , variable_names:: Union{Nothing,AbstractVector{<:AbstractString}} = nothing
203+ )
186204 return variable_names === nothing ? ex. metadata. variable_names : variable_names
187205end
188206function get_tree (ex:: Expression )
249267import .. StringsModule: string_tree, print_tree
250268
251269function string_tree (
252- ex:: AbstractExpression , operators= nothing ; variable_names= nothing , kws...
270+ ex:: AbstractExpression ,
271+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
272+ variable_names= nothing ,
273+ kws... ,
253274)
254275 return string_tree (
255276 get_tree (ex),
@@ -260,7 +281,11 @@ function string_tree(
260281end
261282for io in ((), (:(io:: IO ),))
262283 @eval function print_tree (
263- $ (io... ), ex:: AbstractExpression , operators= nothing ; variable_names= nothing , kws...
284+ $ (io... ),
285+ ex:: AbstractExpression ,
286+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
287+ variable_names= nothing ,
288+ kws... ,
264289 )
265290 return println ($ (io... ), string_tree (ex, operators; variable_names, kws... ))
266291 end
@@ -283,7 +308,9 @@ function max_feature(ex::AbstractExpression)
283308 )
284309end
285310
286- function _validate_input (ex:: AbstractExpression , X, operators)
311+ function _validate_input (
312+ ex:: AbstractExpression , X, operators:: Union{AbstractOperatorEnum,Nothing}
313+ )
287314 if get_operators (ex, operators) isa OperatorEnum
288315 @assert X isa AbstractMatrix
289316 @assert max_feature (ex) <= size (X, 1 )
@@ -292,7 +319,10 @@ function _validate_input(ex::AbstractExpression, X, operators)
292319end
293320
294321function eval_tree_array (
295- ex:: AbstractExpression , cX:: AbstractMatrix , operators= nothing ; kws...
322+ ex:: AbstractExpression ,
323+ cX:: AbstractMatrix ,
324+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
325+ kws... ,
296326)
297327 _validate_input (ex, cX, operators)
298328 return eval_tree_array (get_tree (ex), cX, get_operators (ex, operators); kws... )
@@ -305,7 +335,10 @@ import ..EvaluateDerivativeModule: eval_grad_tree_array
305335# - differentiable_eval_tree_array
306336
307337function eval_grad_tree_array (
308- ex:: AbstractExpression , cX:: AbstractMatrix , operators= nothing ; kws...
338+ ex:: AbstractExpression ,
339+ cX:: AbstractMatrix ,
340+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
341+ kws... ,
309342)
310343 _validate_input (ex, cX, operators)
311344 return eval_grad_tree_array (get_tree (ex), cX, get_operators (ex, operators); kws... )
@@ -319,14 +352,16 @@ end
319352function _grad_evaluator (
320353 ex:: AbstractExpression ,
321354 cX:: AbstractMatrix ,
322- operators= nothing ;
355+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
323356 variable= Val (true ),
324357 kws... ,
325358)
326359 _validate_input (ex, cX, operators)
327360 return _grad_evaluator (get_tree (ex), cX, get_operators (ex, operators); variable, kws... )
328361end
329- function (ex:: AbstractExpression )(X, operators= nothing ; kws... )
362+ function (ex:: AbstractExpression )(
363+ X, operators:: Union{AbstractOperatorEnum,Nothing} = nothing ; kws...
364+ )
330365 _validate_input (ex, X, operators)
331366 return get_tree (ex)(X, get_operators (ex, operators); kws... )
332367end
0 commit comments