Skip to content

Commit 4551644

Browse files
Fix the compatibility issues of Orthogonal and GRU (#19844)
* Add legacy `Orthogonal` class name * Add legacy `implementation` arg to `GRU`
1 parent 46df341 commit 4551644

File tree

4 files changed

+29
-0
lines changed

4 files changed

+29
-0
lines changed

keras/src/initializers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"uniform": RandomUniform,
5050
"normal": RandomNormal,
5151
"orthogonal": OrthogonalInitializer,
52+
"Orthogonal": OrthogonalInitializer, # Legacy
5253
"one": Ones,
5354
"zero": Zeros,
5455
}

keras/src/initializers/random_initializers_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ def test_orthogonal_initializer(self):
147147

148148
self.run_class_serialization_test(initializer)
149149

150+
# Test legacy class_name
151+
initializer = initializers.get("Orthogonal")
152+
self.assertIsInstance(initializer, initializers.OrthogonalInitializer)
153+
150154
def test_get_method(self):
151155
obj = initializers.get("glorot_normal")
152156
self.assertTrue(obj, initializers.GlorotNormal)

keras/src/layers/rnn/gru.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def __init__(
500500
trainable=kwargs.get("trainable", True),
501501
name="gru_cell",
502502
seed=seed,
503+
implementation=kwargs.pop("implementation", 2),
503504
)
504505
super().__init__(
505506
cell,

keras/src/layers/rnn/gru_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,26 @@ def test_masking(self):
286286
np.array([[0.11669192, 0.11669192], [0.28380975, 0.28380975]]),
287287
output,
288288
)
289+
290+
def test_legacy_implementation_argument(self):
291+
sequence = np.arange(72).reshape((3, 6, 4)).astype("float32")
292+
layer = layers.GRU(
293+
3,
294+
kernel_initializer=initializers.Constant(0.01),
295+
recurrent_initializer=initializers.Constant(0.02),
296+
bias_initializer=initializers.Constant(0.03),
297+
)
298+
config = layer.get_config()
299+
config["implementation"] = 0 # Add legacy argument
300+
layer = layers.GRU.from_config(config)
301+
output = layer(sequence)
302+
self.assertAllClose(
303+
np.array(
304+
[
305+
[0.5217289, 0.5217289, 0.5217289],
306+
[0.6371659, 0.6371659, 0.6371659],
307+
[0.39384964, 0.39384964, 0.3938496],
308+
]
309+
),
310+
output,
311+
)

0 commit comments

Comments
 (0)