Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trainable for BatchNorm stops parameters from being saved and loaded #1027

Closed
xukai92 opened this issue Feb 8, 2020 · 14 comments
Closed

trainable for BatchNorm stops parameters from being saved and loaded #1027

xukai92 opened this issue Feb 8, 2020 · 14 comments

Comments

@xukai92
Copy link

xukai92 commented Feb 8, 2020

Below is defined to only take gradient of the β and γ in batch norm layers.

trainable(bn::BatchNorm) = (bn.β, bn.γ)

However, this stops us from using params and loadparams! to save and load parameters as the other two fields, μ and σ², which are updated during training as well, to be saved and loaded.

Maybe it's just fine to not define trainable(bn::BatchNorm) = (bn.β, bn.γ) as μ and σ² doesn't seems to have gradient?

@CarloLucibello
Copy link
Member

Maybe is fine to add μ and σ² to the trainable parameters, since they are used only at test time and they have zero gradient during training.
A more general solution would be to add a flag to params to also return non-trainable fields.

@xukai92
Copy link
Author

xukai92 commented Feb 11, 2020

Maybe is fine to add μ and σ² to the trainable parameters, since they are used only at test time and they have zero gradient during training.

Yes that's what I mean ... don't know why I typed "no" even I said "doesn't` in the end of the setence...

@CarloLucibello
Copy link
Member

@xukai92 what's your specific use case? can you work around it with bson saving and loading?

@xukai92
Copy link
Author

xukai92 commented Mar 5, 2020

I used to have issues using BSON.jl - I will try it again.

@nantonel
Copy link

nantonel commented Nov 11, 2020

Apparently, if you save your model weights:

julia> weights = params(model);

julia> using BSON: @save

julia> @save "mymodel.bson" weights

the batchnorm parameters are lost but if you save with @save "mymodel.bson" model they are preserved.

@ToucheSir
Copy link
Member

We should probably advocate against saving implicit Params in favour of saving explicit params or the model itself. Flat params are just too brittle and are easily invalidated with architectural changes.

@darsnack
Copy link
Member

darsnack commented Feb 9, 2022

I suppose we should think more comprehensively about how to do this. It has come up in FluxML/MetalheadWeights#2. On one hand, Metalhead should just save the model itself and return that instead copying over the weights. On the other hand, this makes it possible to silently return a different model with pretrain = true or pretrain = false. Hopefully, we would never make a release that allows that to happen, but it has the code smell of a poor design.

Ideally, the saving mechanism should allow users to specify state that isn't trainable to be saved. And the loading mechanism should error when the architectures don't match. e.g. maybe we should be doing a structural walk of the model instead of a flat iteration over the parameters.

@ToucheSir
Copy link
Member

That all makes sense to me. I think it (mostly the structural walk) might even be possible to do with what's currently in Optimisers.jl.

Ideally, the saving mechanism should allow users to specify state that isn't trainable to be saved. And the loading mechanism should error when the architectures don't match.

PyTorch handles this by returning a set of mismatched locations when loading and optionally throwing when it detects a mismatch. I think we're on the same page about wanting to do something similar for Flux?

Ideally, the saving mechanism should allow users to specify state that isn't trainable to be saved.

This is where it gets "fun" IMO. Trainable state is probably (can't think of a counterexample) a subset of saveable state. For example norm layer statistics aren't trainable, but you'd definitely want to persist them. The simplest solution would be to use the (fields...,) part of @functor StructType (fields...) more so that fmapstructure gives you back only those params which are worth saving. Depending on the serialization format, I think a further pass would still be required to filter out types like functions.

@darsnack
Copy link
Member

darsnack commented Feb 10, 2022

For loading, I think using fmap(f, m1, m2) with the latest changes on Functors should throw an error when the structures don't match, right?

For saving more than trainable, isn't the easiest to just define savable 😅? It can default to trainable (I agree that "trainable" is a subset of "savable"). Maybe we can modify params (of course this breaks implicit params...I meant have some saveparams(m, path)) to use fmapstructure with the appropriate "subset walk".

@ToucheSir
Copy link
Member

I'm not sure there's a need for a new function, functor can probably fill that role on its own :)

@darsnack
Copy link
Member

True, but functor should be the most expansive set of fields and not all of them are worth saving. Maybe this distinction is not so important though. It's certainly safer to save all of them.

@ToucheSir
Copy link
Member

I think the most expansive set of fields would be ConstructionBase.getproperties, i.e. all of them. This does surface a good observation though, which is that we'd really like to have more than one functor for different scenarios. Currently functor itself is unnecessarily privileged on one hand because all higher-level functions call it unconditionally, yet also unnecessarily constrained on the other because it needs to be as general as possible. For now, my vote would be to see how far we can get with fmapstructure sans custom walk and revisit if we hit a wall.

@xukai92
Copy link
Author

xukai92 commented Jun 10, 2022

@darsnack Is this issue solved by #1875?

@darsnack
Copy link
Member

darsnack commented Jun 10, 2022

Yes, it is resolved by loadmodel!, thanks for catching that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants