Description
In Flux, we have trainable
to designate a subset of leaves as nodes to walk when updating parameters for training. In FluxPrune.jl, I defined pruneable
to designate a subset of leaves for pruning (note that these cannot be the same as the trainable
nodes).
Right now this creates an unfortunate circumstance as discussed in FluxML/Flux.jl#1946. Users need to @functor
their types, remember to define trainable
if necessary. Potentially, to use FluxPrune.jl, they might want to remember to define pruneable
. On the developer side of things, we can use the walk
keyword of fmap
to walk the differently labeled leaf nodes. But this usually requires defining a separate walk function based on the subset that you are hoping to target.
An alternative would be to build this information directly into what @functor
defines. Right now, each child of a functor has a name and a value. I propose adding "tags" which would be a tuple of symbols. Then we could do something like
@functor Conv trainable=(weight, bias) pruneable=(weight,)
Ideally, this mechanism should be dynamic, meaning that if Flux.jl already defines the trainable leaves of a type, then another package like FluxPrune.jl should be able to add a pruneable tag on top of that.
My hope is that we make it easier on users by only having one line for making your type Flux-compatible. And we make it easier on developers by making it easy to filter nodes when walking by tag. I haven't spent a lot of time on the implementation aspect, but I just wanted to float the notion of tags first and get some feedback.