Skip to content

Commit db556dc

Browse files
authored
Merge pull request #85 from gca30/generic-update
Updated OperatorEnum to use any data type (not just Numbers)
2 parents cab1143 + 7ff927b commit db556dc

32 files changed

+879
-320
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.18.5"
4+
version = "0.18.6"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -14,7 +14,6 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1616
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
17-
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
1817

1918
[weakdeps]
2019
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
@@ -43,7 +42,6 @@ PackageExtensionCompat = "1"
4342
PrecompileTools = "1"
4443
Reexport = "1"
4544
SymbolicUtils = "0.19, ^1.0.5, 2"
46-
TestItems = "0.1"
4745
Zygote = "0.6"
4846
julia = "1.6"
4947

benchmark/benchmarks.jl

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
using DynamicExpressions, BenchmarkTools, Random
22

33
# Trigger extensions:
4-
using LoopVectorization
5-
using Bumper
6-
using StrideArrays
7-
using Zygote
4+
using LoopVectorization, Bumper, StrideArrays, Zygote
85

96
if PACKAGE_VERSION < v"0.14.0"
107
@eval using DynamicExpressions: Node as GraphNode
@@ -18,6 +15,14 @@ else
1815
@eval using DynamicExpressions.NodeUtilsModule: is_constant
1916
end
2017

18+
if PACKAGE_VERSION < v"0.18.6"
19+
@eval using DynamicExpressions:
20+
index_constants as index_constant_nodes,
21+
count_constants as count_constant_nodes,
22+
get_constants as get_scalar_constants,
23+
set_constants! as set_scalar_constants!
24+
end
25+
2126
include("../test/tree_gen_utils.jl")
2227

2328
const SUITE = BenchmarkGroup()
@@ -113,15 +118,16 @@ end
113118
PACKAGE_VERSION < v"0.14.0" && return :(copy_node(t; preserve_sharing=preserve_sharing))
114119
return :(copy_node(t)) # Assume type used to infer sharing
115120
end
116-
@generated function get_set_constants!(tree::N) where {T,N<:AbstractExpressionNode{T}}
117-
if !(@isdefined set_constants!)
118-
return :(set_constants(tree, get_constants(tree)))
119-
elseif hasmethod(set_constants!, Tuple{N, Vector{T}})
120-
return :(set_constants!(tree, get_constants(tree)))
121+
@generated function get_set_constants!(tree::N) where {N}
122+
T = eltype(N)
123+
if !(@isdefined set_scalar_constants!)
124+
return :(set_scalar_constants(tree, get_scalar_constants(tree)))
125+
elseif hasmethod(set_scalar_constants!, Tuple{N, Vector{T}})
126+
return :(set_scalar_constants!(tree, get_scalar_constants(tree)))
121127
else
122128
return quote
123-
let (x, refs) = get_constants(tree)
124-
set_constants!(tree, x, refs)
129+
let (x, refs) = get_scalar_constants(tree)
130+
set_scalar_constants!(tree, x, refs)
125131
end
126132
end
127133
end
@@ -141,12 +147,12 @@ function benchmark_utilities()
141147
:combine_operators,
142148
:count_nodes,
143149
:count_depth,
144-
:count_constants,
150+
:count_constant_nodes,
145151
:has_constants,
146152
:has_operators,
147153
:is_constant,
148154
:get_set_constants!,
149-
:index_constants,
155+
:index_constant_nodes,
150156
:string_tree,
151157
:hash,
152158
)
@@ -157,9 +163,9 @@ function benchmark_utilities()
157163
[
158164
:simplify_tree,
159165
:count_nodes,
160-
:count_constants,
166+
:count_constant_nodes,
161167
:get_set_constants!,
162-
:index_constants,
168+
:index_constant_nodes,
163169
:string_tree,
164170
],
165171
)
@@ -207,7 +213,8 @@ function benchmark_utilities()
207213
setup=(
208214
ntrees=100;
209215
n=20;
210-
trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32)) for _ in 1:ntrees]
216+
rng=Random.MersenneTwister(0);
217+
trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32, Node, rng)) for _ in 1:ntrees]
211218
)
212219
)
213220
#! format: on
@@ -216,6 +223,37 @@ function benchmark_utilities()
216223
end
217224
end
218225

226+
# Additional methods
227+
@static if PACKAGE_VERSION >= v"0.18.0"
228+
suite["get_set_constants_parametric"] = @benchmarkable(
229+
[get_set_constants!(ex) for ex in exs],
230+
seconds = 10.0,
231+
setup = (
232+
operators = $operators;
233+
ntrees = 100;
234+
n = 20;
235+
n_features = 5;
236+
n_params = 3;
237+
n_param_classes = 10;
238+
rng = Random.MersenneTwister(0);
239+
exs = [
240+
let tree = gen_random_tree_fixed_size(
241+
n, operators, n_features, Float32, ParametricNode, rng
242+
)
243+
ex = ParametricExpression(
244+
tree;
245+
operators,
246+
variable_names=map(i -> "x$i", 1:n_features),
247+
parameters=randn(rng, Float32, n_params, n_param_classes),
248+
parameter_names=map(i -> "p$i", 1:n_params),
249+
)
250+
ex
251+
end for _ in 1:ntrees
252+
]
253+
)
254+
)
255+
end
256+
219257
return suite
220258
end
221259

ext/DynamicExpressionsBumperExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module DynamicExpressionsBumperExt
22

33
using Bumper: @no_escape, @alloc
4-
using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce
5-
using DynamicExpressions.UtilsModule: ResultOk, counttuple, is_bad_array
4+
using DynamicExpressions:
5+
OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array
6+
using DynamicExpressions.UtilsModule: ResultOk, counttuple
67

78
import DynamicExpressions.ExtensionInterfaceModule:
89
bumper_eval_tree_array, bumper_kern1!, bumper_kern2!
@@ -52,7 +53,7 @@ function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}) where
5253
cumulator.ok || return cumulator
5354

5455
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo))
55-
return ResultOk(out, !is_bad_array(out))
56+
return ResultOk(out, is_valid_array(out))
5657
end
5758
function dispatch_kerns!(
5859
operators, branch_node, cumulator1, cumulator2, ::Val{turbo}
@@ -63,7 +64,7 @@ function dispatch_kerns!(
6364
out = dispatch_kern2!(
6465
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo)
6566
)
66-
return ResultOk(out, !is_bad_array(out))
67+
return ResultOk(out, is_valid_array(out))
6768
end
6869

6970
@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}) where {turbo}

ext/DynamicExpressionsOptimExt.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using DynamicExpressions:
55
AbstractExpressionNode,
66
filter_map,
77
eval_tree_array,
8-
get_constants,
9-
set_constants!
8+
get_scalar_constants,
9+
set_scalar_constants!,
10+
get_number_type
1011
using Compat: @inline
1112

1213
import Optim: Optim, OptimizationResults, NLSolversBase
@@ -44,9 +45,14 @@ function wrap_func(
4445
function wrapped_f(args::Vararg{Any,M}) where {M}
4546
first_args = args[begin:(end - 1)]
4647
x = args[end]
47-
set_constants!(tree, x, refs)
48+
set_scalar_constants!(tree, x, refs)
4849
return @inline(f(first_args..., tree))
4950
end
51+
# without first args, it looks like this
52+
# function wrapped_f(x)
53+
# set_scalar_constants!(tree, x, refs)
54+
# return @inline(f(tree))
55+
# end
5056
return wrapped_f
5157
end
5258
function wrap_func(
@@ -100,7 +106,8 @@ function Optim.optimize(
100106
if make_copy
101107
tree = copy(tree)
102108
end
103-
x0, refs = get_constants(tree)
109+
110+
x0, refs = get_scalar_constants(tree)
104111
if !isnothing(h!)
105112
throw(
106113
ArgumentError(
@@ -117,7 +124,7 @@ function Optim.optimize(
117124
)
118125
end
119126
minimizer = Optim.minimizer(base_res)
120-
set_constants!(tree, minimizer, refs)
127+
set_scalar_constants!(tree, minimizer, refs)
121128
return ExpressionOptimizationResults(base_res, tree)
122129
end
123130

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ using SymbolicUtils
44
import DynamicExpressions.NodeModule:
55
AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
66
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
7-
import DynamicExpressions.UtilsModule: isgood, isbad, deprecate_varmap
7+
import DynamicExpressions.ValueInterfaceModule: is_valid
8+
import DynamicExpressions.UtilsModule: deprecate_varmap
89
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
910
import DynamicExpressions: AbstractExpression, get_tree, get_operators
1011

@@ -19,14 +20,14 @@ macro return_on_false(flag, retval)
1920
)
2021
end
2122

22-
function isgood(x::SymbolicUtils.Symbolic)
23+
function is_valid(x::SymbolicUtils.Symbolic)
2324
return if SymbolicUtils.istree(x)
24-
all(isgood.([SymbolicUtils.operation(x); SymbolicUtils.arguments(x)]))
25+
all(is_valid.([SymbolicUtils.operation(x); SymbolicUtils.arguments(x)]))
2526
else
2627
true
2728
end
2829
end
29-
subs_bad(x) = isgood(x) ? x : Inf
30+
subs_bad(x) = is_valid(x) ? x : Inf
3031

3132
function parse_tree_to_eqs(
3233
tree::AbstractExpressionNode{T},
@@ -197,7 +198,7 @@ function node_to_symbolic(
197198
variable_names = deprecate_varmap(variable_names, varMap, :node_to_symbolic)
198199
expr = subs_bad(parse_tree_to_eqs(tree, operators, index_functions))
199200
# Check for NaN and Inf
200-
@assert isgood(expr) "The recovered equation contains NaN or Inf."
201+
@assert is_valid(expr) "The recovered equation contains NaN or Inf."
201202
# Return if no variable_names is given
202203
variable_names === nothing && return expr
203204
# Create a substitution tuple
@@ -248,12 +249,12 @@ function multiply_powers(
248249
if nargs == 1
249250
l, complete = multiply_powers(args[1])
250251
@return_on_false complete eqn
251-
@return_on_false isgood(l) eqn
252+
@return_on_false is_valid(l) eqn
252253
return op(l), true
253254
elseif op == ^
254255
l, complete = multiply_powers(args[1])
255256
@return_on_false complete eqn
256-
@return_on_false isgood(l) eqn
257+
@return_on_false is_valid(l) eqn
257258
n = args[2]
258259
if typeof(n) <: Integer
259260
if n == 1
@@ -275,23 +276,23 @@ function multiply_powers(
275276
elseif nargs == 2
276277
l, complete = multiply_powers(args[1])
277278
@return_on_false complete eqn
278-
@return_on_false isgood(l) eqn
279+
@return_on_false is_valid(l) eqn
279280
r, complete2 = multiply_powers(args[2])
280281
@return_on_false complete2 eqn
281-
@return_on_false isgood(r) eqn
282+
@return_on_false is_valid(r) eqn
282283
return op(l, r), true
283284
else
284285
# return tree_mapreduce(multiply_powers, op, args)
285286
# ## reduce(op, map(multiply_powers, args))
286287
out = map(multiply_powers, args) #vector of tuples
287288
for i in 1:size(out, 1)
288289
@return_on_false out[i][2] eqn
289-
@return_on_false isgood(out[i][1]) eqn
290+
@return_on_false is_valid(out[i][1]) eqn
290291
end
291292
cumulator = out[1][1]
292293
for i in 2:size(out, 1)
293294
cumulator = op(cumulator, out[i][1])
294-
@return_on_false isgood(cumulator) eqn
295+
@return_on_false is_valid(cumulator) eqn
295296
end
296297
return cumulator, true
297298
end

src/DynamicExpressions.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using DispatchDoctor: @stable, @unstable
44

55
@stable default_mode = "disable" begin
66
include("Utils.jl")
7+
include("ValueInterface.jl")
78
include("ExtensionInterface.jl")
89
include("OperatorEnum.jl")
910
include("Node.jl")
@@ -25,6 +26,13 @@ import PackageExtensionCompat: @require_extensions
2526
import Reexport: @reexport
2627
macro ignore(args...) end
2728

29+
import .ValueInterfaceModule:
30+
is_valid,
31+
is_valid_array,
32+
get_number_type,
33+
pack_scalar_constants!,
34+
unpack_scalar_constants,
35+
ValueInterface
2836
@reexport import .NodeModule:
2937
AbstractNode,
3038
AbstractExpressionNode,
@@ -47,14 +55,15 @@ import .NodeModule:
4755
branch_equal
4856
@reexport import .NodeUtilsModule:
4957
count_nodes,
50-
count_constants,
58+
count_constant_nodes,
5159
count_depth,
5260
NodeIndex,
53-
index_constants,
61+
index_constant_nodes,
5462
has_operators,
5563
has_constants,
56-
get_constants,
57-
set_constants!
64+
count_scalar_constants,
65+
get_scalar_constants,
66+
set_scalar_constants!
5867
@reexport import .StringsModule: string_tree, print_tree
5968
@reexport import .OperatorEnumModule: AbstractOperatorEnum
6069
@reexport import .OperatorEnumConstructionModule:

0 commit comments

Comments
 (0)