1- # Evaluation
1+ # Evaluation & Derivatives
2+
3+ ## Evaluation
24
35Given an expression tree specified with a ` Node ` type, you may evaluate the expression
46over an array of data with the following command:
@@ -11,20 +13,25 @@ Assuming you are only using a single `OperatorEnum`, you can also use
1113the following shorthand by using the expression as a function:
1214
1315```
14- (tree::Node)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false)
15-
16- Evaluate a binary tree (equation) over a given data matrix. The
17- operators contain all of the operators used in the tree.
16+ (tree::Node)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
1817
1918# Arguments
20- - `X::AbstractMatrix{T}`: The input data to evaluate the tree on.
21- - `operators::OperatorEnum`: The operators used in the tree.
22- - `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
19+ - `X::AbstractArray`: The input data to evaluate the tree on.
20+ - `operators::GenericOperatorEnum`: The operators used in the tree.
21+ - `throw_errors::Bool=true`: Whether to throw errors
22+ if they occur during evaluation. Otherwise,
23+ MethodErrors will be caught before they happen and
24+ evaluation will return `nothing`,
25+ rather than throwing an error. This is useful in cases
26+ where you are unsure if a particular tree is valid or not,
27+ and would prefer to work with `nothing` as an output.
2328
2429# Returns
25- - `output::AbstractVector{T}`: the result, which is a 1D array.
26- Any NaN, Inf, or other failure during the evaluation will result in the entire
27- output array being set to NaN.
30+ - `output`: the result of the evaluation.
31+ If evaluation failed, `nothing` will be returned for the first argument.
32+ A `false` complete means an operator was called on input types
33+ that it was not defined for. You can change this behavior by
34+ setting `throw_errors=false`.
2835```
2936
3037For example,
@@ -66,7 +73,7 @@ Likewise for the shorthand notation:
6673- `operators::GenericOperatorEnum`: The operators used in the tree.
6774- `throw_errors::Bool=true`: Whether to throw errors
6875 if they occur during evaluation. Otherwise,
69- MethodErrors will be caught before they happen and
76+ MethodErrors will be caught before they happen and
7077 evaluation will return `nothing`,
7178 rather than throwing an error. This is useful in cases
7279 where you are unsure if a particular tree is valid or not,
@@ -107,8 +114,7 @@ to every constant in the expression.
107114
108115# Arguments
109116- `X::AbstractMatrix{T}`: The data matrix, with each column being a data point.
110- - `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff`
111- must be `true`. This is needed to create the derivative operations.
117+ - `operators::OperatorEnum`: The operators used to create the `tree`.
112118- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
113119 or with respect to every constant in the expression (`variable=false`).
114120- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
@@ -126,13 +132,92 @@ the function `differentiable_eval_tree_array`, although this will be slower.
126132differentiable_eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) where {T<:Number}
127133```
128134
129- ## Printing
135+ ### Enzyme
130136
131- You can also print a tree as follows:
137+ ` DynamicExpressions.jl ` also supports automatic differentiation with
138+ [ ` Enzyme.jl ` ] ( https://github.com/EnzymeAD/Enzyme.jl ) . Note that this is
139+ ** extremely experimental** .
140+ You should expect to see occasional incorrect gradients.
141+ Be sure to explicitly verify gradients are correct for a particular
142+ space of operators (e.g., with finite differences).
132143
133- ``` @docs
134- string_tree(tree::Node, operators::AbstractOperatorEnum)
144+ Let's look at an example. First, let's create a tree:
145+
146+ ``` julia
147+ using DynamicExpressions
148+
149+ operators = OperatorEnum (binary_operators= (+ , - , * , / ), unary_operators= (cos, sin))
150+
151+ x1 = Node {Float64} (feature= 1 )
152+ x2 = Node {Float64} (feature= 2 )
153+
154+ tree = 0.5 * x1 + cos (x2 - 0.2 )
135155```
136156
137- When you define an ` OperatorEnum ` , the standard ` show ` and ` print ` methods
138- will be overwritten to use ` string_tree ` .
157+ Now, say we want to take the derivative of this expression with respect to x1 and x2.
158+ First, let's evaluate it normally:
159+ ``` julia
160+ X = [1.0 2.0 3.0 ; 4.0 5.0 6.0 ] # 2x3 matrix (2 features, 3 rows)
161+
162+ tree (X, operators)
163+ ```
164+
165+ Now, let's use ` Enzyme.jl ` to compute the derivative of the outputs
166+ with respect to x1 and x2, using reverse-mode autodiff:
167+
168+ ``` julia
169+ using Enzyme
170+
171+ function my_loss_function (tree, X, operators)
172+ # Get the outputs
173+ y = tree (X, operators)
174+ # Sum them (so we can take a gradient, rather than a jacobian)
175+ return sum (y)
176+ end
177+
178+
179+ dX = begin
180+ storage= zero (X)
181+ autodiff (
182+ Reverse,
183+ my_loss_function,
184+ Active,
185+ # # Actual arguments to function:
186+ Const (tree),
187+ Duplicated (X, storage),
188+ Const (operators),
189+ )
190+ storage
191+ end
192+ ```
193+
194+ This will get returned as
195+
196+ ``` text
197+ 2×3 Matrix{Float64}:
198+ 0.5 0.5 0.5
199+ 0.611858 0.996165 0.464602
200+ ```
201+
202+ which one can confirm is the correct gradient!
203+
204+ This will take a while the first time you run it, as Enzyme needs to take the
205+ gradients of the actual LLVM IR code. Subsequent runs won't spend any time compiling
206+ and be much faster.
207+
208+ Some general notes about this:
209+
210+ 1 . We want to take a reverse-mode gradient, so we pass ` Reverse ` to ` autodiff ` .
211+ 2 . Since we want to take the gradient of the _ output_ of ` my_loss_function ` ,
212+ we declare ` Active ` as the third argument.
213+ 3 . Following this, we pass our actual arguments to the function.
214+ - Objects which we don't want to take gradients with respect to,
215+ and also don't temporarily store any data during the computation
216+ (such as ` tree ` and ` operators ` here) should be wrapped with ` Const ` .
217+ - Objects which we wish to take derivatives with respect to, we need to use
218+ ` Duplicated ` , and explicitly create a copy of it, with all numerical values
219+ set to zero. Enzyme will then store the derivatives in this object.
220+
221+ Note that you should never use anything other than ` turbo=Val(false) ` with Enzyme,
222+ as Enzyme and LoopVectorization are not compatible, and will cause a segfault.
223+ _ Even using ` turbo=false ` will not work, because it would cause Enzyme to trace the (unused) LoopVectorization code!_
0 commit comments