Skip to content

Commit 160c30a

Browse files
authored
Merge pull request #52 from SymbolicML/avoid-iterated-outputs
Enzyme compatibility
2 parents 5eea47a + ecb3574 commit 160c30a

29 files changed

+1024
-534
lines changed

.github/workflows/CI.yml

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,48 @@ jobs:
5959
with:
6060
parallel: true
6161
path-to-lcov: lcov.info
62-
flag-name: julia-${{ matrix.julia-version }}-${{ matrix.os }}-${{ github.event_name }}
62+
flag-name: julia-${{ matrix.julia-version }}-${{ matrix.os }}-main-${{ github.event_name }}
63+
64+
integration_tests:
65+
name: Integration test - ${{ matrix.test_name }} - ${{ matrix.os }}
66+
runs-on: ${{ matrix.os }}
67+
timeout-minutes: 60
68+
strategy:
69+
fail-fast: false
70+
matrix:
71+
os:
72+
- "ubuntu-latest"
73+
julia-version:
74+
- "1"
75+
test_name:
76+
- "enzyme"
77+
steps:
78+
- uses: actions/checkout@v2
79+
- uses: julia-actions/setup-julia@v1
80+
with:
81+
version: ${{ matrix.julia-version }}
82+
- uses: julia-actions/cache@v1
83+
- uses: julia-actions/julia-buildpkg@v1
84+
- name: Run tests
85+
run: |
86+
julia --color=yes -e 'import Pkg; Pkg.add("Coverage")'
87+
SR_TEST=${{ matrix.test_name }} julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.test(coverage=true)'
88+
julia --color=yes coverage.jl
89+
shell: bash
90+
- name: Coveralls
91+
uses: coverallsapp/github-action@v2
92+
with:
93+
parallel: true
94+
path-to-lcov: lcov.info
95+
flag-name: julia-${{ matrix.julia-version }}-${{ matrix.os }}-${{ matrix.test_name }}-${{ github.event_name }}
96+
6397

6498
coveralls:
6599
name: Indicate completion to coveralls
66100
runs-on: ubuntu-latest
67-
needs: test
101+
needs:
102+
- test
103+
- integration_tests
68104
steps:
69105
- name: Finish
70106
uses: coverallsapp/github-action@v2

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.14.1"
4+
version = "0.15.0"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -26,6 +26,7 @@ DynamicExpressionsZygoteExt = "Zygote"
2626
[compat]
2727
Aqua = "0.7"
2828
Compat = "3.37, 4"
29+
Enzyme = "^0.11.12"
2930
LoopVectorization = "0.12"
3031
MacroTools = "0.4, 0.5"
3132
PackageExtensionCompat = "1"
@@ -37,6 +38,7 @@ julia = "1.6"
3738

3839
[extras]
3940
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
41+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4042
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4143
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4244
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -46,4 +48,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4648
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4749

4850
[targets]
49-
test = ["Test", "SafeTestsets", "Aqua", "SpecialFunctions", "ForwardDiff", "StaticArrays", "SymbolicUtils", "Zygote"]
51+
test = ["Test", "SafeTestsets", "Aqua", "Enzyme", "ForwardDiff", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ using Zygote # trigger extension
105105
operators = OperatorEnum(;
106106
binary_operators=[+, -, *],
107107
unary_operators=[cos],
108-
enable_autodiff=true,
109108
)
110109
x1 = Node(; feature=1)
111110
x2 = Node(; feature=2)

benchmark/benchmarks.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@ else
77
@eval using DynamicExpressions: GraphNode
88
end
99

10-
include("benchmark_utils.jl")
10+
include("../test/tree_gen_utils.jl")
1111

1212
const SUITE = BenchmarkGroup()
1313

1414
function benchmark_evaluation()
1515
suite = BenchmarkGroup()
1616
operators = OperatorEnum(;
17-
binary_operators=[+, -, /, *], unary_operators=[cos, exp], enable_autodiff=true
17+
binary_operators=[+, -, /, *],
18+
unary_operators=[cos, exp],
19+
(PACKAGE_VERSION >= v"0.15" ? () : (; enable_autodiff=true))...,
1820
)
1921
for T in (ComplexF32, ComplexF64, Float32, Float64)
2022
if !(T <: Real) && PACKAGE_VERSION < v"0.5.0" && PACKAGE_VERSION != v"0.0.0"

docs/src/eval.md

Lines changed: 105 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
# Evaluation
1+
# Evaluation & Derivatives
2+
3+
## Evaluation
24

35
Given an expression tree specified with a `Node` type, you may evaluate the expression
46
over an array of data with the following command:
@@ -11,20 +13,25 @@ Assuming you are only using a single `OperatorEnum`, you can also use
1113
the following shorthand by using the expression as a function:
1214

1315
```
14-
(tree::Node)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false)
15-
16-
Evaluate a binary tree (equation) over a given data matrix. The
17-
operators contain all of the operators used in the tree.
16+
(tree::Node)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
1817
1918
# Arguments
20-
- `X::AbstractMatrix{T}`: The input data to evaluate the tree on.
21-
- `operators::OperatorEnum`: The operators used in the tree.
22-
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
19+
- `X::AbstractArray`: The input data to evaluate the tree on.
20+
- `operators::GenericOperatorEnum`: The operators used in the tree.
21+
- `throw_errors::Bool=true`: Whether to throw errors
22+
if they occur during evaluation. Otherwise,
23+
MethodErrors will be caught before they happen and
24+
evaluation will return `nothing`,
25+
rather than throwing an error. This is useful in cases
26+
where you are unsure if a particular tree is valid or not,
27+
and would prefer to work with `nothing` as an output.
2328
2429
# Returns
25-
- `output::AbstractVector{T}`: the result, which is a 1D array.
26-
Any NaN, Inf, or other failure during the evaluation will result in the entire
27-
output array being set to NaN.
30+
- `output`: the result of the evaluation.
31+
If evaluation failed, `nothing` will be returned for the first argument.
32+
A `false` complete means an operator was called on input types
33+
that it was not defined for. You can change this behavior by
34+
setting `throw_errors=false`.
2835
```
2936

3037
For example,
@@ -66,7 +73,7 @@ Likewise for the shorthand notation:
6673
- `operators::GenericOperatorEnum`: The operators used in the tree.
6774
- `throw_errors::Bool=true`: Whether to throw errors
6875
if they occur during evaluation. Otherwise,
69-
MethodErrors will be caught before they happen and
76+
MethodErrors will be caught before they happen and
7077
evaluation will return `nothing`,
7178
rather than throwing an error. This is useful in cases
7279
where you are unsure if a particular tree is valid or not,
@@ -107,8 +114,7 @@ to every constant in the expression.
107114
108115
# Arguments
109116
- `X::AbstractMatrix{T}`: The data matrix, with each column being a data point.
110-
- `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff`
111-
must be `true`. This is needed to create the derivative operations.
117+
- `operators::OperatorEnum`: The operators used to create the `tree`.
112118
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
113119
or with respect to every constant in the expression (`variable=false`).
114120
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
@@ -126,13 +132,92 @@ the function `differentiable_eval_tree_array`, although this will be slower.
126132
differentiable_eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) where {T<:Number}
127133
```
128134

129-
## Printing
135+
### Enzyme
130136

131-
You can also print a tree as follows:
137+
`DynamicExpressions.jl` also supports automatic differentiation with
138+
[`Enzyme.jl`](https://github.com/EnzymeAD/Enzyme.jl). Note that this is
139+
**extremely experimental**.
140+
You should expect to see occasional incorrect gradients.
141+
Be sure to explicitly verify gradients are correct for a particular
142+
space of operators (e.g., with finite differences).
132143

133-
```@docs
134-
string_tree(tree::Node, operators::AbstractOperatorEnum)
144+
Let's look at an example. First, let's create a tree:
145+
146+
```julia
147+
using DynamicExpressions
148+
149+
operators = OperatorEnum(binary_operators=(+, -, *, /), unary_operators=(cos, sin))
150+
151+
x1 = Node{Float64}(feature=1)
152+
x2 = Node{Float64}(feature=2)
153+
154+
tree = 0.5 * x1 + cos(x2 - 0.2)
135155
```
136156

137-
When you define an `OperatorEnum`, the standard `show` and `print` methods
138-
will be overwritten to use `string_tree`.
157+
Now, say we want to take the derivative of this expression with respect to x1 and x2.
158+
First, let's evaluate it normally:
159+
```julia
160+
X = [1.0 2.0 3.0; 4.0 5.0 6.0] # 2x3 matrix (2 features, 3 rows)
161+
162+
tree(X, operators)
163+
```
164+
165+
Now, let's use `Enzyme.jl` to compute the derivative of the outputs
166+
with respect to x1 and x2, using reverse-mode autodiff:
167+
168+
```julia
169+
using Enzyme
170+
171+
function my_loss_function(tree, X, operators)
172+
# Get the outputs
173+
y = tree(X, operators)
174+
# Sum them (so we can take a gradient, rather than a jacobian)
175+
return sum(y)
176+
end
177+
178+
179+
dX = begin
180+
storage=zero(X)
181+
autodiff(
182+
Reverse,
183+
my_loss_function,
184+
Active,
185+
## Actual arguments to function:
186+
Const(tree),
187+
Duplicated(X, storage),
188+
Const(operators),
189+
)
190+
storage
191+
end
192+
```
193+
194+
This will get returned as
195+
196+
```text
197+
2×3 Matrix{Float64}:
198+
0.5 0.5 0.5
199+
0.611858 0.996165 0.464602
200+
```
201+
202+
which one can confirm is the correct gradient!
203+
204+
This will take a while the first time you run it, as Enzyme needs to take the
205+
gradients of the actual LLVM IR code. Subsequent runs won't spend any time compiling
206+
and be much faster.
207+
208+
Some general notes about this:
209+
210+
1. We want to take a reverse-mode gradient, so we pass `Reverse` to `autodiff`.
211+
2. Since we want to take the gradient of the _output_ of `my_loss_function`,
212+
we declare `Active` as the third argument.
213+
3. Following this, we pass our actual arguments to the function.
214+
- Objects which we don't want to take gradients with respect to,
215+
and also don't temporarily store any data during the computation
216+
(such as `tree` and `operators` here) should be wrapped with `Const`.
217+
- Objects which we wish to take derivatives with respect to, we need to use
218+
`Duplicated`, and explicitly create a copy of it, with all numerical values
219+
set to zero. Enzyme will then store the derivatives in this object.
220+
221+
Note that you should never use anything other than `turbo=Val(false)` with Enzyme,
222+
as Enzyme and LoopVectorization are not compatible, and will cause a segfault.
223+
_Even using `turbo=false` will not work, because it would cause Enzyme to trace the (unused) LoopVectorization code!_

docs/src/types.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ OperatorEnum
1313
Construct this operator specification as follows:
1414

1515
```@docs
16-
OperatorEnum(; binary_operators=[], unary_operators=[], enable_autodiff::Bool=false, define_helper_functions::Bool=true)
16+
OperatorEnum(; binary_operators=[], unary_operators=[], define_helper_functions::Bool=true)
1717
```
1818

1919
This is just for scalar operators. However, you can use

docs/src/utils.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ convert(::Type{<:AbstractExpressionNode{T1}}, n::AbstractExpressionNode{T2}) whe
2020
hash(tree::AbstractExpressionNode{T}, h::UInt; break_sharing::Val=Val(false)) where {T}
2121
```
2222

23+
## Printing
24+
25+
Trees are printed using the `string_tree` function, which is very
26+
configurable:
27+
28+
```@docs
29+
string_tree(tree::Node, operators::AbstractOperatorEnum)
30+
```
31+
32+
The standard `show` and `print` methods will use the most recently-created `OperatorEnum`
33+
in a `string_tree`.
34+
2335
## Sampling
2436

2537
There are also methods for random sampling of nodes:

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@ using SymbolicUtils
44
import DynamicExpressions.EquationModule:
55
AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
66
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
7-
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
7+
import DynamicExpressions.UtilsModule: isgood, isbad, deprecate_varmap
88
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
99

1010
const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}}
1111
const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, +, -, *, /)
1212

13+
macro return_on_false(flag, retval)
14+
:(
15+
if !$(esc(flag))
16+
return ($(esc(retval)), false)
17+
end
18+
)
19+
end
20+
1321
function isgood(x::SymbolicUtils.Symbolic)
1422
return if SymbolicUtils.istree(x)
1523
all(isgood.([SymbolicUtils.operation(x); SymbolicUtils.arguments(x)]))

ext/DynamicExpressionsZygoteExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module DynamicExpressionsZygoteExt
22

33
import Zygote: gradient
4-
import DynamicExpressions.EvaluateEquationDerivativeModule: _zygote_gradient
4+
import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient
55

66
function _zygote_gradient(op::F, ::Val{1}) where {F}
77
function (x)

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DynamicExpressions
22

33
include("Utils.jl")
4+
include("ExtensionInterface.jl")
45
include("OperatorEnum.jl")
56
include("Equation.jl")
67
include("EquationUtils.jl")
@@ -9,7 +10,6 @@ include("EvaluateEquationDerivative.jl")
910
include("EvaluationHelpers.jl")
1011
include("SimplifyEquation.jl")
1112
include("OperatorEnumConstruction.jl")
12-
include("ExtensionInterface.jl")
1313
include("Random.jl")
1414

1515
import PackageExtensionCompat: @require_extensions

0 commit comments

Comments
 (0)