-
-
Notifications
You must be signed in to change notification settings - Fork 653
Add a handle_buffers
option for EMAHandler
#2592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@sandylaker thanks for a quick PR! As for the problem about integral params in buffers, does it mean that in pytorch they should here a similar issue here: https://github.com/pytorch/pytorch/blob/080cf84bed46c6c118c37fc2fa6fbd484fd9b4cd/torch/optim/swa_utils.py#L124-L132 ? |
Yeah they should, but the default usage for the SWA model is to update the batchnorm after training and not during via https://github.com/pytorch/pytorch/blob/080cf84bed46c6c118c37fc2fa6fbd484fd9b4cd/torch/optim/swa_utils.py#L136 . Also as I would hypothesize, that most code and examples usually set the model to .train() before transfer, thus the problem is less likely to occur. Actually, I once trained a common Mean-teacher recipe and checked the absolute difference between the online (student) model and the EMA model. |
@vfdev-5 I did not test the >>> num_batches_tracked = torch.tensor(0, dtype=torch.int64)
>>> num_batches_tracked.copy_(num_batches_tracked * 0.9998 + 1 * 0.0002)
tensor(0) I think the rounding error will be more severe when the online value is small. In the current implementation, the |
@sandylaker if you check again "Solutions" of #2590 it looks like that RicherMans already tried to update floating buffers and copy integers (see his code snippet). For this PR I think we could add an arg with 3 options: 1) "copy" = to keep current (copy) behaviour, 2) "update_buffers" = update buffers according to Richer's option 1 and 3) "ema_train" = Richer's option 2 What do you think ? |
So I actually did already check during my investigations all the proposed methods. So to be more precise for my experiments, I do sound event classification for the DCASE 2022 task, where the baseline uses a mean-teacher approach that I have reimplemented using EMAHandler. The results are as follows:
So as I originally stated, the momentum update of buffers such as the mean and variance works obviously better than the simple copying mechanism. Of course this is a mean-teacher learning case, and my results might not be representative for other cases where one might want to "freeze" the buffers. What do you guys think? |
@RicherMans thanks for sharing details ! I think we could provide all 3 options such that depending on a use-case user could pick appropriate configuration (as suggested above: #2592 (comment)). We have to ensure that the 3rd option |
@sandylaker can we move forward with this PR and implement 3 options described in #2592 (comment) ? Thank you |
@vfdev-5 Hi. Sorry for the late reply, I have been busy during the working days. I will implement it this weekend. |
use_buffers
option for EMAHandlerhandle_buffers
option for EMAHandler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot @sandylaker !
@RicherMans could you please check this PR if it solves the issue you have |
I merged this PR and if we need more changes to it, we can do that it in a follow-up PR. |
Sorry @vfdev-5 for my late response. Thanks guys for the good work! ( Even though I hope I can also one time commit myself ) |
That would be awesome ! By the way, if you need some help to start with that, feel free to join our discord: https://pytorch-ignite.ai/chat |
Fixes #2590
Description:
Add a
handle_buffers
option for EMAHandler.cc @vfdev-5 @RicherMans
Check list: