Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
version = "0.14.0"
version = "0.14.1"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
4 changes: 4 additions & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ function benchmark_utilities()
:get_set_constants!,
:index_constants,
:string_tree,
:hash,
)
has_both_modes = [:copy, :convert]
if PACKAGE_VERSION >= v"0.14.0"
Expand All @@ -122,6 +123,9 @@ function benchmark_utilities()
],
)
end
if PACKAGE_VERSION >= v"0.14.1"
append!(has_both_modes, [:hash])
end

operators = OperatorEnum(; binary_operators=[+, -, /, *], unary_operators=[cos, exp])
for func_k in all_funcs
Expand Down
2 changes: 2 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
using Documenter
using DynamicExpressions
using Random: AbstractRNG

makedocs(;
sitename="DynamicExpressions.jl",
authors="Miles Cranmer",
doctest=false,
clean=true,
format=Documenter.HTML(),
warnonly=true,
)

deploydocs(; repo="github.com/SymbolicML/DynamicExpressions.jl.git")
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# Contents

```@contents
Pages = ["types.md", "eval.md"]
Pages = ["utils.md", "types.md", "eval.md"]
```
47 changes: 47 additions & 0 deletions docs/src/utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Node utilities

## `Base`

Various functions in `Base` are overloaded to treat an `AbstractNode` as a
collection of its nodes.

```@docs
copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
collect(tree::AbstractNode; break_sharing::Val=Val(false))
filter(f::Function, tree::AbstractNode; break_sharing::Val=Val(false))
count(f::Function, tree::AbstractNode; init=0, break_sharing::Val=Val(false))
foreach(f::Function, tree::AbstractNode; break_sharing::Val=Val(false))
sum(f::F, tree::AbstractNode; init=0, return_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val=Val(false)) where {F<:Function}
mapreduce(f::F, op::G, tree::AbstractNode; return_type, f_on_shared, break_sharing) where {F<:Function,G<:Function}
any(f::F, tree::AbstractNode) where {F<:Function}
all(f::F, tree::AbstractNode) where {F<:Function}
map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)) where {F<:Function,RT}
convert(::Type{<:AbstractExpressionNode{T1}}, n::AbstractExpressionNode{T2}) where {T1,T2}
hash(tree::AbstractExpressionNode{T}, h::UInt; break_sharing::Val=Val(false)) where {T}
```

## Sampling

There are also methods for random sampling of nodes:

```@docs
NodeSampler
rand(rng::AbstractRNG, tree::AbstractNode)
rand(rng::AbstractRNG, sampler::NodeSampler{N,F,Nothing}) where {N,F}
```

## Internal utilities

Almost all node utilities are crafted using the `tree_mapreduce` function,
which evaluates a mapreduce over a tree-like (or graph-like) structure:

```@docs
tree_mapreduce
```

Various other utility functions include the following:

```@docs
filter_map
filter_map!
```
5 changes: 4 additions & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ include("EvaluationHelpers.jl")
include("SimplifyEquation.jl")
include("OperatorEnumConstruction.jl")
include("ExtensionInterface.jl")
include("Random.jl")

import PackageExtensionCompat: @require_extensions
import Reexport: @reexport
Expand All @@ -23,7 +24,8 @@ import Reexport: @reexport
copy_node,
set_node!,
tree_mapreduce,
filter_map
filter_map,
filter_map!
import .EquationModule: constructorof, preserve_sharing
@reexport import .EquationUtilsModule:
count_nodes,
Expand All @@ -44,6 +46,7 @@ import .EquationModule: constructorof, preserve_sharing
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree!
@reexport import .EvaluationHelpersModule
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
@reexport import .RandomModule: NodeSampler

function __init__()
@require_extensions
Expand Down
84 changes: 84 additions & 0 deletions src/Random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
module RandomModule

import Compat: Returns, @inline
import Random: AbstractRNG
import Base: rand
import ..EquationModule: AbstractNode, tree_mapreduce, filter_map

"""
NodeSampler(; tree, filter::Function=Returns(true), weighting::Union{Nothing,Function}=nothing, break_sharing::Val=Val(false))

Defines a sampler of nodes in a tree.

# Arguments

- `tree`: The tree to sample nodes from. For a regular `Node`,
nodes are sampled uniformly. For a `GraphNode`, nodes are also
sampled uniformly (e.g., in `sin(x) + {x}`, the `x` has equal
probability of being sampled from the `sin` or the `+` node, because
it is shared), unless `break_sharing` is set to `Val(true)`.
- `filter::Function`: A function that takes a node and returns a boolean
indicating whether the node should be sampled. Defaults to `Returns(true)`.
- `weighting::Union{Nothing,Function}`: A function that takes a node and
returns a weight for the node, if it passes the filter, proportional
to the probability of sampling the node. If `nothing`, all nodes are
sampled uniformly.
- `break_sharing::Val`: If `Val(true)`, the
sampler will break sharing in the tree, and sample nodes uniformly
from the tree.
"""
Base.@kwdef struct NodeSampler{
N<:AbstractNode,F<:Function,W<:Union{Nothing,Function},B<:Val
}
tree::N
weighting::W = nothing
filter::F = Returns(true)
break_sharing::B = Val(false)
end

"""
rand(rng::AbstractRNG, tree::AbstractNode)

Sample a node from a tree according to the default sampler `NodeSampler(; tree)`.
"""
rand(rng::AbstractRNG, tree::AbstractNode) = rand(rng, NodeSampler(; tree))

"""
rand(rng::AbstractRNG, sampler::NodeSampler)

Sample a node from a tree according to the sampler `sampler`.
"""
function rand(rng::AbstractRNG, sampler::NodeSampler{N,F,Nothing}) where {N,F}
n = count(sampler.filter, sampler.tree; sampler.break_sharing)
idx = rand(rng, 1:n)
return _get_node(sampler.tree, sampler.filter, idx, sampler.break_sharing)
end
function rand(rng::AbstractRNG, sampler::NodeSampler{N,F,W}) where {N,F,W<:Function}
weights = filter_map(
sampler.filter, sampler.weighting, sampler.tree, Float64; sampler.break_sharing
)
idx = _sample_idx(rng, weights)
return _get_node(sampler.tree, sampler.filter, idx, sampler.break_sharing)
end

function _get_node(
tree, filter_f::F, idx::Int, ::Val{break_sharing}
) where {F,break_sharing}
i = Ref(0)
out = Ref(tree)
foreach(tree; break_sharing=Val(break_sharing)) do node
if @inline(filter_f(node)) && (i[] += 1) == idx
out[] = node
end
nothing
end
return out[]
end

function _sample_idx(rng::AbstractRNG, weights)
csum = cumsum(weights)
r = rand(rng, eltype(weights)) * csum[end]
return findfirst(ci -> ci > r, csum)::Int
end

end
82 changes: 64 additions & 18 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ function count_nodes(tree::AbstractNode; break_sharing=Val(false))
end

"""
foreach(f::Function, tree::AbstractNode)
foreach(f::Function, tree::AbstractNode; break_sharing::Val=Val(false))

Apply a function to each node in a tree.
Apply a function to each node in a tree without returning the results.
"""
function foreach(
f::F, tree::AbstractNode; break_sharing::Val=Val(false)
Expand All @@ -244,7 +244,7 @@ function foreach(
end

"""
filter_map(filter_fnc::Function, map_fnc::Function, tree::AbstractNode, result_type::Type)
filter_map(filter_fnc::Function, map_fnc::Function, tree::AbstractNode, result_type::Type, break_sharing::Val=Val(false))

A faster equivalent to `map(map_fnc, filter(filter_fnc, tree))`
that avoids the intermediate allocation. However, using this requires
Expand Down Expand Up @@ -286,20 +286,25 @@ function filter_map!(
end

"""
filter(f::Function, tree::AbstractNode)
filter(f::Function, tree::AbstractNode; break_sharing::Val=Val(false))

Filter nodes of a tree, returning a flat array of the nodes for which the function returns `true`.
"""
function filter(f::F, tree::AbstractNode; break_sharing::Val=Val(false)) where {F<:Function}
return filter_map(f, identity, tree, typeof(tree); break_sharing)
end

"""
collect(tree::AbstractNode; break_sharing::Val=Val(false))

Collect all nodes in a tree into a flat array in depth-first order.
"""
function collect(tree::AbstractNode; break_sharing::Val=Val(false))
return filter(Returns(true), tree; break_sharing)
end

"""
map(f::Function, tree::AbstractNode, result_type::Type{RT}=Nothing)
map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)) where {F<:Function,RT}

Map a function over a tree and return a flat array of the results in depth-first order.
Pre-specifying the `result_type` of the function can be used to avoid extra allocations.
Expand All @@ -314,6 +319,11 @@ function map(
end
end

"""
count(f::F, tree::AbstractNode; init=0, break_sharing::Val=Val(false)) where {F<:Function}

Count the number of nodes in a tree for which the function returns `true`.
"""
function count(
f::F, tree::AbstractNode; init=0, break_sharing::Val=Val(false)
) where {F<:Function}
Expand All @@ -327,22 +337,44 @@ function count(
) + init
end

"""
sum(f::Function, tree::AbstractNode; init=0, return_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val=Val(false)) where {F<:Function}

Sum the results of a function over a tree. For graphs with shared nodes
such as `GraphNode`, the function `f_on_shared` is called on the result
of each shared node. This is used to avoid double-counting shared nodes (default
behavior).
"""
function sum(
f::F,
tree::AbstractNode;
init=0,
return_type=Undefined,
f_on_shared=(c, is_shared) -> is_shared ? (false * c) : c,
f_on_shared=_default_shared_aggregation,
break_sharing::Val=Val(false),
) where {F<:Function}
if preserve_sharing(typeof(tree))
@assert typeof(return_type) !== Undefined "Must specify `return_type` as a keyword argument to `sum` if `preserve_sharing` is true."
end
return tree_mapreduce(f, +, tree, return_type; f_on_shared, break_sharing) + init
end
function _default_shared_aggregation(c, is_shared)
return is_shared ? (false * c) : c
end

"""
all(f::Function, tree::AbstractNode)

Reduce a flag function over a tree, returning `true` if the
function returns `true` for all nodes, `false` otherwise.
"""
all(f::F, tree::AbstractNode) where {F<:Function} = !any(t -> !@inline(f(t)), tree)

"""
mapreduce(f::Function, op::Function, tree::AbstractNode; return_type, f_on_shared, break_sharing)

Map a function over a tree and aggregate the result using an operator `op`.
"""
function mapreduce(
f::F,
op::G,
Expand All @@ -369,28 +401,34 @@ function length(tree::AbstractNode; break_sharing::Val=Val(false))
return count_nodes(tree; break_sharing)
end

function hash(tree::AbstractExpressionNode{T}) where {T}
"""
hash(tree::AbstractExpressionNode{T}[, h::UInt]; break_sharing::Val=Val(false)) where {T}

Compute a hash of a tree. This will compute a hash differently
if nodes are shared in a tree. This is ignored if `break_sharing` is set to `Val(true)`.
"""
function hash(
tree::AbstractExpressionNode{T}, h::UInt=zero(UInt); break_sharing::Val=Val(false)
) where {T}
return tree_mapreduce(
t -> t.constant ? hash((0, t.val::T)) : hash((1, t.feature)),
t -> hash((t.degree + 1, t.op)),
(n...) -> hash(n),
t -> t.constant ? hash((0, t.val::T), h) : hash((1, t.feature), h),
t -> hash((t.degree + 1, t.op), h),
(n...) -> hash(n, h),
tree,
UInt64;
UInt;
f_on_shared=(cur_hash, is_shared) ->
is_shared ? hash((:shared, cur_hash)) : cur_hash,
is_shared ? hash((:shared, cur_hash), h) : cur_hash,
break_sharing,
)
end

"""
copy_node(tree::AbstractExpressionNode)
copy_node(tree::AbstractExpressionNode; break_sharing::Val=Val(false))

Copy a node, recursively copying all children nodes.
This is more efficient than the built-in copy.

id_map is a map from `objectid(tree)` to `copy(tree)`.
We check against the map before making a new copy; otherwise
we can simply reference the existing copy.
[Thanks to Ted Hopp.](https://stackoverflow.com/questions/49285475/how-to-copy-a-full-non-binary-tree-including-loops)
If `break_sharing` is set to `Val(true)`, sharing in a tree will be ignored.
"""
function copy_node(
tree::N; break_sharing::Val=Val(false)
Expand All @@ -409,12 +447,20 @@ function copy_node(
)
end

"""
copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))

Copy a node, recursively copying all children nodes.
This is more efficient than the built-in copy.

If `break_sharing` is set to `Val(true)`, sharing in a tree will be ignored.
"""
function copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
return copy_node(tree; break_sharing)
end

"""
convert(::Type{AbstractExpressionNode{T1}}, n::AbstractExpressionNode{T2}) where {T1,T2}
convert(::Type{<:AbstractExpressionNode{T1}}, n::AbstractExpressionNode{T2}) where {T1,T2}

Convert a `AbstractExpressionNode{T2}` to a `AbstractExpressionNode{T1}`.
This will recursively convert all children nodes to `AbstractExpressionNode{T1}`,
Expand Down
Loading