Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract common functionality into fold #32

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 64 additions & 24 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T))
escargs = map(fieldnames(T)) do f
f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f)
end
escfs = [:($f=x.$f) for f in fs]
escfs = [:($f = x.$f) for f in fs]

@eval m begin
$Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...))
Expand Down Expand Up @@ -169,7 +169,54 @@ function _default_walk(f, x)
func, re = functor(x)
re(map(f, func))
end
_default_walk(f, ::Nothing, ::Nothing) = nothing
_default_walk(_, ::Nothing, ::Nothing) = nothing

# Side effects only, saves a restructure
function _foreach_walk(f, x)
foreach(f, children(x))
return x
end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

### WARNING: the following is unstable internal functionality. Use at your own risk!
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered _-prefixing everything but didn't for now because these will probably become part of the public API at some point. Happy to prefix them if that's desired though.

# Wrapper over an IdDict which only saves values with a stable object identity
struct Cache{K,V}
inner::IdDict{K,V}
end
Cache() = Cache(IdDict())

iscachesafe(x) = !isbits(x) && ismutable(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I suggest calling this usecache? It's not so clear to me which way "safe" points, safe that we don't need the cache, or safe that we can use it?

Also, why ismutable? This catches only the outermost wrapper, which I don't think ought to matter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usecache sounds better, changed.

ismutable is required because this function is called at every level of the tree. So if there is a string in a (possibly deeply) nested struct that otherwise contains all bitstype values, isbits will fail and we will get a false positive. In fact ismutable may be sufficient on its own, but in the spirit of not making too many assumptions around underspecified language behaviour (namely that all mutable structs have stable objectids) I've kept it in to be safe.

# Functionally immutable and observe value semantics, but still `ismutable` and not `isbits`
iscachesafe(::Union{String,Symbol}) = false
# For varargs
iscachesafe(xs::Tuple) = all(iscachesafe, xs)
Base.get!(f, c::Cache, x) = iscachesafe(x) ? get!(f, c.inner, x) : f()
ToucheSir marked this conversation as resolved.
Show resolved Hide resolved

# Passthrough used to disable caching (e.g. when passing `cache=false`)
struct NoCache end
Base.get!(f, ::NoCache, _) = f()
Copy link
Member

@mcabbott mcabbott Jan 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style not logic, but could this be more tidily done as a safeget! function, rather than types on which Base's functions act? There is I think exactly one line which calls this get! on these types.

Something like

safeget!(f, d::AbstractDict, k) = usecache(k) ? get!(f, d, k) : f()
safeget!(f, d, k) = f()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had exactly this in an earlier draft, but unfortunately it's a breaking change (i.e. tests fail) because the default before was to unconditionally trust IdDict/objectid/===. So value types were being deduped via the cache. If we decide to call that a bug, then I can put this suggestion in place.

Another theoretical benefit of not hard-coding usecache would be allowing users to customize caching behaviour on a case-by-case basis. I haven't thought of a scenario that would require such a degree of customization yet (e.g. TiedArrays could simply overload usecache), so we can cross that bridge when we come to it.

Copy link
Member

@mcabbott mcabbott Jan 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, my suggestion still leaves usecache as an overloadable function. It's not intended as a change of logic, just more compact & fewer types?

Testing isbits etc. seems at all seems like a potentially breaking change, as it changes how many times f will be called (e.g. what does fmap(println, (1,2,3,1,2,3)) print?). Maybe it depends how firmly you believe f must be pure... certainly fcollect is going to change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking for isbits and ismutable is a hedge against an impure f. IMO users would be more surprised if fmap(println, (1,2,3,1,2,3)) printed 1\n2\n3\n as it currently does.

Copy link
Member

@mcabbott mcabbott Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I'm in favour of some such usecache check. But perhaps this can be factored out, and we can discuss it this one has the right rules. (Should it be any or all?) This seems largely orthogonal.

The point of this comment though was to wonder whether all these structs are necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to clarify, yea to calling the current behaviour a bug and unconditionally calling usecache if it helps get rid of the structs?


# Encapsulates the self-recursive part of a recursive tree reduction (fold).
# This allows calling functions to remove any self-calls or nested callback closures.
struct Fold{F,L,C,W}
fn::F
isleaf::L
cache::C
walk::W
end
(fld::Fold)(x) = get!(fld.cache, x) do
fld.fn(fld.isleaf(x) ? x : fld.walk(fld, x))
end

# Convenience function for working with `Fold`
function fold(f, x; isleaf = isleaf, cache = false, walk = _default_walk)
if cache === true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that fmap and fcollect are still passing through IdDict() by default. The hope is that new functions either don't allow customization of the cache or use a "safer" default.

cache = Cache()
elseif cache === false
cache = NoCache()
end
return Fold(f, isleaf, cache, walk)(x)
end
### end of unstable internal functionality

"""
fmap(f, x; exclude = Functors.isleaf, walk = Functors._default_walk)
Expand Down Expand Up @@ -253,11 +300,10 @@ Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7))))
```
"""
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
haskey(cache, x) && return cache[x]
y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x)
cache[x] = y

return y
return fold(x; cache, walk, isleaf = exclude) do node
!exclude(node) && return node
ToucheSir marked this conversation as resolved.
Show resolved Hide resolved
return f(node)
ToucheSir marked this conversation as resolved.
Show resolved Hide resolved
end
end

"""
Expand Down Expand Up @@ -296,8 +342,8 @@ fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x))
fcollect(x; exclude = v -> false)

Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref)
and collecting the results into a flat array, ordered by a breadth-first
traversal of `x`, respecting the iteration order of `children` calls.
and collecting the results into a flat array, ordered by a depth-first,
post-order traversal of `x` that respects the iteration order of `children` calls.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically breaking (as can be seen from the test changes), but it's unclear to me how many users were relying on the breadth-first iteration order. The one example I found on JuliaHub, AlphaZero, was not. It's also odd that the original worked breadth-first when pretty much all other traversals based on functors are depth-first. So RFC, and happy to back this out if we consider it too breaking.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to have this breaking change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that nothing now in Optimisers.jl now cares about the order taken by fmap.

Of course the two halves of destructure care a lot and must match. I haven't completely understood how #31 would have to change; it currently uses fmap for one half and something hand-written for the other.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order in fmap has always been stable, things brings fcollect in line as it was the odd one out.


Doesn't recurse inside branches rooted at nodes `v`
for which `exclude(v) == true`.
Expand All @@ -324,33 +370,27 @@ Foo(Bar([1, 2, 3]), NoChildren(:a, :b))

julia> fcollect(m)
4-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
Bar([1, 2, 3])
[1, 2, 3]
Bar([1, 2, 3])
NoChildren(:a, :b)
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))

julia> fcollect(m, exclude = v -> v isa Bar)
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
NoChildren(:a, :b)
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))

julia> fcollect(m, exclude = v -> Functors.isleaf(v))
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
Bar([1, 2, 3])
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
```
"""
function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
# for the results, to preserve traversal order (important downstream!).
x in cache && return output
if !exclude(x)
push!(cache, x)
push!(output, x)
foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x))
end
return output
function fcollect(x; output = [], cache = Base.IdDict(), exclude = v -> false)
fold(x; cache, isleaf = exclude, walk = _foreach_walk) do node
exclude(node) || push!(output, node); # always return nothing
end
return output
end

# Allow gradients and other constructs that match the structure of the functor
Expand Down
34 changes: 27 additions & 7 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,26 @@ end
end
end

@testset "Folds" begin
arrays = ntuple(i -> [i], 3)
model = Foo(
Foo(arrays[1], arrays[2]),
Foo(arrays[3], arrays[1])
)

total = Ref(0)
Functors.fmap(model, cache = true) do x
total[] += only(x)
end
@test total[] == 6

total = Ref(0)
Functors.fmap(model, cache = false) do x
total[] += only(x)
end
@test total[] == 7
end

@testset "Nested" begin
model = Bar(Foo(1, [1, 2, 3]))

Expand Down Expand Up @@ -72,21 +92,21 @@ end
m2 = 1
m3 = Foo(m1, m2)
m4 = Bar(m3)
@test all(fcollect(m4) .=== [m4, m3, m1, m2])
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3, m2])
@test all(fcollect(m4) .=== [m1, m2, m3, m4])
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m2, m3, m4])
@test all(fcollect(m4, exclude = x -> x isa Foo) .=== [m4])

m1 = [1, 2, 3]
m2 = Bar(m1)
m0 = NoChildren(:a, :b)
m3 = Foo(m2, m0)
m4 = Bar(m3)
@test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
@test all(fcollect(m4) .=== [m1, m2, m0, m3, m4])

m1 = [1, 2, 3]
m2 = [1, 2, 3]
m3 = Foo(m1, m2)
@test all(fcollect(m3) .=== [m3, m1, m2])
@test all(fcollect(m3) .=== [m1, m2, m3])
end

struct FFoo
Expand Down Expand Up @@ -143,14 +163,14 @@ end
m2 = [1, 2, 3]
m3 = FFoo(m1, m2, (:y, ))
m4 = FBar(m3, (:x,))
@test all(fcollect(m4) .=== [m4, m3, m2])
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3])
@test all(fcollect(m4) .=== [m2, m3, m4])
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m3, m4])
@test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4])

m0 = NoChildren(:a, :b)
m1 = [1, 2, 3]
m2 = FBar(m1, ())
m3 = FFoo(m2, m0, (:x, :y,))
m4 = FBar(m3, (:x,))
@test all(fcollect(m4) .=== [m4, m3, m2, m0])
@test all(fcollect(m4) .=== [m2, m0, m3, m4])
end