Description
It would be nice to be able to temporarily exclude some parameters from training.
(Edit: I forgot that there is FluxML/Flux.jl#1931, now folded in here.)
-
One mechanism is to alter
Leaf
to record whether is is frozen. This is what Per-leaf freezing #49 does, and what Allow shared parameters, take III #106 suggests as an aside. The former is immutable, changed by walking & re-building. The latter makesLeaf
mutable (for other reasons) so this can be changed in place. (Edit: implemented in Addfreeze!
/thaw!
#112, merged.) -
Another mechanism would be to insert some
Frozen
struct into the state tree which stops further exploration. This may make it easier to freeze a whole branch. But will result in a tree with different labels to the model, some pattern likemodel.layers[1].layers[1].weight
will no longer translate directly to one for the state tree. -
A similar struct could equally be inserted into the model not the state. Or into both. Since gradient calculation never sees the state, changing the model may allow for faster gradients. Does Optimisers.jl own the
struct Frozen
, if it is to recognise it?
Maybe independently, it needs a friendly way to set & remove these labels.
-
PR Per-leaf freezing #49 proposes that you give an address like
freeze(state, (:layers, 1, :enc, 3))
. It seems a bit awkward to require you to know all the field names from the root. -
It would also be possible to work based on just one field name:
freeze(state, :enc)
acts on anything within any field calledenc
(which in practice is someChain(enc = ..., dec = ...)
). Likewisefreeze(state, :bias)
could affect every layer. -
Another possibility is to allow control based on the type in the model. Then it has to walk both,
state = freeze(state, model, cond)
or perhapsstate = freeze(f, state, model)
wheref
is ado
block which testsx isa Dense
or whatever. Doesn't lend itself so obviously to freezing only some fields,enc
orbias
... unlessf
returns not a Bool but a list of fields, likex isa Chain && return :enc
. -
If the modification is to the model, then 6. becomes
model = freeze(f, model)
. -
If Leaf is mutable, then instead of an address you can just pass a part of the tree:
freeze!(tree.layers[1].enc[3])
, after confirming thatmodel.layers[1].enc[3]
is the part you want. (Edit: implemented as Addfreeze!
/thaw!
#112, merged.)
There's a related API question for shared weights. At present Flux (and Functors) rely on objectid. This won't work for immutable arrays.
-
One idea is to wrap them in a struct like
TiedWeight(array, Ref())
to get an objectid (and possibly remove this later). -
The idea of Transparent handling of tied weights #100 is that instead the state tree can have the same (mutable)
Leaf
struct at the location of tied arrays. How do you construct this? With 4. this might betie(state, (:layers, 1, :enc, 3) => (:layers, 1, :dec, 3, :parent))
where the:parent
is because of a Transpose. Is there a less ugly way?