Skip to content

Commit 2890899

Browse files
authored
[Relay][Keras][Bugfix] fix the converters of GRU and SimpleRNN about the go_backwards attribute (#15829)
* fix bug in gru and simpleRNN about go_backwards * Update test_forward.py * Update keras.py
1 parent def551d commit 2890899

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

python/tvm/relay/frontend/keras.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,8 @@ def _convert_simple_rnn(
10621062
in_bias = etab.new_const(weightList[2])
10631063
assert len(in_data.type_annotation.shape) == 3
10641064
timeDim = in_data.type_annotation.shape[1].value
1065+
if keras_layer.go_backwards:
1066+
in_data = _op.reverse(in_data, axis=1)
10651067
in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1)
10661068
for i in range(len(in_data_split)):
10671069
in_data_split_i = _op.nn.batch_flatten(in_data_split[i])
@@ -1090,6 +1092,8 @@ def _convert_gru(
10901092
recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
10911093
if keras_layer.use_bias:
10921094
in_bias = etab.new_const(weightList[2])
1095+
if keras_layer.go_backwards:
1096+
in_data = _op.reverse(in_data, axis=1)
10931097
units = list(weightList[0].shape)[1]
10941098
assert units > 0, "The value of units must be a positive integer"
10951099
in_data = _op.nn.batch_flatten(in_data)

tests/python/frontend/keras/test_forward.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,12 +568,23 @@ def test_forward_rnn(self, keras_mod):
568568
keras_mod.layers.SimpleRNN(
569569
units=16, return_state=False, activation="tanh", use_bias=False
570570
),
571+
keras_mod.layers.SimpleRNN(
572+
units=16, return_state=False, activation="tanh", go_backwards=True
573+
),
574+
keras_mod.layers.GRU(
575+
units=16,
576+
return_state=False,
577+
recurrent_activation="sigmoid",
578+
activation="tanh",
579+
reset_after=False,
580+
),
571581
keras_mod.layers.GRU(
572582
units=16,
573583
return_state=False,
574584
recurrent_activation="sigmoid",
575585
activation="tanh",
576586
reset_after=False,
587+
use_bias=False,
577588
),
578589
keras_mod.layers.GRU(
579590
units=16,
@@ -582,6 +593,7 @@ def test_forward_rnn(self, keras_mod):
582593
activation="tanh",
583594
reset_after=False,
584595
use_bias=False,
596+
go_backwards=True,
585597
),
586598
]
587599
for rnn_func in rnn_funcs:

0 commit comments

Comments
 (0)