Skip to content

Commit 11af6b1

Browse files
committed
Speed up mapreduce with closure
1 parent 32bc1ae commit 11af6b1

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

src/tree_map.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -78,30 +78,28 @@ function tree_mapreduce(
7878
result_type::Type{RT}=Nothing;
7979
preserve_sharing::Bool=false,
8080
) where {T,N<:Node{T},F1<:Function,F2<:Function,G<:Function,RT}
81-
if preserve_sharing && RT != Nothing
82-
return @with_memoization _tree_mapreduce(f_leaf, f_branch, op, tree) IdDict{N,RT}()
83-
elseif preserve_sharing
84-
throw(ArgumentError("Need to specify `result_type` if you use `preserve_sharing`."))
81+
82+
# Trick taken from here:
83+
# https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
84+
# to speed up recursive closure
85+
@memoize_on t function inner(inner, t::Node)
86+
if t.degree == 0
87+
return @inline(f_leaf(t))
88+
elseif t.degree == 1
89+
return @inline(op(@inline(f_branch(t)), inner(inner, t.l)))
90+
else
91+
return @inline(op(@inline(f_branch(t)), inner(inner, t.l), inner(inner, t.r)))
92+
end
8593
end
86-
return _tree_mapreduce(f_leaf, f_branch, op, tree)
87-
end
88-
@memoize_on tree function _tree_mapreduce(
89-
f_leaf::F1, f_branch::F2, op::G, tree::Node
90-
) where {F1<:Function,F2<:Function,G<:Function}
91-
if tree.degree == 0
92-
return @inline(f_leaf(tree))
93-
elseif tree.degree == 1
94-
return @inline(
95-
op(@inline(f_branch(tree)), _tree_mapreduce(f_leaf, f_branch, op, tree.l))
96-
)
94+
95+
RT == Nothing &&
96+
preserve_sharing &&
97+
throw(ArgumentError("Need to specify `result_type` if you use `preserve_sharing`."))
98+
99+
if preserve_sharing
100+
return @with_memoization inner(inner, tree) IdDict{N,RT}()
97101
else
98-
return @inline(
99-
op(
100-
@inline(f_branch(tree)),
101-
_tree_mapreduce(f_leaf, f_branch, op, tree.l),
102-
_tree_mapreduce(f_leaf, f_branch, op, tree.r),
103-
)
104-
)
102+
return inner(inner, tree)
105103
end
106104
end
107105

0 commit comments

Comments
 (0)