Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
a246b11
Initial implementation of `tree_map` and `tree_mapreduce`
MilesCranmer Apr 15, 2023
c718bf8
Simplify `map` as special case of `map_reduce`
MilesCranmer Apr 15, 2023
f1cd9e1
Formatting
MilesCranmer Apr 15, 2023
71ec0cb
Rewrite counting utils with tree_map
MilesCranmer Apr 24, 2023
65141bc
Inline everything
MilesCranmer Apr 24, 2023
d78f1b4
Create `tree_map!` for setting properties
MilesCranmer Apr 24, 2023
618ac18
Create `tree_any` for flagging things
MilesCranmer Apr 24, 2023
cc37acf
Replace more functions with tree map equivalents
MilesCranmer Apr 24, 2023
6e8add6
Rename `set_constants` to `set_constants!`
MilesCranmer Apr 24, 2023
3fd2709
Add docstrings to utility functions
MilesCranmer Apr 24, 2023
de87062
Extra clean up
MilesCranmer Apr 24, 2023
2365c04
Naming conventions
MilesCranmer Apr 24, 2023
dc0738c
Remove count_nodes_with_stack for simplicity
MilesCranmer Apr 24, 2023
21de96b
Embed tree map into Equation.jl
MilesCranmer Apr 24, 2023
b523a13
Define additional Base functions for node
MilesCranmer Apr 25, 2023
bd8b09a
Faster get_constants
MilesCranmer Apr 26, 2023
9a73385
Speed up set constants by return nothing
MilesCranmer Apr 26, 2023
62270bd
Improve code readability
MilesCranmer Apr 26, 2023
bd5d973
Fix inline function on Julia <1.8
MilesCranmer Apr 26, 2023
dd62888
Only overload iterate for nodes
MilesCranmer Apr 26, 2023
004f280
Fix inline call
MilesCranmer Apr 26, 2023
2e561d6
Implement faster getindex and setindex!
MilesCranmer Apr 27, 2023
afacbe1
Benchmark other utility functions
MilesCranmer Apr 27, 2023
594e971
Fix access of module
MilesCranmer Apr 27, 2023
0856bb6
Fix behavior for deprecated set_constants
MilesCranmer Apr 27, 2023
4d20ff0
Fix overloading of `all`
MilesCranmer Apr 27, 2023
d46c5ae
Rename back to tree_mapreduce
MilesCranmer Apr 27, 2023
d623e70
Add additional Base functions
MilesCranmer Apr 27, 2023
10496dc
Fix speed of is_constant
MilesCranmer Apr 27, 2023
03fa330
Define isempty
MilesCranmer Apr 27, 2023
c330340
Fix error in is_constant
MilesCranmer Apr 27, 2023
cbde551
Clean up tree map code
MilesCranmer Apr 27, 2023
84f4e75
Preallocate stack of nodes
MilesCranmer Apr 27, 2023
9bae08e
Faster equality checks
MilesCranmer Apr 27, 2023
dc6eabd
Improve docs for `tree_mapreduce`
MilesCranmer Apr 29, 2023
c281e6b
Add filter_and_mapreduce
MilesCranmer Apr 30, 2023
e085e08
Remove filter_and_mapreduce
MilesCranmer Apr 30, 2023
7f72615
Require result type for `filter_and_map`
MilesCranmer Apr 30, 2023
adfc611
Fix shared node preserving convert in benchmark
MilesCranmer Apr 30, 2023
f3f3ad7
Add `foreach`
MilesCranmer Apr 30, 2023
9737763
Backport with Compat.jl
MilesCranmer Apr 30, 2023
a776662
Move Base overloading functions to single file
MilesCranmer May 5, 2023
a1cd2bb
Fix error in definition of ==
MilesCranmer May 6, 2023
d419994
Make copy_node a derived function
MilesCranmer May 6, 2023
fbc2827
Allow different functions for leaf and branch
MilesCranmer May 6, 2023
298e61d
Maybe `convert` a derived function
MilesCranmer May 6, 2023
5353fc8
Missing return
MilesCranmer May 6, 2023
c471863
Apparently passing type is faster
MilesCranmer May 6, 2023
17d4e7e
Make macro name more intuitive
MilesCranmer May 6, 2023
439c164
Try to inline more calls
MilesCranmer May 6, 2023
bf7f79f
Fix performance issue due to no specialization on kwargs
MilesCranmer May 7, 2023
00f8e7a
Make `map` specialize return type as well
MilesCranmer May 7, 2023
e1ac913
Make `mapreduce` a derived function
MilesCranmer May 7, 2023
527f65c
Make `==` more readable
MilesCranmer May 7, 2023
7f3efb9
Move undefined functions to end
MilesCranmer May 7, 2023
f0a30d5
Make `filter_map` a derived function
MilesCranmer May 7, 2023
d42b570
Clean up `filter_map!`
MilesCranmer May 7, 2023
abf58b2
Remove unused indexing utilities
MilesCranmer May 8, 2023
20a615a
Clean up tree conversion
MilesCranmer May 8, 2023
de7fee7
Reduce intermediate variables
MilesCranmer May 8, 2023
32bc1ae
Avoid intermediate allocation for == with mismatching types
MilesCranmer May 8, 2023
b191482
Speed up mapreduce with closure
MilesCranmer May 8, 2023
a94286d
Rename `tree_map.jl` to `base.jl`
MilesCranmer May 8, 2023
69624ce
Rename `with_memoize` macro
MilesCranmer May 8, 2023
f21c4cf
Plug performance hole
MilesCranmer May 8, 2023
a247c69
Add tests for tree map related functions
MilesCranmer May 8, 2023
cdabf24
Remove unused `map!` code
MilesCranmer May 8, 2023
0adf895
Fix tests of Base
MilesCranmer May 8, 2023
fcc056e
Remove redundant import
MilesCranmer May 8, 2023
c100d50
Test error is thrown for unsupported functions
MilesCranmer May 8, 2023
cfc902f
Test other branch of map
MilesCranmer May 8, 2023
5073e35
Bump version with Base utilities
MilesCranmer May 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
version = "0.7.0"
version = "0.8.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Compat = "3.37, 4"
LoopVectorization = "0.12"
MacroTools = "0.4, 0.5"
Reexport = "1"
PrecompileTools = "1"
Reexport = "1"
SymbolicUtils = "0.19, ^1.0.5"
Zygote = "0.6"
julia = "1.6"
Expand Down
52 changes: 38 additions & 14 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DynamicExpressions, BenchmarkTools, Random
using DynamicExpressions: copy_node
using DynamicExpressions.EquationUtilsModule: is_constant

include("benchmark_utils.jl")

Expand Down Expand Up @@ -73,30 +73,54 @@ end
PACKAGE_VERSION < v"0.7.0" && return :(copy_node(t; preserve_topology=preserve_sharing))
return :(copy_node(t; preserve_sharing=preserve_sharing))
end
@generated function get_set_constants!(tree)
!(@isdefined set_constants!) && return :(set_constants(tree, get_constants(tree)))
return :(set_constants!(tree, get_constants(tree)))
end
#! format: on

f_tree_op(f::F, tree, operators) where {F} = f(tree, operators)
f_tree_op(f::F, tree) where {F} = f(tree)

function benchmark_utilities()
suite = BenchmarkGroup()

all_funcs = (
:copy,
:convert,
:simplify_tree,
:combine_operators,
:count_nodes,
:count_depth,
:count_constants,
:has_constants,
:has_operators,
:is_constant,
:get_set_constants!,
:index_constants,
)

operators = OperatorEnum(; binary_operators=[+, -, /, *], unary_operators=[cos, exp])
for func_k in ("copy", "convert", "simplify_tree", "combine_operators")

for func_k in all_funcs
suite[func_k] = let s = BenchmarkGroup()
for k in ("break_sharing", "preserve_sharing")
k == "preserve_sharing" &&
func_k in ("simplify_tree", "combine_operators") &&
continue
for k in (:break_sharing, :preserve_sharing)
k == :preserve_sharing && !(func_k in (:copy, :convert)) && continue

f = if func_k == "copy"
tree -> _copy_node(tree; preserve_sharing=(k == "preserve_sharing"))
elseif func_k == "convert"
f = if func_k == :copy
tree -> _copy_node(tree; preserve_sharing=(k == :preserve_sharing))
elseif func_k == :convert
tree -> _convert(
Node{Float64},
tree;
preserve_sharing=(k == "preserve_sharing"),
preserve_sharing=(k == :preserve_sharing),
)
elseif func_k == "simplify_tree"
tree -> simplify_tree(tree, operators)
elseif func_k == "combine_operators"
tree -> combine_operators(tree, operators)
elseif func_k in (:simplify_tree, :combine_operators)
g = getfield(@__MODULE__, func_k)
tree -> f_tree_op(g, tree, operators)
else
g = getfield(@__MODULE__, func_k)
tree -> f_tree_op(g, tree)
end

#! format: off
Expand Down
8 changes: 5 additions & 3 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ include("SimplifyEquation.jl")
include("OperatorEnumConstruction.jl")

using Reexport
@reexport import .EquationModule: Node, string_tree, print_tree, copy_node, set_node!
@reexport import .EquationModule:
Node, string_tree, print_tree, copy_node, set_node!, tree_mapreduce, filter_map
@reexport import .EquationUtilsModule:
count_nodes,
count_nodes_with_stack,
count_constants,
count_depth,
NodeIndex,
index_constants,
has_operators,
has_constants,
get_constants,
set_constants
set_constants!
@reexport import .OperatorEnumModule: AbstractOperatorEnum
@reexport import .OperatorEnumConstructionModule:
OperatorEnum, GenericOperatorEnum, @extend_operators
Expand All @@ -34,6 +34,8 @@ using Reexport
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree
@reexport import .EvaluationHelpersModule

include("deprecated.jl")

import TOML: parsefile

const PACKAGE_VERSION = let
Expand Down
134 changes: 2 additions & 132 deletions src/Equation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module EquationModule

import ..OperatorEnumModule: AbstractOperatorEnum
import ..UtilsModule: @generate_idmap, @use_idmap
import ..UtilsModule: @memoize_on, @with_memoize

const DEFAULT_NODE_TYPE = Float32

Expand Down Expand Up @@ -62,51 +62,7 @@ mutable struct Node{T}
end
################################################################################

"""
convert(::Type{Node{T1}}, n::Node{T2}) where {T1,T2}

Convert a `Node{T2}` to a `Node{T1}`.
This will recursively convert all children nodes to `Node{T1}`,
using `convert(T1, tree.val)` at constant nodes.

# Arguments
- `::Type{Node{T1}}`: Type to convert to.
- `tree::Node{T2}`: Node to convert.
"""
function Base.convert(
::Type{Node{T1}}, tree::Node{T2}; preserve_sharing::Bool=false
) where {T1,T2}
if T1 == T2
return tree
end
if preserve_sharing
@use_idmap(_convert(Node{T1}, tree), IdDict{Node{T2},Node{T1}}())
else
_convert(Node{T1}, tree)
end
end

@generate_idmap tree function _convert(::Type{Node{T1}}, tree::Node{T2}) where {T1,T2}
if tree.degree == 0
if tree.constant
val = tree.val::T2
if !(T2 <: T1)
# e.g., we don't want to convert Float32 to Union{Float32,Vector{Float32}}!
val = convert(T1, val)
end
Node(T1, 0, tree.constant, val)
else
Node(T1, 0, tree.constant, nothing, tree.feature)
end
elseif tree.degree == 1
l = _convert(Node{T1}, tree.l)
Node(1, tree.constant, nothing, tree.feature, tree.op, l)
else
l = _convert(Node{T1}, tree.l)
r = _convert(Node{T1}, tree.r)
Node(2, tree.constant, nothing, tree.feature, tree.op, l, r)
end
end
include("base.jl")

"""
Node([::Type{T}]; val=nothing, feature::Int=nothing) where {T}
Expand Down Expand Up @@ -224,45 +180,6 @@ function set_node!(tree::Node{T}, new_tree::Node{T}) where {T}
return nothing
end

"""
copy_node(tree::Node; preserve_sharing::Bool=false)

Copy a node, recursively copying all children nodes.
This is more efficient than the built-in copy.
With `preserve_sharing=true`, this will also
preserve linkage between a node and
multiple parents, whereas without, this would create
duplicate child node copies.

id_map is a map from `objectid(tree)` to `copy(tree)`.
We check against the map before making a new copy; otherwise
we can simply reference the existing copy.
[Thanks to Ted Hopp.](https://stackoverflow.com/questions/49285475/how-to-copy-a-full-non-binary-tree-including-loops)

Note that this will *not* preserve loops in graphs.
"""
function copy_node(tree::Node{T}; preserve_sharing::Bool=false)::Node{T} where {T}
if preserve_sharing
@use_idmap(_copy_node(tree), IdDict{Node{T},Node{T}}())
else
_copy_node(tree)
end
end

@generate_idmap tree function _copy_node(tree::Node{T})::Node{T} where {T}
if tree.degree == 0
if tree.constant
Node(; val=copy(tree.val::T))
else
Node(T; feature=copy(tree.feature))
end
elseif tree.degree == 1
Node(copy(tree.op), _copy_node(tree.l))
else
Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r))
end
end

const OP_NAMES = Dict(
"safe_log" => "log",
"safe_log2" => "log2",
Expand Down Expand Up @@ -363,51 +280,4 @@ function print_tree(
return println(string_tree(tree, operators; varMap=varMap))
end

function Base.hash(tree::Node{T})::UInt where {T}
if tree.degree == 0
if tree.constant
# tree.val used.
return hash((0, tree.val::T))
else
# tree.feature used.
return hash((1, tree.feature))
end
elseif tree.degree == 1
return hash((1, tree.op, hash(tree.l)))
else
return hash((2, tree.op, hash(tree.l), hash(tree.r)))
end
end

function is_equal(a::Node{T}, b::Node{T})::Bool where {T}
if a.degree == 0
b.degree != 0 && return false
if a.constant
!(b.constant) && return false
return a.val::T == b.val::T
else
b.constant && return false
return a.feature == b.feature
end
elseif a.degree == 1
b.degree != 1 && return false
a.op != b.op && return false
return is_equal(a.l, b.l)
else
b.degree != 2 && return false
a.op != b.op && return false
return is_equal(a.l, b.l) && is_equal(a.r, b.r)
end
end

function Base.:(==)(a::Node{T}, b::Node{T})::Bool where {T}
return is_equal(a, b)
end

function Base.:(==)(a::Node{T1}, b::Node{T2})::Bool where {T1,T2}
T = promote_type(T1, T2)
# TODO: Should also have preserve_sharing check...
return is_equal(convert(Node{T}, a), convert(Node{T}, b))
end

end
Loading