Skip to content

Standardization: count is shared and gets updated multiple times, leading to wrong results #524

Closed
@vpratz

Description

@vpratz

Standardization._update_moments is called for each flattened input. While each input has its own mean and m2, the count is shared, and therefore updated multiple times, leading to an increase by n_inputs*batch_size instead of only batch_size. This leads to wrong results.

Possible fixes are:

  1. Update count only once in the end of the updates
  2. Have one count object for each input

The first is a bit more hacky, but does not break serialization. The latter would be in line with how we handle the mean and m2, and would fit better in the existing structure in my opinion, but would break the serialization.

I will provide a fix for 1., but 2. would be easy to implement as well.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions