Skip to content

Commit

Permalink
Unblock control flow changes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 260897836
  • Loading branch information
TensorFlow Hub Authors authored and vbardiovskyg committed Aug 2, 2019
1 parent 9d60517 commit b31b500
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tensorflow_hub/native_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.eager import function as function_eager
from tensorflow.python.framework import function
from tensorflow.python.framework import test_util
from tensorflow.python.ops.control_flow_ops import ControlFlowContext
from tensorflow.python.ops.lookup_ops import HashTable
from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file
Expand Down Expand Up @@ -272,6 +273,7 @@ def testArgErrors(self):
with self.assertRaisesRegexp(TypeError, "missing"):
m()

@test_util.run_v1_only("b/138681007")
def testUseWithinWhileLoop(self):
with tf.Graph().as_default():
spec = hub.create_module_spec(double_module_fn)
Expand Down Expand Up @@ -563,6 +565,7 @@ def testNonResourceVariables(self):
variable_names = set(name for name, _ in variable_names_and_shapes)
self.assertEqual(variable_names, {"module/var123"})

@test_util.run_v1_only("b/138681007")
def testNonResourceVariableInWhileLoop(self):
with tf.Graph().as_default():
# This test uses non-Resource variables to see an actual colocation
Expand All @@ -582,6 +585,7 @@ def body(i, x):
sess.run(tf_v1.global_variables_initializer())
self.assertAllEqual(sess.run([oi, ox]), [4, 160.0])

@test_util.run_v1_only("b/138681007")
def testNonResourceVariableInCond(self):
with tf.Graph().as_default():
spec = hub.create_module_spec(stateful_non_rv_module_fn)
Expand Down Expand Up @@ -1545,6 +1549,7 @@ def testWhileModule(self):
grad = tf.gradients([y], [x])
self.assertAllClose(sess.run(grad, {x: 2, n: 3}), [12.0])

@test_util.run_v1_only("b/138681007")
def testUseModuleWithWhileLoopInsideCond(self):
spec = hub.create_module_spec(while_module_fn)
with tf.Graph().as_default():
Expand Down

0 comments on commit b31b500

Please sign in to comment.