Skip to content

Commit 1b3ddd6

Browse files
Ramana RadhakrishnanTrevor Morris
authored andcommitted
Tf2 test fixups (apache#5391)
* Fix oversight in importing tf.compat.v1 as tf. * Actually disable test for lstm in TF2.1 Since the testing framework actually uses pytest, the version check needs to be moved.
1 parent b9b9422 commit 1b3ddd6

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

tests/python/frontend/tensorflow/test_bn_dynamic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
"""
2323
import tvm
2424
import numpy as np
25-
import tensorflow as tf
25+
try:
26+
import tensorflow.compat.v1 as tf
27+
except ImportError:
28+
import tensorflow as tf
2629
from tvm import relay
2730
from tensorflow.python.framework import graph_util
2831

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,7 +1901,9 @@ def _get_tensorflow_output():
19011901

19021902
def test_forward_lstm():
19031903
'''test LSTM block cell'''
1904-
_test_lstm_cell(1, 2, 1, 0.5, 'float32')
1904+
if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'):
1905+
#in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
1906+
_test_lstm_cell(1, 2, 1, 0.5, 'float32')
19051907

19061908

19071909
#######################################################################
@@ -3308,9 +3310,7 @@ def test_forward_isfinite():
33083310
test_forward_ptb()
33093311

33103312
# RNN
3311-
if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'):
3312-
#in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
3313-
test_forward_lstm()
3313+
test_forward_lstm()
33143314

33153315
# Elementwise
33163316
test_forward_ceil()

0 commit comments

Comments
 (0)