Skip to content

Commit cecad13

Browse files
authored
Merge pull request #57 from SymbolicML/builtin-sampling
Add built-in random sampling
2 parents 8109f9c + 6d591ed commit cecad13

File tree

10 files changed

+292
-21
lines changed

10 files changed

+292
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.14.0"
4+
version = "0.14.1"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

benchmark/benchmarks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ function benchmark_utilities()
107107
:get_set_constants!,
108108
:index_constants,
109109
:string_tree,
110+
:hash,
110111
)
111112
has_both_modes = [:copy, :convert]
112113
if PACKAGE_VERSION >= v"0.14.0"
@@ -122,6 +123,9 @@ function benchmark_utilities()
122123
],
123124
)
124125
end
126+
if PACKAGE_VERSION >= v"0.14.1"
127+
append!(has_both_modes, [:hash])
128+
end
125129

126130
operators = OperatorEnum(; binary_operators=[+, -, /, *], unary_operators=[cos, exp])
127131
for func_k in all_funcs

docs/make.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
using Documenter
22
using DynamicExpressions
3+
using Random: AbstractRNG
34

45
makedocs(;
56
sitename="DynamicExpressions.jl",
67
authors="Miles Cranmer",
78
doctest=false,
89
clean=true,
910
format=Documenter.HTML(),
11+
warnonly=true,
1012
)
1113

1214
deploydocs(; repo="github.com/SymbolicML/DynamicExpressions.jl.git")

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
# Contents
33

44
```@contents
5-
Pages = ["types.md", "eval.md"]
5+
Pages = ["utils.md", "types.md", "eval.md"]
66
```

docs/src/utils.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Node utilities
2+
3+
## `Base`
4+
5+
Various functions in `Base` are overloaded to treat an `AbstractNode` as a
6+
collection of its nodes.
7+
8+
```@docs
9+
copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
10+
collect(tree::AbstractNode; break_sharing::Val=Val(false))
11+
filter(f::Function, tree::AbstractNode; break_sharing::Val=Val(false))
12+
count(f::Function, tree::AbstractNode; init=0, break_sharing::Val=Val(false))
13+
foreach(f::Function, tree::AbstractNode; break_sharing::Val=Val(false))
14+
sum(f::F, tree::AbstractNode; init=0, return_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val=Val(false)) where {F<:Function}
15+
mapreduce(f::F, op::G, tree::AbstractNode; return_type, f_on_shared, break_sharing) where {F<:Function,G<:Function}
16+
any(f::F, tree::AbstractNode) where {F<:Function}
17+
all(f::F, tree::AbstractNode) where {F<:Function}
18+
map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)) where {F<:Function,RT}
19+
convert(::Type{<:AbstractExpressionNode{T1}}, n::AbstractExpressionNode{T2}) where {T1,T2}
20+
hash(tree::AbstractExpressionNode{T}, h::UInt; break_sharing::Val=Val(false)) where {T}
21+
```
22+
23+
## Sampling
24+
25+
There are also methods for random sampling of nodes:
26+
27+
```@docs
28+
NodeSampler
29+
rand(rng::AbstractRNG, tree::AbstractNode)
30+
rand(rng::AbstractRNG, sampler::NodeSampler{N,F,Nothing}) where {N,F}
31+
```
32+
33+
## Internal utilities
34+
35+
Almost all node utilities are crafted using the `tree_mapreduce` function,
36+
which evaluates a mapreduce over a tree-like (or graph-like) structure:
37+
38+
```@docs
39+
tree_mapreduce
40+
```
41+
42+
Various other utility functions include the following:
43+
44+
```@docs
45+
filter_map
46+
filter_map!
47+
```

src/DynamicExpressions.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ include("EvaluationHelpers.jl")
1010
include("SimplifyEquation.jl")
1111
include("OperatorEnumConstruction.jl")
1212
include("ExtensionInterface.jl")
13+
include("Random.jl")
1314

1415
import PackageExtensionCompat: @require_extensions
1516
import Reexport: @reexport
@@ -23,7 +24,8 @@ import Reexport: @reexport
2324
copy_node,
2425
set_node!,
2526
tree_mapreduce,
26-
filter_map
27+
filter_map,
28+
filter_map!
2729
import .EquationModule: constructorof, preserve_sharing
2830
@reexport import .EquationUtilsModule:
2931
count_nodes,
@@ -44,6 +46,7 @@ import .EquationModule: constructorof, preserve_sharing
4446
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree!
4547
@reexport import .EvaluationHelpersModule
4648
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
49+
@reexport import .RandomModule: NodeSampler
4750

4851
function __init__()
4952
@require_extensions

src/Random.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
module RandomModule
2+
3+
import Compat: Returns, @inline
4+
import Random: AbstractRNG
5+
import Base: rand
6+
import ..EquationModule: AbstractNode, tree_mapreduce, filter_map
7+
8+
"""
9+
NodeSampler(; tree, filter::Function=Returns(true), weighting::Union{Nothing,Function}=nothing, break_sharing::Val=Val(false))
10+
11+
Defines a sampler of nodes in a tree.
12+
13+
# Arguments
14+
15+
- `tree`: The tree to sample nodes from. For a regular `Node`,
16+
nodes are sampled uniformly. For a `GraphNode`, nodes are also
17+
sampled uniformly (e.g., in `sin(x) + {x}`, the `x` has equal
18+
probability of being sampled from the `sin` or the `+` node, because
19+
it is shared), unless `break_sharing` is set to `Val(true)`.
20+
- `filter::Function`: A function that takes a node and returns a boolean
21+
indicating whether the node should be sampled. Defaults to `Returns(true)`.
22+
- `weighting::Union{Nothing,Function}`: A function that takes a node and
23+
returns a weight for the node, if it passes the filter, proportional
24+
to the probability of sampling the node. If `nothing`, all nodes are
25+
sampled uniformly.
26+
- `break_sharing::Val`: If `Val(true)`, the
27+
sampler will break sharing in the tree, and sample nodes uniformly
28+
from the tree.
29+
"""
30+
Base.@kwdef struct NodeSampler{
31+
N<:AbstractNode,F<:Function,W<:Union{Nothing,Function},B<:Val
32+
}
33+
tree::N
34+
weighting::W = nothing
35+
filter::F = Returns(true)
36+
break_sharing::B = Val(false)
37+
end
38+
39+
"""
40+
rand(rng::AbstractRNG, tree::AbstractNode)
41+
42+
Sample a node from a tree according to the default sampler `NodeSampler(; tree)`.
43+
"""
44+
rand(rng::AbstractRNG, tree::AbstractNode) = rand(rng, NodeSampler(; tree))
45+
46+
"""
47+
rand(rng::AbstractRNG, sampler::NodeSampler)
48+
49+
Sample a node from a tree according to the sampler `sampler`.
50+
"""
51+
function rand(rng::AbstractRNG, sampler::NodeSampler{N,F,Nothing}) where {N,F}
52+
n = count(sampler.filter, sampler.tree; sampler.break_sharing)
53+
idx = rand(rng, 1:n)
54+
return _get_node(sampler.tree, sampler.filter, idx, sampler.break_sharing)
55+
end
56+
function rand(rng::AbstractRNG, sampler::NodeSampler{N,F,W}) where {N,F,W<:Function}
57+
weights = filter_map(
58+
sampler.filter, sampler.weighting, sampler.tree, Float64; sampler.break_sharing
59+
)
60+
idx = _sample_idx(rng, weights)
61+
return _get_node(sampler.tree, sampler.filter, idx, sampler.break_sharing)
62+
end
63+
64+
function _get_node(
65+
tree, filter_f::F, idx::Int, ::Val{break_sharing}
66+
) where {F,break_sharing}
67+
i = Ref(0)
68+
out = Ref(tree)
69+
foreach(tree; break_sharing=Val(break_sharing)) do node
70+
if @inline(filter_f(node)) && (i[] += 1) == idx
71+
out[] = node
72+
end
73+
nothing
74+
end
75+
return out[]
76+
end
77+
78+
function _sample_idx(rng::AbstractRNG, weights)
79+
csum = cumsum(weights)
80+
r = rand(rng, eltype(weights)) * csum[end]
81+
return findfirst(ci -> ci > r, csum)::Int
82+
end
83+
84+
end

src/base.jl

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,9 @@ function count_nodes(tree::AbstractNode; break_sharing=Val(false))
230230
end
231231

232232
"""
233-
foreach(f::Function, tree::AbstractNode)
233+
foreach(f::Function, tree::AbstractNode; break_sharing::Val=Val(false))
234234
235-
Apply a function to each node in a tree.
235+
Apply a function to each node in a tree without returning the results.
236236
"""
237237
function foreach(
238238
f::F, tree::AbstractNode; break_sharing::Val=Val(false)
@@ -244,7 +244,7 @@ function foreach(
244244
end
245245

246246
"""
247-
filter_map(filter_fnc::Function, map_fnc::Function, tree::AbstractNode, result_type::Type)
247+
filter_map(filter_fnc::Function, map_fnc::Function, tree::AbstractNode, result_type::Type, break_sharing::Val=Val(false))
248248
249249
A faster equivalent to `map(map_fnc, filter(filter_fnc, tree))`
250250
that avoids the intermediate allocation. However, using this requires
@@ -286,20 +286,25 @@ function filter_map!(
286286
end
287287

288288
"""
289-
filter(f::Function, tree::AbstractNode)
289+
filter(f::Function, tree::AbstractNode; break_sharing::Val=Val(false))
290290
291291
Filter nodes of a tree, returning a flat array of the nodes for which the function returns `true`.
292292
"""
293293
function filter(f::F, tree::AbstractNode; break_sharing::Val=Val(false)) where {F<:Function}
294294
return filter_map(f, identity, tree, typeof(tree); break_sharing)
295295
end
296296

297+
"""
298+
collect(tree::AbstractNode; break_sharing::Val=Val(false))
299+
300+
Collect all nodes in a tree into a flat array in depth-first order.
301+
"""
297302
function collect(tree::AbstractNode; break_sharing::Val=Val(false))
298303
return filter(Returns(true), tree; break_sharing)
299304
end
300305

301306
"""
302-
map(f::Function, tree::AbstractNode, result_type::Type{RT}=Nothing)
307+
map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)) where {F<:Function,RT}
303308
304309
Map a function over a tree and return a flat array of the results in depth-first order.
305310
Pre-specifying the `result_type` of the function can be used to avoid extra allocations.
@@ -314,6 +319,11 @@ function map(
314319
end
315320
end
316321

322+
"""
323+
count(f::F, tree::AbstractNode; init=0, break_sharing::Val=Val(false)) where {F<:Function}
324+
325+
Count the number of nodes in a tree for which the function returns `true`.
326+
"""
317327
function count(
318328
f::F, tree::AbstractNode; init=0, break_sharing::Val=Val(false)
319329
) where {F<:Function}
@@ -327,22 +337,44 @@ function count(
327337
) + init
328338
end
329339

340+
"""
341+
sum(f::Function, tree::AbstractNode; init=0, return_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val=Val(false)) where {F<:Function}
342+
343+
Sum the results of a function over a tree. For graphs with shared nodes
344+
such as `GraphNode`, the function `f_on_shared` is called on the result
345+
of each shared node. This is used to avoid double-counting shared nodes (default
346+
behavior).
347+
"""
330348
function sum(
331349
f::F,
332350
tree::AbstractNode;
333351
init=0,
334352
return_type=Undefined,
335-
f_on_shared=(c, is_shared) -> is_shared ? (false * c) : c,
353+
f_on_shared=_default_shared_aggregation,
336354
break_sharing::Val=Val(false),
337355
) where {F<:Function}
338356
if preserve_sharing(typeof(tree))
339357
@assert typeof(return_type) !== Undefined "Must specify `return_type` as a keyword argument to `sum` if `preserve_sharing` is true."
340358
end
341359
return tree_mapreduce(f, +, tree, return_type; f_on_shared, break_sharing) + init
342360
end
361+
function _default_shared_aggregation(c, is_shared)
362+
return is_shared ? (false * c) : c
363+
end
343364

365+
"""
366+
all(f::Function, tree::AbstractNode)
367+
368+
Reduce a flag function over a tree, returning `true` if the
369+
function returns `true` for all nodes, `false` otherwise.
370+
"""
344371
all(f::F, tree::AbstractNode) where {F<:Function} = !any(t -> !@inline(f(t)), tree)
345372

373+
"""
374+
mapreduce(f::Function, op::Function, tree::AbstractNode; return_type, f_on_shared, break_sharing)
375+
376+
Map a function over a tree and aggregate the result using an operator `op`.
377+
"""
346378
function mapreduce(
347379
f::F,
348380
op::G,
@@ -369,28 +401,34 @@ function length(tree::AbstractNode; break_sharing::Val=Val(false))
369401
return count_nodes(tree; break_sharing)
370402
end
371403

372-
function hash(tree::AbstractExpressionNode{T}) where {T}
404+
"""
405+
hash(tree::AbstractExpressionNode{T}[, h::UInt]; break_sharing::Val=Val(false)) where {T}
406+
407+
Compute a hash of a tree. This will compute a hash differently
408+
if nodes are shared in a tree. This is ignored if `break_sharing` is set to `Val(true)`.
409+
"""
410+
function hash(
411+
tree::AbstractExpressionNode{T}, h::UInt=zero(UInt); break_sharing::Val=Val(false)
412+
) where {T}
373413
return tree_mapreduce(
374-
t -> t.constant ? hash((0, t.val::T)) : hash((1, t.feature)),
375-
t -> hash((t.degree + 1, t.op)),
376-
(n...) -> hash(n),
414+
t -> t.constant ? hash((0, t.val::T), h) : hash((1, t.feature), h),
415+
t -> hash((t.degree + 1, t.op), h),
416+
(n...) -> hash(n, h),
377417
tree,
378-
UInt64;
418+
UInt;
379419
f_on_shared=(cur_hash, is_shared) ->
380-
is_shared ? hash((:shared, cur_hash)) : cur_hash,
420+
is_shared ? hash((:shared, cur_hash), h) : cur_hash,
421+
break_sharing,
381422
)
382423
end
383424

384425
"""
385-
copy_node(tree::AbstractExpressionNode)
426+
copy_node(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
386427
387428
Copy a node, recursively copying all children nodes.
388429
This is more efficient than the built-in copy.
389430
390-
id_map is a map from `objectid(tree)` to `copy(tree)`.
391-
We check against the map before making a new copy; otherwise
392-
we can simply reference the existing copy.
393-
[Thanks to Ted Hopp.](https://stackoverflow.com/questions/49285475/how-to-copy-a-full-non-binary-tree-including-loops)
431+
If `break_sharing` is set to `Val(true)`, sharing in a tree will be ignored.
394432
"""
395433
function copy_node(
396434
tree::N; break_sharing::Val=Val(false)
@@ -409,12 +447,20 @@ function copy_node(
409447
)
410448
end
411449

450+
"""
451+
copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
452+
453+
Copy a node, recursively copying all children nodes.
454+
This is more efficient than the built-in copy.
455+
456+
If `break_sharing` is set to `Val(true)`, sharing in a tree will be ignored.
457+
"""
412458
function copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
413459
return copy_node(tree; break_sharing)
414460
end
415461

416462
"""
417-
convert(::Type{AbstractExpressionNode{T1}}, n::AbstractExpressionNode{T2}) where {T1,T2}
463+
convert(::Type{<:AbstractExpressionNode{T1}}, n::AbstractExpressionNode{T2}) where {T1,T2}
418464
419465
Convert a `AbstractExpressionNode{T2}` to a `AbstractExpressionNode{T1}`.
420466
This will recursively convert all children nodes to `AbstractExpressionNode{T1}`,

0 commit comments

Comments
 (0)