Skip to content

Error in masked BatchNormalization with > 3 dimensions #19818

@drasmuss

Description

@drasmuss
import keras

x = keras.ops.ones((1, 2, 3, 4))
mask = keras.ops.ones((1, 2), dtype="bool")
y = keras.layers.BatchNormalization()(x, mask=mask, training=True)

gives

tensorflow.python.framework.errors_impl.InvalidArgumentError: Exception encountered when calling BatchNormalization.call().

{{function_node __wrapped__Mul_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [1,2,1] vs. [1,2,3,4] [Op:Mul] name: 

Arguments received by BatchNormalization.call():
  • inputs=tf.Tensor(shape=(1, 2, 3, 4), dtype=float32)
  • training=True
  • mask=tf.Tensor(shape=(1, 2), dtype=bool)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions