-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
trainable
for BatchNorm stops parameters from being saved and loaded
#1027
Comments
Maybe is fine to add μ and σ² to the trainable parameters, since they are used only at test time and they have zero gradient during training. |
Yes that's what I mean ... don't know why I typed "no" even I said "doesn't` in the end of the setence... |
@xukai92 what's your specific use case? can you work around it with bson saving and loading? |
I used to have issues using BSON.jl - I will try it again. |
Apparently, if you save your model weights: julia> weights = params(model);
julia> using BSON: @save
julia> @save "mymodel.bson" weights the batchnorm parameters are lost but if you save with |
We should probably advocate against saving implicit |
I suppose we should think more comprehensively about how to do this. It has come up in FluxML/MetalheadWeights#2. On one hand, Metalhead should just save the model itself and return that instead copying over the weights. On the other hand, this makes it possible to silently return a different model with Ideally, the saving mechanism should allow users to specify state that isn't trainable to be saved. And the loading mechanism should error when the architectures don't match. e.g. maybe we should be doing a structural walk of the model instead of a flat iteration over the parameters. |
That all makes sense to me. I think it (mostly the structural walk) might even be possible to do with what's currently in Optimisers.jl.
PyTorch handles this by returning a set of mismatched locations when loading and optionally throwing when it detects a mismatch. I think we're on the same page about wanting to do something similar for Flux?
This is where it gets "fun" IMO. Trainable state is probably (can't think of a counterexample) a subset of saveable state. For example norm layer statistics aren't trainable, but you'd definitely want to persist them. The simplest solution would be to use the |
For loading, I think using For saving more than |
I'm not sure there's a need for a new function, |
True, but |
I think the most expansive set of fields would be |
Yes, it is resolved by |
Below is defined to only take gradient of the
β
andγ
in batch norm layers.However, this stops us from using
params
andloadparams!
to save and load parameters as the other two fields,μ
andσ²
, which are updated during training as well, to be saved and loaded.Maybe it's just fine to not define
trainable(bn::BatchNorm) = (bn.β, bn.γ)
asμ
andσ²
doesn't seems to have gradient?The text was updated successfully, but these errors were encountered: