-
Couldn't load subscription status.
- Fork 19.6k
Closed
Closed
Copy link
Labels
Description
The mean/variance calculations are incorrect, which means the inputs are not normalized correctly. E.g.
import keras
x = keras.ops.ones((1, 2, 3, 4))
x._keras_mask = keras.ops.ones((1, 2, 1))
y = keras.layers.BatchNormalization()(x, training=True)
print(keras.ops.mean(y, axis=-1))gives output
tf.Tensor([-0.57732624 -0.57732624 -0.57732624 -0.57732624], shape=(4,), dtype=float32)
instead of the correct normalized output ([0, 0, 0, 0]).
The basic issue is that this calculation is incorrect:
keras/keras/src/layers/normalization/batch_normalization.py
Lines 310 to 314 in efaaf85
| sum_of_weights = ops.sum( | |
| mask_weights_broadcasted, | |
| self._reduction_axes, | |
| keepdims=True, | |
| ) |
because it doesn't account for the broadcasting (i.e. it gives a value of 2 in the above example, when it should be 2 * 3 * 4).
See #19818 for more discussion/background.