Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract common functionality into fold #32

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Extract common functionality into fold #32

wants to merge 3 commits into from

Conversation

ToucheSir
Copy link
Member

@ToucheSir ToucheSir commented Jan 30, 2022

Traversal functions in Functors.jl currently handle traversal via manual recursion. This isn't the end of the world, but it results in a good amount of code duplication as well as parameter explosion from having to plumb auxiliary arguments through the call stack. With #31 looking to add a number of additional functions and Optimisers.jl not getting a lot of mileage out of fmap, it's a good time to consider whether we can cut down on the boilerplate.

This PR introduces helper traversal functionality in the form of [fF]old. In the language of recursion schemes, fold is a "catamorphism" which performs a generalized structural reduction over a tree (or in our case, DAG). Also added are a couple of caching-related helpers that may be useful downstream.

This is the first step in an effort to implement the vision of #27 while maintaining as much backwards compatibility as possible. A non-exhaustive list of objectives for future PRs include:

  1. Reducing the reliance on custom walks for downstream code
  2. Removing the requirement to carry around the re(construct) closure. This can lead to some unfortunate gymnastics, so getting rid of it would be great.
  3. Accomodating multiple inputs, e.g. fmap(f, x, xs...). This would help Optimisers.jl and any other downstream library that have rolled their own fmap variants with functor.

return x
end

### WARNING: the following is unstable internal functionality. Use at your own risk!
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered _-prefixing everything but didn't for now because these will probably become part of the public API at some point. Happy to prefix them if that's desired though.


# Convenience function for working with `Fold`
function fold(f, x; isleaf = isleaf, cache = false, walk = _default_walk)
if cache === true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that fmap and fcollect are still passing through IdDict() by default. The hope is that new functions either don't allow customization of the cache or use a "safer" default.

and collecting the results into a flat array, ordered by a breadth-first
traversal of `x`, respecting the iteration order of `children` calls.
and collecting the results into a flat array, ordered by a depth-first,
post-order traversal of `x` that respects the iteration order of `children` calls.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically breaking (as can be seen from the test changes), but it's unclear to me how many users were relying on the breadth-first iteration order. The one example I found on JuliaHub, AlphaZero, was not. It's also odd that the original worked breadth-first when pretty much all other traversals based on functors are depth-first. So RFC, and happy to back this out if we consider it too breaking.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to have this breaking change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that nothing now in Optimisers.jl now cares about the order taken by fmap.

Of course the two halves of destructure care a lot and must match. I haven't completely understood how #31 would have to change; it currently uses fmap for one half and something hand-written for the other.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order in fmap has always been stable, things brings fcollect in line as it was the odd one out.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great improvement!

src/functor.jl Outdated Show resolved Hide resolved
src/functor.jl Outdated
end
Cache() = Cache(IdDict())

iscachesafe(x) = !isbits(x) && ismutable(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I suggest calling this usecache? It's not so clear to me which way "safe" points, safe that we don't need the cache, or safe that we can use it?

Also, why ismutable? This catches only the outermost wrapper, which I don't think ought to matter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usecache sounds better, changed.

ismutable is required because this function is called at every level of the tree. So if there is a string in a (possibly deeply) nested struct that otherwise contains all bitstype values, isbits will fail and we will get a false positive. In fact ismutable may be sufficient on its own, but in the spirit of not making too many assumptions around underspecified language behaviour (namely that all mutable structs have stable objectids) I've kept it in to be safe.

src/functor.jl Outdated
Comment on lines 192 to 196
Base.get!(f, c::Cache, x) = iscachesafe(x) ? get!(f, c.inner, x) : f()

# Passthrough used to disable caching (e.g. when passing `cache=false`)
struct NoCache end
Base.get!(f, ::NoCache, _) = f()
Copy link
Member

@mcabbott mcabbott Jan 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style not logic, but could this be more tidily done as a safeget! function, rather than types on which Base's functions act? There is I think exactly one line which calls this get! on these types.

Something like

safeget!(f, d::AbstractDict, k) = usecache(k) ? get!(f, d, k) : f()
safeget!(f, d, k) = f()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had exactly this in an earlier draft, but unfortunately it's a breaking change (i.e. tests fail) because the default before was to unconditionally trust IdDict/objectid/===. So value types were being deduped via the cache. If we decide to call that a bug, then I can put this suggestion in place.

Another theoretical benefit of not hard-coding usecache would be allowing users to customize caching behaviour on a case-by-case basis. I haven't thought of a scenario that would require such a degree of customization yet (e.g. TiedArrays could simply overload usecache), so we can cross that bridge when we come to it.

Copy link
Member

@mcabbott mcabbott Jan 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, my suggestion still leaves usecache as an overloadable function. It's not intended as a change of logic, just more compact & fewer types?

Testing isbits etc. seems at all seems like a potentially breaking change, as it changes how many times f will be called (e.g. what does fmap(println, (1,2,3,1,2,3)) print?). Maybe it depends how firmly you believe f must be pure... certainly fcollect is going to change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking for isbits and ismutable is a hedge against an impure f. IMO users would be more surprised if fmap(println, (1,2,3,1,2,3)) printed 1\n2\n3\n as it currently does.

Copy link
Member

@mcabbott mcabbott Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I'm in favour of some such usecache check. But perhaps this can be factored out, and we can discuss it this one has the right rules. (Should it be any or all?) This seems largely orthogonal.

The point of this comment though was to wonder whether all these structs are necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to clarify, yea to calling the current behaviour a bug and unconditionally calling usecache if it helps get rid of the structs?

@mcabbott
Copy link
Member

e.g. fmap(f, x, xs...). This would help Optimisers.jl and any other downstream library that have rolled their own fmap variants with functor.

Can you sketch how this would work? Optimisers does some tricky things, and it often has to treat the gradient asymmetrically to the model. Is it clear that this is a good fit?

src/functor.jl Outdated Show resolved Hide resolved
src/functor.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member

darsnack commented Jan 30, 2022

Haven't tested this, but this a rough sketch of the approach. Probably needs more iterations in the oven, but the more we take Functors.jl in this direction, the better. Even without Optimisers.jl, I will be able to use these improvements for model pruning.

function update(tree, x, dxs...)
  function opt_walk(f, leaf, x, dxs...)
    any(isnothing, (leaf, x, dxs...)) && return first(dxs)
    stree = children(leaf)
    sx = children(x)
    sdxs = map(children, dxs...)
    foreach((_tree, _x, _dxs) -> f(_tree, _x, _dxs...), stree, sx, zip(sdxs))
  end

  fmap(apply, tree, x, dxs...; walk = opt_walk)
end

@ToucheSir
Copy link
Member Author

ToucheSir commented Jan 30, 2022

Can you sketch how this would work? Optimisers does some tricky things, and it often has to treat the gradient asymmetrically to the model. Is it clear that this is a good fit?

Good questions. fmap will never modify the structure of the tree, so using it that way is a non-starter. I intend to rename or alias it to mapleaves to clearly reflect that 1. fold, on the other hand, will traverse every level of the tree. So you can add or prune subtrees at will using it. This is overkill for most of Optimisers, but you can imagine how it might be useful for FluxML/Optimisers.jl#42 with a little tweaking.

Handling parallel traversal (i.e. zip then traverse) over asymmetric trees will require better walk functions than we currently have in Functors. I started prototyping that as part of #27, but it's going to be a challenge to get working without completely changing core functionality such as the behaviour of functor itself. Hence the incremental approach described up top: keep the user-facing interface as stable as possible while updating the internals. If that requires maintaining two parallel sets of functionality (one for back compat and one for the new stuff), so be it.

Footnotes

  1. What's old is new again: https://github.com/FluxML/Flux.jl/blob/e7d76b8423818c5a165e388dd3b090cc5bf42cbb/src/treelike.jl#L27. But seriously, removing jargon from functors is a good thing, especially when some of it (e.g. fmap) is not strictly correct.

@mcabbott
Copy link
Member

Ok. It might be worth trying to write FluxML/Optimisers.jl#42 using this as a logic stress-test? Although need not hold up this PR I guess. (I am not happy yet with how to handle transposed tied arrays, there, for which the gradient might not have .parent, but might be a thunk or a Broadcasted....)

Longer term I think we should be open to just scrapping Functors and starting over. The entire library fits on a page, really, it's just a matter of choosing logic which matches the problems we want to solve. And a clean start (or two or three) might be much easier than contortions to modify it while keeping backward compat.

@darsnack
Copy link
Member

The major breakages will be w.r.t. Flux's uses of Functors. There are probably few users directly using Functors outside of @functor. So, if we need to scrap one day, then the contortions can always live in Flux.jl for a deprecation cycle.

@mcabbott
Copy link
Member

We could contemplate having Flux provide an @functor which for now calls this one, to ease transitions.

@ToucheSir
Copy link
Member Author

The entire library fitting on a page is both part of the allure and the fundamental problem with Functors. It makes the problem look so easy, whereas as we've discovered it is anything but.

Longer term I think we should be open to just scrapping Functors and starting over.

This could be done today by moving to an alternative like https://github.com/chengchingwen/StructWalk.jl. fold is basically postwalk + cache and overriding of isleaf, after all. IMO WalkStyle is a cleaner and more idiomatic paradigm than custom walk functions too. In any case, the missing pieces for any contender (including Functors itself) are:

  1. Walking over similarly structured trees, walk(f, x, xs...)
  2. Walking over multiple subtrees with different branches, walk(f, (a=1, b=2), (b=3, c=4))
  3. Functionality to traverse multiple trees and return a single tree, fmap(f, x, xs...) -> tree
  4. Functionality to traverse multiple trees and return multiple trees, fmultimap(f, x, xs...) -> tree, trees.... This could also be done as an unzip-like post-processing step, see jax.tree_util.tree_transpose.

cc @chengchingwen for his thoughts on this.

We could contemplate having Flux provide an @functor which for now calls this one, to ease transitions.

JuliaHub tells me that a number of libraries rely on Functors but not Flux, so it would help a bit but not completely solve the migration issue.

@chengchingwen
Copy link
Member

@ToucheSir Do you have some small test cases in mind for those missing pieces? It's not obvious to me how they should work in general. For example, what would be considered to be similarly structured: Is 1//2 similar to (a=1,b=2)? what about (a, b) and (b=3, c=4) (they can both be viewed as NTuple{2})? Or maybe we want an interface to define what is similar and how should they aligned? I used to tackle a similar problem in loading weights from state_dict by using the field names (i.e. hasfield/getfield/keys/...) and functor, and StructWalk.jl was originally built for rewriting that part of code. OTOH, what is the different between "walking over subtrees" (1./2.) and "traverse multiple trees" (3./4.)? Returning multiple trees (4.) is the same as returning a tuple of trees, which is a single tree (3.). And if we are traversing multiple trees that are not similar / cannot be aligned (3. but not 1./2.), what would the traversal algorithm be? Priority first search with custom priority?

@ToucheSir
Copy link
Member Author

@chengchingwen those are all good examples, and honestly I'm not sure either. One idea I had is that structural similarity would be determined by the LHS. So (1, 2), (a=1,b=2), ... uses the tuple path, while (a=1,b=2), (1, 2), ... uses the namedtuple path. Arrays will be similar to tuples while structs will act like namedtuples (I don't want to think about Dicts...).

And if we are traversing multiple trees that are not similar / cannot be aligned (3. but not 1./2.), what would the traversal algorithm be? Priority first search with custom priority?

That too I'm not sure on. Do we error/warn about them, if at all? My thought was to take the intersection of all properties for keyed types and walk over only those. Some thinking out loud with pseudocode:

function walk(f, nt, nts...) # single return, f(xs...) = y
  common_props = intersect(propertynames(nt), map(propertynames, nts...)...)
  patch = map(f, nt[common_props], map(t -> t[common_props])...)
  return merge(nt, patch)
end

function walk(f, nt, nts...) # multiple returns, f(xs...) = ys..., length(ys) == length(xs)
  common_props = intersect(propertynames(nt), map(propertynames, nts...)...)
  patch, patches... = map(f, nt[common_props], map(t -> t[common_props])...)
  return (merge(nt, patch), map(merge, nts, patches)...)
end

Maybe that's too surprising?

Linearly-indexed types should probably just fail on a length mismatch? I see the 2-arg case as an instance of the multi-arg case, but perhaps that's the wrong way to look at it.

Returning multiple trees (4.) is the same as returning a tuple of trees, which is a single tree (3.)

Yes, the question is how to create a tuple of trees given a tuple of trees and a user-specified varargs function. AIUI applying the function at each node will result in a tree of tuples instead, so the library would need to know how to unwrap those tuples either during or after traversal.

@ToucheSir ToucheSir mentioned this pull request Jan 31, 2022
@chengchingwen
Copy link
Member

@ToucheSir

One idea I had is that structural similarity would be determined by the LHS

This might be problematic since the similar operator is asymmetric, and I'm not sure we always want to align the children just according to the first argument (or just the intersect of property names as in the pseudocode). There are cases we would align children with different name or even different depths to ignore some wrapper.

All in all, I think we need to clarify what the data do we really need during the walk/transformation. In the "walking over subtrees" paradigm, we want the "zip"ed data in each tree like running several BFS in parallel. OTOH, In the "traverse multiple trees" paradigm, we might want different part of data. So IMPO no matter it's a single/multiple input tree single/multiple output tree function, the real question is in what order do we want to access those subtrees?

@ToucheSir
Copy link
Member Author

This might be problematic since the similar operator is asymmetric, and I'm not sure we always want to align the children just according to the first argument (or just the intersect of property names as in the pseudocode). There are cases we would align children with different name or even different depths to ignore some wrapper.

That's true, is there a way to customize this alignment behaviour? To be clear, I only think basing the output structure off the leftmost tree makes sense when a single tree is returned from the multi-tree walk, as it aligns somewhat with how map works on collections now. Having it be pluggable with a sane default (e.g. force exact structural similarity as liftA2 does) would be amazing.

All in all, I think we need to clarify what the data do we really need during the walk/transformation. In the "walking over subtrees" paradigm, we want the "zip"ed data in each tree like running several BFS in parallel. OTOH, In the "traverse multiple trees" paradigm, we might want different part of data. So IMPO no matter it's a single/multiple input tree single/multiple output tree function, the real question is in what order do we want to access those subtrees?

I think "walking over multiple subtrees" was bad wording on my part. A better term would've been "dissimilar trees" to mirror the 1st point about similarly-structured trees. Presumably that narrows down the design space dramatically, as I can't think of a use-case in e.g. Optimisers.jl that wouldn't work with either prewalk or postwalk.

@mcabbott
Copy link
Member

mcabbott commented Feb 1, 2022

as I can't think of a use-case in e.g. Optimisers.jl that wouldn't work with either prewalk or postwalk.

One thing it wants is the tree equivalent of map(f, pairs(x)), with cache=false, over only trainable nodes. As a concrete example, can this be written neatly with pre/post-walk? Or without, is it made easier / neater by this PR?

Like Peter, I'm a little sceptical that we can design the ideal multi-tree walker, without close reference to the problems it's meant to solve, and their weird edge cases. Optimisers.jl gets a lot of mileage out of functor(typeof(x), y), and perhaps that pattern can be further abstracted somewhere here... but whether it can be done without making everything harder to understand I don't know. Knowing all the knobs on a configurable walking machine at some point becomes more complex than just writing the one you need in Julia code.

@chengchingwen
Copy link
Member

Some side notes about StructWalk.jl is that it has a interface for define custom WalkStyle, so we can have different kinds of children for different types according to the style. This is something that I think is missing in the current implementation of Functors.jl, as we are all based on functor to define the tree-equivalent map (trainable/gpu/f32/...), but those function might actually need different sets of child node. For example, not all array fields are trainable, but some should still be gpu-able. So the WalkStyle interface provide a way do modified what is being treated as a child node. Besides, walk are more general than fmap as they not only apply function on the leaf nodes, it also apply on the non-leaf nodes. I'm not sure if we really need this for the optimiser case, but should totally be ignored if not needed.

@ToucheSir
Copy link
Member Author

One thing it wants is the tree equivalent of map(f, pairs(x)), with cache=false, over only trainable nodes.

map(f, pairs(x)) over trainable nodes is doable with a custom walk function on master, I believe. The logic would be similar to Optimisers._trainable, but with a callback. One could also envision another walk that does replace non-trainable fields with nothing. In fact I think you could MacGyver one with the current functionality in Optimisers:

# _trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr) ...

function _trainable_fillnothing_walk(f, x)
  func, re = functor(x)
  tfunc = trainable(x)
  return re(_trainable(func, map(f, tfunc)))
end

cache=false is part of this PR, but could be accomplished on master with a dummy dict type with dummy setindex! and haskey methods.

As a concrete example, can this be written neatly with pre/post-walk? Or without, is it made easier / neater by this PR?

fmap is already doing an implicit post-order traversal, so I'd say so. This PR makes that implicit part explicit by pulling it into a function. What fold buys you over plain fmap is that it visits non-leaf nodes too. I think most of Optimisers could get away without such a thing (via packing logic into walk functions, assuming we get multi-tree walk functions), but it has the potential to make things neater.

To Peter's latest point, all this is why I became interested in StructWalk.jl. fold is more or less a convergently evolved StructWalk.postwalk, but we have no equivalent of prewalk in Functors at present. IMO WalkStyle is far nicer than custom walk functions (which were a bit of a hack from the beginning, we basically exposed part of the internals). The only thing missing is the ability to walk multiple trees (regardless of semantics), but perhaps that discussion would be better had on the StructWalk issue tracker. Thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants