@@ -9,7 +9,7 @@ import ..EvaluateEquationModule:
99import .. ExtensionInterfaceModule: _zygote_gradient
1010
1111"""
12- eval_diff_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer; turbo::Bool= false)
12+ eval_diff_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer; turbo::Union{ Bool,Val}=Val( false) )
1313
1414Compute the forward derivative of an expression, using a similar
1515structure and optimization to eval_tree_array. `direction` is the index of a particular
@@ -22,7 +22,8 @@ respect to `x1`.
2222- `cX::AbstractMatrix{T}`: The data matrix, with each column being a data point.
2323- `operators::OperatorEnum`: The operators used to create the `tree`.
2424- `direction::Integer`: The index of the variable to take the derivative with respect to.
25- - `turbo::Union{Val,Bool}`: Use `LoopVectorization.@turbo` for faster evaluation.
25+ - `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation. Currently this does not have
26+ any effect.
2627
2728# Returns
2829
@@ -34,7 +35,7 @@ function eval_diff_tree_array(
3435 cX:: AbstractMatrix{T} ,
3536 operators:: OperatorEnum ,
3637 direction:: Integer ;
37- turbo:: Union{Val, Bool} = Val (false ),
38+ turbo:: Union{Bool,Val } = Val (false ),
3839) where {T<: Number }
3940 # TODO : Implement quick check for whether the variable is actually used
4041 # in this tree. Otherwise, return zero.
@@ -48,7 +49,7 @@ function eval_diff_tree_array(
4849 cX:: AbstractMatrix{T2} ,
4950 operators:: OperatorEnum ,
5051 direction:: Integer ;
51- turbo:: Bool = false ,
52+ turbo:: Union{ Bool,Val} = Val ( false ) ,
5253) where {T1<: Number ,T2<: Number }
5354 T = promote_type (T1, T2)
5455 @warn " Warning: eval_diff_tree_array received mixed types: tree=$(T1) and data=$(T2) ."
@@ -175,7 +176,7 @@ function diff_deg2_eval(
175176end
176177
177178"""
178- eval_grad_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool= false, turbo::Bool= false)
179+ eval_grad_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Union{ Bool,Val}=Val( false) , turbo::Union{ Bool,Val}=Val( false) )
179180
180181Compute the forward-mode derivative of an expression, using a similar
181182structure and optimization to eval_tree_array. `variable` specifies whether
@@ -187,9 +188,9 @@ to every constant in the expression.
187188- `tree::AbstractExpressionNode{T}`: The expression tree to evaluate.
188189- `cX::AbstractMatrix{T}`: The data matrix, with each column being a data point.
189190- `operators::OperatorEnum`: The operators used to create the `tree`.
190- - `variable::Bool`: Whether to take derivatives with respect to features (i.e., `cX` - with `variable=true`),
191+ - `variable::Union{ Bool,Val} `: Whether to take derivatives with respect to features (i.e., `cX` - with `variable=true`),
191192 or with respect to every constant in the expression (`variable=false`).
192- - `turbo::Bool`: Use ` LoopVectorization.@turbo` for faster evaluation. Currently this does not have
193+ - `turbo::Union{ Bool,Val} `: Use LoopVectorization.jl for faster evaluation. Currently this does not have
193194 any effect.
194195
195196# Returns
@@ -201,8 +202,8 @@ function eval_grad_tree_array(
201202 tree:: AbstractExpressionNode{T} ,
202203 cX:: AbstractMatrix{T} ,
203204 operators:: OperatorEnum ;
204- variable:: Union{Val, Bool} = Val {false} ( ),
205- turbo:: Union{Val, Bool} = Val {false} ( ),
205+ variable:: Union{Bool,Val } = Val ( false ),
206+ turbo:: Union{Bool,Val } = Val ( false ),
206207) where {T<: Number }
207208 n_gradients = if isa (variable, Val{true }) || (isa (variable, Bool) && variable)
208209 size (cX, 1 ):: Int
@@ -239,8 +240,8 @@ function eval_grad_tree_array(
239240 tree:: AbstractExpressionNode{T1} ,
240241 cX:: AbstractMatrix{T2} ,
241242 operators:: OperatorEnum ;
242- variable:: Union{Val, Bool} = Val {false} ( ),
243- turbo:: Union{Val, Bool} = Val {false} ( ),
243+ variable:: Union{Bool,Val } = Val ( false ),
244+ turbo:: Union{Bool,Val } = Val ( false ),
244245) where {T1<: Number ,T2<: Number }
245246 T = promote_type (T1, T2)
246247 return eval_grad_tree_array (
0 commit comments