-
-
Notifications
You must be signed in to change notification settings - Fork 16
Extract common functionality into fold #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T)) | |
escargs = map(fieldnames(T)) do f | ||
f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f) | ||
end | ||
escfs = [:($f=x.$f) for f in fs] | ||
escfs = [:($f = x.$f) for f in fs] | ||
|
||
@eval m begin | ||
$Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...)) | ||
|
@@ -169,7 +169,54 @@ function _default_walk(f, x) | |
func, re = functor(x) | ||
re(map(f, func)) | ||
end | ||
_default_walk(f, ::Nothing, ::Nothing) = nothing | ||
_default_walk(_, ::Nothing, ::Nothing) = nothing | ||
|
||
# Side effects only, saves a restructure | ||
function _foreach_walk(f, x) | ||
foreach(f, children(x)) | ||
return x | ||
end | ||
|
||
### WARNING: the following is unstable internal functionality. Use at your own risk! | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I considered |
||
# Wrapper over an IdDict which only saves values with a stable object identity | ||
struct Cache{K,V} | ||
inner::IdDict{K,V} | ||
end | ||
Cache() = Cache(IdDict()) | ||
|
||
usecache(x) = !isbits(x) && ismutable(x) | ||
# Functionally immutable and observe value semantics, but still `ismutable` and not `isbits` | ||
usecache(::Union{String,Symbol}) = false | ||
# For varargs | ||
usecache(xs::Tuple) = all(usecache, xs) | ||
Base.get!(f, c::Cache, x) = usecache(x) ? get!(f, c.inner, x) : f() | ||
|
||
# Passthrough used to disable caching (e.g. when passing `cache=false`) | ||
struct NoCache end | ||
Base.get!(f, ::NoCache, _) = f() | ||
|
||
# Encapsulates the self-recursive part of a recursive tree reduction (fold). | ||
# This allows calling functions to remove any self-calls or nested callback closures. | ||
struct Fold{F,L,C,W} | ||
fn::F | ||
isleaf::L | ||
cache::C | ||
walk::W | ||
end | ||
(fld::Fold)(x) = get!(fld.cache, x) do | ||
fld.fn(fld.isleaf(x) ? x : fld.walk(fld, x)) | ||
end | ||
|
||
# Convenience function for working with `Fold` | ||
function fold(f, x; isleaf = isleaf, cache = false, walk = _default_walk) | ||
if cache === true | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that |
||
cache = Cache() | ||
elseif cache === false | ||
cache = NoCache() | ||
end | ||
return Fold(f, isleaf, cache, walk)(x) | ||
end | ||
### end of unstable internal functionality | ||
|
||
""" | ||
fmap(f, x; exclude = Functors.isleaf, walk = Functors._default_walk) | ||
|
@@ -253,11 +300,9 @@ Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7)))) | |
``` | ||
""" | ||
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict()) | ||
haskey(cache, x) && return cache[x] | ||
y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x) | ||
cache[x] = y | ||
|
||
return y | ||
return fold(x; cache, walk, isleaf = exclude) do node | ||
exclude(node) ? f(node) : node | ||
end | ||
end | ||
|
||
""" | ||
|
@@ -296,8 +341,8 @@ fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)) | |
fcollect(x; exclude = v -> false) | ||
|
||
Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref) | ||
and collecting the results into a flat array, ordered by a breadth-first | ||
traversal of `x`, respecting the iteration order of `children` calls. | ||
and collecting the results into a flat array, ordered by a depth-first, | ||
post-order traversal of `x` that respects the iteration order of `children` calls. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is technically breaking (as can be seen from the test changes), but it's unclear to me how many users were relying on the breadth-first iteration order. The one example I found on JuliaHub, AlphaZero, was not. It's also odd that the original worked breadth-first when pretty much all other traversals based on functors are depth-first. So RFC, and happy to back this out if we consider it too breaking. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine to have this breaking change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that nothing now in Optimisers.jl now cares about the order taken by Of course the two halves of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The order in |
||
|
||
Doesn't recurse inside branches rooted at nodes `v` | ||
for which `exclude(v) == true`. | ||
|
@@ -324,33 +369,27 @@ Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) | |
|
||
julia> fcollect(m) | ||
4-element Vector{Any}: | ||
Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) | ||
Bar([1, 2, 3]) | ||
[1, 2, 3] | ||
Bar([1, 2, 3]) | ||
NoChildren(:a, :b) | ||
Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) | ||
|
||
julia> fcollect(m, exclude = v -> v isa Bar) | ||
2-element Vector{Any}: | ||
Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) | ||
NoChildren(:a, :b) | ||
Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) | ||
|
||
julia> fcollect(m, exclude = v -> Functors.isleaf(v)) | ||
2-element Vector{Any}: | ||
Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) | ||
Bar([1, 2, 3]) | ||
Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) | ||
``` | ||
""" | ||
function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false) | ||
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache | ||
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector` | ||
# for the results, to preserve traversal order (important downstream!). | ||
x in cache && return output | ||
if !exclude(x) | ||
push!(cache, x) | ||
push!(output, x) | ||
foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x)) | ||
end | ||
return output | ||
function fcollect(x; output = [], cache = Base.IdDict(), exclude = v -> false) | ||
fold(x; cache, isleaf = exclude, walk = _foreach_walk) do node | ||
exclude(node) || push!(output, node); # always return nothing | ||
end | ||
return output | ||
end | ||
|
||
# Allow gradients and other constructs that match the structure of the functor | ||
|
Uh oh!
There was an error while loading. Please reload this page.