Skip to content

Commit b426191

Browse files
committed
Update docstrings
1 parent c22ac5c commit b426191

File tree

4 files changed

+35
-32
lines changed

4 files changed

+35
-32
lines changed

docs/src/eval.md

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,23 @@ Assuming you are only using a single `OperatorEnum`, you can also use
1313
the following shorthand by using the expression as a function:
1414

1515
```
16-
(tree::Node)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
16+
(tree::AbstractExpressionNode)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=false, bumper::Union{Bool,Val}=Val(false))
17+
18+
Evaluate a binary tree (equation) over a given input data matrix. The
19+
operators contain all of the operators used. This function fuses doublets
20+
and triplets of operations for lower memory usage.
1721
1822
# Arguments
19-
- `X::AbstractArray`: The input data to evaluate the tree on.
20-
- `operators::GenericOperatorEnum`: The operators used in the tree.
21-
- `throw_errors::Bool=true`: Whether to throw errors
22-
if they occur during evaluation. Otherwise,
23-
MethodErrors will be caught before they happen and
24-
evaluation will return `nothing`,
25-
rather than throwing an error. This is useful in cases
26-
where you are unsure if a particular tree is valid or not,
27-
and would prefer to work with `nothing` as an output.
23+
- `tree::AbstractExpressionNode`: The root node of the tree to evaluate.
24+
- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on.
25+
- `operators::OperatorEnum`: The operators used in the tree.
26+
- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation.
27+
- `bumper::Union{Bool,Val}`: Use Bumper.jl for faster evaluation.
2828
2929
# Returns
30-
- `output`: the result of the evaluation.
31-
If evaluation failed, `nothing` will be returned for the first argument.
32-
A `false` complete means an operator was called on input types
33-
that it was not defined for. You can change this behavior by
34-
setting `throw_errors=false`.
30+
- `output::AbstractVector{T}`: the result, which is a 1D array.
31+
Any NaN, Inf, or other failure during the evaluation will result in the entire
32+
output array being set to NaN.
3533
```
3634

3735
For example,
@@ -98,7 +96,7 @@ all variables (or, all constants). Both use forward-mode automatic, but use
9896

9997
```@docs
10098
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer) where {T<:Number}
101-
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=false) where {T<:Number}
99+
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) where {T<:Number}
102100
```
103101

104102
You can compute gradients this with shorthand notation as well (which by default computes

src/EvaluateEquation.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ macro return_on_nonfinite_array(array)
2525
end
2626

2727
"""
28-
eval_tree_array(tree::AbstractExpressionNode, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=Val(false))
28+
eval_tree_array(tree::AbstractExpressionNode, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=Val(false), bumper::Union{Bool,Val}=Val(false))
2929
3030
Evaluate a binary tree (equation) over a given input data matrix. The
3131
operators contain all of the operators used. This function fuses doublets
@@ -35,7 +35,8 @@ and triplets of operations for lower memory usage.
3535
- `tree::AbstractExpressionNode`: The root node of the tree to evaluate.
3636
- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on.
3737
- `operators::OperatorEnum`: The operators used in the tree.
38-
- `turbo::Union{Bool,Val}`: Use `LoopVectorization.@turbo` for faster evaluation.
38+
- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation.
39+
- `bumper::Union{Bool,Val}`: Use Bumper.jl for faster evaluation.
3940
4041
# Returns
4142
- `(output, complete)::Tuple{AbstractVector{T}, Bool}`: the result,
@@ -71,6 +72,7 @@ function eval_tree_array(
7172
if v_turbo isa Val{true} || v_bumper isa Val{true}
7273
@assert T in (Float32, Float64)
7374
end
75+
@assert !(v_turbo isa Val{true} && v_bumper isa Val{true})
7476
if v_bumper isa Val{true}
7577
return bumper_eval_tree_array(tree, cX, operators)
7678
end

src/EvaluateEquationDerivative.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import ..EvaluateEquationModule:
99
import ..ExtensionInterfaceModule: _zygote_gradient
1010

1111
"""
12-
eval_diff_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer; turbo::Bool=false)
12+
eval_diff_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer; turbo::Union{Bool,Val}=Val(false))
1313
1414
Compute the forward derivative of an expression, using a similar
1515
structure and optimization to eval_tree_array. `direction` is the index of a particular
@@ -22,7 +22,8 @@ respect to `x1`.
2222
- `cX::AbstractMatrix{T}`: The data matrix, with each column being a data point.
2323
- `operators::OperatorEnum`: The operators used to create the `tree`.
2424
- `direction::Integer`: The index of the variable to take the derivative with respect to.
25-
- `turbo::Union{Val,Bool}`: Use `LoopVectorization.@turbo` for faster evaluation.
25+
- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation. Currently this does not have
26+
any effect.
2627
2728
# Returns
2829
@@ -34,7 +35,7 @@ function eval_diff_tree_array(
3435
cX::AbstractMatrix{T},
3536
operators::OperatorEnum,
3637
direction::Integer;
37-
turbo::Union{Val,Bool}=Val(false),
38+
turbo::Union{Bool,Val}=Val(false),
3839
) where {T<:Number}
3940
# TODO: Implement quick check for whether the variable is actually used
4041
# in this tree. Otherwise, return zero.
@@ -48,7 +49,7 @@ function eval_diff_tree_array(
4849
cX::AbstractMatrix{T2},
4950
operators::OperatorEnum,
5051
direction::Integer;
51-
turbo::Bool=false,
52+
turbo::Union{Bool,Val}=Val(false),
5253
) where {T1<:Number,T2<:Number}
5354
T = promote_type(T1, T2)
5455
@warn "Warning: eval_diff_tree_array received mixed types: tree=$(T1) and data=$(T2)."
@@ -175,7 +176,7 @@ function diff_deg2_eval(
175176
end
176177

177178
"""
178-
eval_grad_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false, turbo::Bool=false)
179+
eval_grad_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Union{Bool,Val}=Val(false), turbo::Union{Bool,Val}=Val(false))
179180
180181
Compute the forward-mode derivative of an expression, using a similar
181182
structure and optimization to eval_tree_array. `variable` specifies whether
@@ -187,9 +188,9 @@ to every constant in the expression.
187188
- `tree::AbstractExpressionNode{T}`: The expression tree to evaluate.
188189
- `cX::AbstractMatrix{T}`: The data matrix, with each column being a data point.
189190
- `operators::OperatorEnum`: The operators used to create the `tree`.
190-
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `cX` - with `variable=true`),
191+
- `variable::Union{Bool,Val}`: Whether to take derivatives with respect to features (i.e., `cX` - with `variable=true`),
191192
or with respect to every constant in the expression (`variable=false`).
192-
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation. Currently this does not have
193+
- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation. Currently this does not have
193194
any effect.
194195
195196
# Returns
@@ -201,8 +202,8 @@ function eval_grad_tree_array(
201202
tree::AbstractExpressionNode{T},
202203
cX::AbstractMatrix{T},
203204
operators::OperatorEnum;
204-
variable::Union{Val,Bool}=Val{false}(),
205-
turbo::Union{Val,Bool}=Val{false}(),
205+
variable::Union{Bool,Val}=Val(false),
206+
turbo::Union{Bool,Val}=Val(false),
206207
) where {T<:Number}
207208
n_gradients = if isa(variable, Val{true}) || (isa(variable, Bool) && variable)
208209
size(cX, 1)::Int
@@ -239,8 +240,8 @@ function eval_grad_tree_array(
239240
tree::AbstractExpressionNode{T1},
240241
cX::AbstractMatrix{T2},
241242
operators::OperatorEnum;
242-
variable::Union{Val,Bool}=Val{false}(),
243-
turbo::Union{Val,Bool}=Val{false}(),
243+
variable::Union{Bool,Val}=Val(false),
244+
turbo::Union{Bool,Val}=Val(false),
244245
) where {T1<:Number,T2<:Number}
245246
T = promote_type(T1, T2)
246247
return eval_grad_tree_array(

src/EvaluationHelpers.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import ..EvaluateEquationDerivativeModule: eval_grad_tree_array
88

99
# Evaluation:
1010
"""
11-
(tree::AbstractExpressionNode)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=false)
11+
(tree::AbstractExpressionNode)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=false, bumper::Union{Bool,Val}=Val(false))
1212
1313
Evaluate a binary tree (equation) over a given input data matrix. The
1414
operators contain all of the operators used. This function fuses doublets
@@ -18,7 +18,8 @@ and triplets of operations for lower memory usage.
1818
- `tree::AbstractExpressionNode`: The root node of the tree to evaluate.
1919
- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on.
2020
- `operators::OperatorEnum`: The operators used in the tree.
21-
- `turbo::Union{Bool,Val}`: Use `LoopVectorization.@turbo` for faster evaluation.
21+
- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation.
22+
- `bumper::Union{Bool,Val}`: Use Bumper.jl for faster evaluation.
2223
2324
# Returns
2425
- `output::AbstractVector{T}`: the result, which is a 1D array.
@@ -84,7 +85,8 @@ to every constant in the expression.
8485
- `operators::OperatorEnum`: The operators used to create the `tree`.
8586
- `variable::Union{Bool,Val}`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
8687
or with respect to every constant in the expression (`variable=false`).
87-
- `turbo::Union{Bool,Val}`: Use `LoopVectorization.@turbo` for faster evaluation.
88+
- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation. Currently this does not have
89+
any effect.
8890
8991
# Returns
9092

0 commit comments

Comments
 (0)