File tree Expand file tree Collapse file tree 2 files changed +16
-0
lines changed
python/tvm/relay/frontend
tests/python/frontend/keras Expand file tree Collapse file tree 2 files changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -1062,6 +1062,8 @@ def _convert_simple_rnn(
10621062 in_bias = etab .new_const (weightList [2 ])
10631063 assert len (in_data .type_annotation .shape ) == 3
10641064 timeDim = in_data .type_annotation .shape [1 ].value
1065+ if keras_layer .go_backwards :
1066+ in_data = _op .reverse (in_data , axis = 1 )
10651067 in_data_split = _op .split (in_data , indices_or_sections = timeDim , axis = 1 )
10661068 for i in range (len (in_data_split )):
10671069 in_data_split_i = _op .nn .batch_flatten (in_data_split [i ])
@@ -1090,6 +1092,8 @@ def _convert_gru(
10901092 recurrent_weight = etab .new_const (weightList [1 ].transpose ([1 , 0 ]))
10911093 if keras_layer .use_bias :
10921094 in_bias = etab .new_const (weightList [2 ])
1095+ if keras_layer .go_backwards :
1096+ in_data = _op .reverse (in_data , axis = 1 )
10931097 units = list (weightList [0 ].shape )[1 ]
10941098 assert units > 0 , "The value of units must be a positive integer"
10951099 in_data = _op .nn .batch_flatten (in_data )
Original file line number Diff line number Diff line change @@ -568,12 +568,23 @@ def test_forward_rnn(self, keras_mod):
568568 keras_mod .layers .SimpleRNN (
569569 units = 16 , return_state = False , activation = "tanh" , use_bias = False
570570 ),
571+ keras_mod .layers .SimpleRNN (
572+ units = 16 , return_state = False , activation = "tanh" , go_backwards = True
573+ ),
574+ keras_mod .layers .GRU (
575+ units = 16 ,
576+ return_state = False ,
577+ recurrent_activation = "sigmoid" ,
578+ activation = "tanh" ,
579+ reset_after = False ,
580+ ),
571581 keras_mod .layers .GRU (
572582 units = 16 ,
573583 return_state = False ,
574584 recurrent_activation = "sigmoid" ,
575585 activation = "tanh" ,
576586 reset_after = False ,
587+ use_bias = False ,
577588 ),
578589 keras_mod .layers .GRU (
579590 units = 16 ,
@@ -582,6 +593,7 @@ def test_forward_rnn(self, keras_mod):
582593 activation = "tanh" ,
583594 reset_after = False ,
584595 use_bias = False ,
596+ go_backwards = True ,
585597 ),
586598 ]
587599 for rnn_func in rnn_funcs :
You can’t perform that action at this time.
0 commit comments