-
Couldn't load subscription status.
- Fork 19.6k
Closed
Description
import keras
x = keras.ops.ones((1, 2, 3, 4))
mask = keras.ops.ones((1, 2), dtype="bool")
y = keras.layers.BatchNormalization()(x, mask=mask, training=True)
gives
tensorflow.python.framework.errors_impl.InvalidArgumentError: Exception encountered when calling BatchNormalization.call().
{{function_node __wrapped__Mul_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [1,2,1] vs. [1,2,3,4] [Op:Mul] name:
Arguments received by BatchNormalization.call():
• inputs=tf.Tensor(shape=(1, 2, 3, 4), dtype=float32)
• training=True
• mask=tf.Tensor(shape=(1, 2), dtype=bool)
Metadata
Metadata
Assignees
Labels
No labels