Skip to content

Commit 2945731

Browse files
Improve type stability of cached walks (#82)
* improve type stability of cached walks * fix doctest format * handle old julia version
1 parent c0936a5 commit 2945731

File tree

4 files changed

+67
-8
lines changed

4 files changed

+67
-8
lines changed

src/Functors.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export @leaf, @functor, @flexiblefunctor,
1313
include("functor.jl")
1414
include("keypath.jl")
1515
include("walks.jl")
16+
include("cache.jl")
1617
include("maps.jl")
1718
include("base.jl")
1819

src/cache.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
struct WalkCache{K, V, W <: AbstractWalk, C <: AbstractDict{K, V}} <: AbstractDict{K, V}
2+
walk::W
3+
cache::C
4+
WalkCache(walk, cache::AbstractDict{K, V} = IdDict()) where {K, V} = new{K, V, typeof(walk), typeof(cache)}(walk, cache)
5+
end
6+
Base.length(cache::WalkCache) = length(cache.cache)
7+
Base.empty!(cache::WalkCache) = empty!(cache.cache)
8+
Base.haskey(cache::WalkCache, x) = haskey(cache.cache, x)
9+
Base.get(cache::WalkCache, x, default) = haskey(cache.cache, x) ? cache[x] : default
10+
Base.iterate(cache::WalkCache, state...) = iterate(cache.cache, state...)
11+
Base.setindex!(cache::WalkCache, value, key) = setindex!(cache.cache, value, key)
12+
Base.getindex(cache::WalkCache, x) = cache.cache[x]
13+
14+
@static if VERSION >= v"1.10.0-DEV.609"
15+
function __cacheget_generator__(world, source, self, cache, x, args #= for `return_type` only =#)
16+
# :(return cache.cache[x]::(return_type(cache.walk, typeof(args))))
17+
walk = cache.parameters[3]
18+
RT = Core.Compiler.return_type(Tuple{walk, args...}, world)
19+
body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x)
20+
if RT != Any
21+
body = Expr(:(::), body, RT)
22+
end
23+
expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args],
24+
Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body))))
25+
ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__)
26+
ci.inlineable = true
27+
return ci
28+
end
29+
@eval function cacheget(cache::WalkCache, x, args...)
30+
$(Expr(:meta, :generated, __cacheget_generator__))
31+
$(Expr(:meta, :generated_only))
32+
end
33+
else
34+
@generated function cacheget(cache::WalkCache, x, args...)
35+
walk = cache.parameters[3]
36+
world = typemax(UInt)
37+
@static if VERSION >= v"1.8"
38+
RT = Core.Compiler.return_type(Tuple{walk, args...}, world)
39+
else
40+
if isdefined(walk, :instance)
41+
RT = Core.Compiler.return_type(walk.instance, Tuple{args...}, world)
42+
else
43+
RT = Any
44+
end
45+
end
46+
body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x)
47+
if RT != Any
48+
body = Expr(:(::), body, RT)
49+
end
50+
expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args],
51+
Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body))))
52+
ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__)
53+
ci.inlineable = true
54+
return ci
55+
end
56+
end
57+
# fallback behavior that only lookup for `x`
58+
@inline cacheget(cache::AbstractDict, x, args...) = cache[x]

src/maps.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function fmap(f, x, ys...; exclude = isleaf,
66
prune = NoKeyword())
77
_walk = ExcludeWalk(AnonymousWalk(walk), f, exclude)
88
if !isnothing(cache)
9-
_walk = CachedWalk(_walk, prune, cache)
9+
_walk = CachedWalk(_walk, prune, WalkCache(_walk, cache))
1010
end
1111
execute(_walk, x, ys...)
1212
end
@@ -18,7 +18,7 @@ function fmap_with_path(f, x, ys...; exclude = isleaf,
1818

1919
_walk = ExcludeWalkWithKeyPath(walk, f, exclude)
2020
if !isnothing(cache)
21-
_walk = CachedWalkWithPath(_walk, prune, cache)
21+
_walk = CachedWalkWithPath(_walk, prune, WalkCache(_walk, cache))
2222
end
2323
return execute(_walk, KeyPath(), x, ys...)
2424
end

src/walks.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,18 @@ Whenever the cache already contains `x`, either:
181181
182182
Typically wraps an existing `walk` for use with [`fmap`](@ref).
183183
"""
184-
struct CachedWalk{T, S} <: AbstractWalk
184+
struct CachedWalk{T, S, C <: AbstractDict} <: AbstractWalk
185185
walk::T
186186
prune::S
187-
cache::IdDict{Any, Any}
187+
cache::C
188188
end
189189
CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) =
190190
CachedWalk(walk, prune, cache)
191191

192192
function (walk::CachedWalk)(recurse, x, ys...)
193193
should_cache = usecache(walk.cache, x)
194194
if should_cache && haskey(walk.cache, x)
195-
return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune
195+
return walk.prune isa NoKeyword ? cacheget(walk.cache, x, recurse, x, ys...) : walk.prune
196196
else
197197
ret = walk.walk(recurse, x, ys...)
198198
if should_cache
@@ -202,10 +202,10 @@ function (walk::CachedWalk)(recurse, x, ys...)
202202
end
203203
end
204204

205-
struct CachedWalkWithPath{T, S} <: AbstractWalk
205+
struct CachedWalkWithPath{T, S, C <: AbstractDict} <: AbstractWalk
206206
walk::T
207207
prune::S
208-
cache::IdDict{Any, Any}
208+
cache::C
209209
end
210210

211211
CachedWalkWithPath(walk; prune = NoKeyword(), cache = IdDict()) =
@@ -214,7 +214,7 @@ CachedWalkWithPath(walk; prune = NoKeyword(), cache = IdDict()) =
214214
function (walk::CachedWalkWithPath)(recurse, kp::KeyPath, x, ys...)
215215
should_cache = usecache(walk.cache, x)
216216
if should_cache && haskey(walk.cache, x)
217-
return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune
217+
return walk.prune isa NoKeyword ? cacheget(walk.cache, x, recurse, kp, x, ys...) : walk.prune
218218
else
219219
ret = walk.walk(recurse, kp, x, ys...)
220220
if should_cache

0 commit comments

Comments
 (0)