Skip to content

EMAHandler (Buffer update) does lead to bad results for Mean-Teacher experiments #2590

@RicherMans

Description

@RicherMans

🚀 Feature

Hey there,
thanks for the increased effort in Ignite, the package just gets better by each update!

So my feature request ( or bug report ) is concerned with the recently developed EMAHAndler.

Specifically I recently ran a lot of experiments using Ignite for Sound event classification and noticed that my results far underperform ( F1 Score of ~40 with Ignite and ~60 without ) models that do not use the included EMAHandler class.
I investigated this issue and came to the conclusion that the implementation of Ignites EMAHandler has some minor issues which lead to this behaviour.

Specifically, the lines responsible are:

self.ema_model.eval()

and

for ema_b, model_b in zip(self.ema_model.buffers(), self.model.buffers()):

First problem, evaluation mode

First, I'd like to discuss the problem with self.ema_model.eval().
While this default behavior makes sense for doing EMA updates, it leads to some major problems regarding non-parameter (buffer) updates of the EMA model.
A particular example is the use of BatchNorm, which will not update its mean and variance at all during EMA training.
However BatchNorm will update its affine transformations during training, which might lead to a discrepancy between the EMA model and the original (student) model.

Second problem, updating buffers

The second problem is directly connected to the first one.
Since the non-parameter updates of the EMA model are in-fact frozen in the Ignite implementation, the buffer variables are therefore simply copied from the Student (Original model) to the EMAModel using:

        for ema_b, model_b in zip(self.ema_model.buffers(), self.model.buffers()):
            ema_b.data = model_b.data

While again this makes perfect sense, for variables such as num_tracked_batches in a BatchNorm layer, it does not make sense for their corresponding means and variances.
For example, after some EMA updates, the weights of the EMA model might not be similar to the ones of the student model. But the current implementation forces the EMA model to use the exact same mean's and variances from the student model, which might not be the "True" means and variances.
This then will likely lead to a problem when training for longer, since the EMA model cannot properly normalize the input data and thus somewhat diverges.

Solutions

I suggest two solutions:

  1. (Not recommended ) Also average floating point buffers, such as means and variances.
for ema_b, model_b in zip(self.ema_model.buffers(),                                                                                                                                  
                          self.model.buffers()):                                                                                                                                     
    #Update buffers like running mean and variance                                                                                                                          
    if torch.is_floating_point(ema_b):                                                                                                                                      
        ema_b.mul_(1.0 - momentum).add_(model_b.data, alpha=momentum)                                                                                                       
    else:                                                                                                                                                                   
        ema_b.data = model_b.data

I tried this solution already in my experiments and at least from a "loss" perspective, it does outperform the standard EMA model by a large margin.
However the final results were still worse than the ones proposed in 2.:

  1. Do not force .eval() mode and remove the buffer updates.

I'd recommend currently to follow the implementation of most mean-teacher approaches, which actually require the EMA model to be in train() mode, thus updating the batchnorm buffers as well as adding some noise (dropout) during training.
Thus, I'd recommend just removing the forced self.ema_model.eval() statement and removing the buffer updates.

What is your point of view about this issue? Did anybody maybe also encounter a similar problem?

.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions