Skip to content

BatchNormalization gives incorrect output with masked inputs > 3 dimensions #19848

@drasmuss

Description

@drasmuss

The mean/variance calculations are incorrect, which means the inputs are not normalized correctly. E.g.

import keras

x = keras.ops.ones((1, 2, 3, 4))
x._keras_mask = keras.ops.ones((1, 2, 1))

y = keras.layers.BatchNormalization()(x, training=True)

print(keras.ops.mean(y, axis=-1))

gives output

tf.Tensor([-0.57732624 -0.57732624 -0.57732624 -0.57732624], shape=(4,), dtype=float32)

instead of the correct normalized output ([0, 0, 0, 0]).

The basic issue is that this calculation is incorrect:

sum_of_weights = ops.sum(
mask_weights_broadcasted,
self._reduction_axes,
keepdims=True,
)

because it doesn't account for the broadcasting (i.e. it gives a value of 2 in the above example, when it should be 2 * 3 * 4).

See #19818 for more discussion/background.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions