Skip to content

Taggable functors #41

Open
Open
@darsnack

Description

@darsnack

In Flux, we have trainable to designate a subset of leaves as nodes to walk when updating parameters for training. In FluxPrune.jl, I defined pruneable to designate a subset of leaves for pruning (note that these cannot be the same as the trainable nodes).

Right now this creates an unfortunate circumstance as discussed in FluxML/Flux.jl#1946. Users need to @functor their types, remember to define trainable if necessary. Potentially, to use FluxPrune.jl, they might want to remember to define pruneable. On the developer side of things, we can use the walk keyword of fmap to walk the differently labeled leaf nodes. But this usually requires defining a separate walk function based on the subset that you are hoping to target.

An alternative would be to build this information directly into what @functor defines. Right now, each child of a functor has a name and a value. I propose adding "tags" which would be a tuple of symbols. Then we could do something like

@functor Conv trainable=(weight, bias) pruneable=(weight,)

Ideally, this mechanism should be dynamic, meaning that if Flux.jl already defines the trainable leaves of a type, then another package like FluxPrune.jl should be able to add a pruneable tag on top of that.

My hope is that we make it easier on users by only having one line for making your type Flux-compatible. And we make it easier on developers by making it easy to filter nodes when walking by tag. I haven't spent a lot of time on the implementation aspect, but I just wanted to float the notion of tags first and get some feedback.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions