Skip to content

Commit 0616d54

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Make UnitNormalization layer stateless.
There is no need to resolve negative axes in `build`, as `tf.linalg.l2_normalize` can handle them. Kept the build method to validate the axes in the context of the `input_shape`. Also added call to `super.build(...)` per best practice on Keras 2. Note that in Keras 3, `UnitNormalization` is already stateless. PiperOrigin-RevId: 699253410
1 parent 2da1800 commit 0616d54

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tf_keras/layers/normalization/unit_normalization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def __init__(self, axis=-1, **kwargs):
6060
self.supports_masking = True
6161

6262
def build(self, input_shape):
63-
self.axis = tf_utils.validate_axis(self.axis, input_shape)
63+
tf_utils.validate_axis(self.axis, input_shape)
64+
super().build(input_shape)
6465

6566
def call(self, inputs):
6667
inputs = tf.cast(inputs, self.compute_dtype)

0 commit comments

Comments
 (0)