Closed
Description
Standardization._update_moments
is called for each flattened input. While each input has its own mean and m2, the count is shared, and therefore updated multiple times, leading to an increase by n_inputs*batch_size
instead of only batch_size
. This leads to wrong results.
Possible fixes are:
- Update
count
only once in the end of the updates - Have one count object for each input
The first is a bit more hacky, but does not break serialization. The latter would be in line with how we handle the mean and m2, and would fit better in the existing structure in my opinion, but would break the serialization.
I will provide a fix for 1., but 2. would be easy to implement as well.