-
-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extract common functionality into fold #32
base: master
Are you sure you want to change the base?
Conversation
return x | ||
end | ||
|
||
### WARNING: the following is unstable internal functionality. Use at your own risk! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I considered _
-prefixing everything but didn't for now because these will probably become part of the public API at some point. Happy to prefix them if that's desired though.
|
||
# Convenience function for working with `Fold` | ||
function fold(f, x; isleaf = isleaf, cache = false, walk = _default_walk) | ||
if cache === true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that fmap
and fcollect
are still passing through IdDict()
by default. The hope is that new functions either don't allow customization of the cache or use a "safer" default.
and collecting the results into a flat array, ordered by a breadth-first | ||
traversal of `x`, respecting the iteration order of `children` calls. | ||
and collecting the results into a flat array, ordered by a depth-first, | ||
post-order traversal of `x` that respects the iteration order of `children` calls. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is technically breaking (as can be seen from the test changes), but it's unclear to me how many users were relying on the breadth-first iteration order. The one example I found on JuliaHub, AlphaZero, was not. It's also odd that the original worked breadth-first when pretty much all other traversals based on functors are depth-first. So RFC, and happy to back this out if we consider it too breaking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to have this breaking change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that nothing now in Optimisers.jl now cares about the order taken by fmap
.
Of course the two halves of destructure
care a lot and must match. I haven't completely understood how #31 would have to change; it currently uses fmap
for one half and something hand-written for the other.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order in fmap
has always been stable, things brings fcollect
in line as it was the odd one out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great improvement!
src/functor.jl
Outdated
end | ||
Cache() = Cache(IdDict()) | ||
|
||
iscachesafe(x) = !isbits(x) && ismutable(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I suggest calling this usecache
? It's not so clear to me which way "safe" points, safe that we don't need the cache, or safe that we can use it?
Also, why ismutable
? This catches only the outermost wrapper, which I don't think ought to matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usecache
sounds better, changed.
ismutable
is required because this function is called at every level of the tree. So if there is a string in a (possibly deeply) nested struct that otherwise contains all bitstype values, isbits
will fail and we will get a false positive. In fact ismutable
may be sufficient on its own, but in the spirit of not making too many assumptions around underspecified language behaviour (namely that all mutable structs have stable objectids) I've kept it in to be safe.
src/functor.jl
Outdated
Base.get!(f, c::Cache, x) = iscachesafe(x) ? get!(f, c.inner, x) : f() | ||
|
||
# Passthrough used to disable caching (e.g. when passing `cache=false`) | ||
struct NoCache end | ||
Base.get!(f, ::NoCache, _) = f() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style not logic, but could this be more tidily done as a safeget!
function, rather than types on which Base's functions act? There is I think exactly one line which calls this get!
on these types.
Something like
safeget!(f, d::AbstractDict, k) = usecache(k) ? get!(f, d, k) : f()
safeget!(f, d, k) = f()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had exactly this in an earlier draft, but unfortunately it's a breaking change (i.e. tests fail) because the default before was to unconditionally trust IdDict/objectid/===. So value types were being deduped via the cache. If we decide to call that a bug, then I can put this suggestion in place.
Another theoretical benefit of not hard-coding usecache
would be allowing users to customize caching behaviour on a case-by-case basis. I haven't thought of a scenario that would require such a degree of customization yet (e.g. TiedArrays
could simply overload usecache
), so we can cross that bridge when we come to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, my suggestion still leaves usecache
as an overloadable function. It's not intended as a change of logic, just more compact & fewer types?
Testing isbits
etc. seems at all seems like a potentially breaking change, as it changes how many times f
will be called (e.g. what does fmap(println, (1,2,3,1,2,3))
print?). Maybe it depends how firmly you believe f
must be pure... certainly fcollect
is going to change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking for isbits
and ismutable
is a hedge against an impure f
. IMO users would be more surprised if fmap(println, (1,2,3,1,2,3))
printed 1\n2\n3\n
as it currently does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I'm in favour of some such usecache
check. But perhaps this can be factored out, and we can discuss it this one has the right rules. (Should it be any or all?) This seems largely orthogonal.
The point of this comment though was to wonder whether all these structs are necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So to clarify, yea to calling the current behaviour a bug and unconditionally calling usecache
if it helps get rid of the structs?
Can you sketch how this would work? Optimisers does some tricky things, and it often has to treat the gradient asymmetrically to the model. Is it clear that this is a good fit? |
Haven't tested this, but this a rough sketch of the approach. Probably needs more iterations in the oven, but the more we take Functors.jl in this direction, the better. Even without Optimisers.jl, I will be able to use these improvements for model pruning. function update(tree, x, dxs...)
function opt_walk(f, leaf, x, dxs...)
any(isnothing, (leaf, x, dxs...)) && return first(dxs)
stree = children(leaf)
sx = children(x)
sdxs = map(children, dxs...)
foreach((_tree, _x, _dxs) -> f(_tree, _x, _dxs...), stree, sx, zip(sdxs))
end
fmap(apply, tree, x, dxs...; walk = opt_walk)
end |
Good questions. Handling parallel traversal (i.e. zip then traverse) over asymmetric trees will require better walk functions than we currently have in Functors. I started prototyping that as part of #27, but it's going to be a challenge to get working without completely changing core functionality such as the behaviour of Footnotes
|
Ok. It might be worth trying to write FluxML/Optimisers.jl#42 using this as a logic stress-test? Although need not hold up this PR I guess. (I am not happy yet with how to handle transposed tied arrays, there, for which the gradient might not have .parent, but might be a thunk or a Broadcasted....) Longer term I think we should be open to just scrapping Functors and starting over. The entire library fits on a page, really, it's just a matter of choosing logic which matches the problems we want to solve. And a clean start (or two or three) might be much easier than contortions to modify it while keeping backward compat. |
The major breakages will be w.r.t. Flux's uses of Functors. There are probably few users directly using Functors outside of |
We could contemplate having Flux provide an |
The entire library fitting on a page is both part of the allure and the fundamental problem with Functors. It makes the problem look so easy, whereas as we've discovered it is anything but.
This could be done today by moving to an alternative like https://github.com/chengchingwen/StructWalk.jl.
cc @chengchingwen for his thoughts on this.
JuliaHub tells me that a number of libraries rely on Functors but not Flux, so it would help a bit but not completely solve the migration issue. |
@ToucheSir Do you have some small test cases in mind for those missing pieces? It's not obvious to me how they should work in general. For example, what would be considered to be similarly structured: Is |
@chengchingwen those are all good examples, and honestly I'm not sure either. One idea I had is that structural similarity would be determined by the LHS. So
That too I'm not sure on. Do we error/warn about them, if at all? My thought was to take the intersection of all properties for keyed types and walk over only those. Some thinking out loud with pseudocode: function walk(f, nt, nts...) # single return, f(xs...) = y
common_props = intersect(propertynames(nt), map(propertynames, nts...)...)
patch = map(f, nt[common_props], map(t -> t[common_props])...)
return merge(nt, patch)
end
function walk(f, nt, nts...) # multiple returns, f(xs...) = ys..., length(ys) == length(xs)
common_props = intersect(propertynames(nt), map(propertynames, nts...)...)
patch, patches... = map(f, nt[common_props], map(t -> t[common_props])...)
return (merge(nt, patch), map(merge, nts, patches)...)
end Maybe that's too surprising? Linearly-indexed types should probably just fail on a length mismatch? I see the 2-arg case as an instance of the multi-arg case, but perhaps that's the wrong way to look at it.
Yes, the question is how to create a tuple of trees given a tuple of trees and a user-specified varargs function. AIUI applying the function at each node will result in a tree of tuples instead, so the library would need to know how to unwrap those tuples either during or after traversal. |
This might be problematic since the similar operator is asymmetric, and I'm not sure we always want to align the children just according to the first argument (or just the intersect of property names as in the pseudocode). There are cases we would align children with different name or even different depths to ignore some wrapper. All in all, I think we need to clarify what the data do we really need during the walk/transformation. In the "walking over subtrees" paradigm, we want the "zip"ed data in each tree like running several BFS in parallel. OTOH, In the "traverse multiple trees" paradigm, we might want different part of data. So IMPO no matter it's a single/multiple input tree single/multiple output tree function, the real question is in what order do we want to access those subtrees? |
That's true, is there a way to customize this alignment behaviour? To be clear, I only think basing the output structure off the leftmost tree makes sense when a single tree is returned from the multi-tree walk, as it aligns somewhat with how
I think "walking over multiple subtrees" was bad wording on my part. A better term would've been "dissimilar trees" to mirror the 1st point about similarly-structured trees. Presumably that narrows down the design space dramatically, as I can't think of a use-case in e.g. Optimisers.jl that wouldn't work with either |
One thing it wants is the tree equivalent of Like Peter, I'm a little sceptical that we can design the ideal multi-tree walker, without close reference to the problems it's meant to solve, and their weird edge cases. Optimisers.jl gets a lot of mileage out of |
Some side notes about |
# _trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr) ...
function _trainable_fillnothing_walk(f, x)
func, re = functor(x)
tfunc = trainable(x)
return re(_trainable(func, map(f, tfunc)))
end
To Peter's latest point, all this is why I became interested in StructWalk.jl. |
Traversal functions in Functors.jl currently handle traversal via manual recursion. This isn't the end of the world, but it results in a good amount of code duplication as well as parameter explosion from having to plumb auxiliary arguments through the call stack. With #31 looking to add a number of additional functions and Optimisers.jl not getting a lot of mileage out of
fmap
, it's a good time to consider whether we can cut down on the boilerplate.This PR introduces helper traversal functionality in the form of
[fF]old
. In the language of recursion schemes,fold
is a "catamorphism" which performs a generalized structural reduction over a tree (or in our case, DAG). Also added are a couple of caching-related helpers that may be useful downstream.This is the first step in an effort to implement the vision of #27 while maintaining as much backwards compatibility as possible. A non-exhaustive list of objectives for future PRs include:
re
(construct) closure. This can lead to some unfortunate gymnastics, so getting rid of it would be great.fmap(f, x, xs...)
. This would help Optimisers.jl and any other downstream library that have rolled their ownfmap
variants withfunctor
.