Skip to content

Commit 87d2927

Browse files
authored
[Debugger Plugin] Fix Python unit tests for TF2 (#2583)
* Motivation for features / changes * Re. #1705 * Make sure all debugger plugin's python tests pass with TF2 * Technical description of changes * In session_debug_test.py and interactive_debugger_plugin_test.py: use tf.compat.v1 and disable_v2_behavior(), as tfdbg (V1) is indeed wedded to the Session::Run() paradigm. * In debug_graphs_helper_test.py: make conditional assertions to fit v2-specific behaviors
1 parent 69abce6 commit 87d2927

File tree

3 files changed

+49
-44
lines changed

3 files changed

+49
-44
lines changed

tensorboard/plugins/debugger/debug_graphs_helper_test.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838
import tensorflow as tf
3939
from tensorflow.python import debug as tf_debug
40+
# See discussion on issue #1996 for private module import justification.
41+
from tensorflow.python import tf2 as tensorflow_python_tf2
4042
from tensorflow.python.debug.lib import grpc_debug_test_server
4143

4244
from tensorboard.compat.proto import config_pb2
@@ -85,7 +87,6 @@ def _createTestGraphAndRunOptions(self, sess, gated_grpc=True):
8587
debug_urls=self.debug_server_url)
8688
return z, run_options
8789

88-
@test_util.run_v1_only('Ops differ. Similar to [1].')
8990
def testExtractGatedGrpcTensorsFoundGatedGrpcOps(self):
9091
with tf.compat.v1.Session() as sess:
9192
z, run_options = self._createTestGraphAndRunOptions(sess, gated_grpc=True)
@@ -108,15 +109,11 @@ def testExtractGatedGrpcTensorsFoundGatedGrpcOps(self):
108109
gated_debug_ops = [
109110
(item[0], item[2], item[3]) for item in gated_debug_ops]
110111

111-
# TODO(#1705): TF 2.0 breaks below.
112112
self.assertIn(('a', 0, 'DebugIdentity'), gated_debug_ops)
113-
self.assertIn(('a/read', 0, 'DebugIdentity'), gated_debug_ops)
114113
self.assertIn(('b', 0, 'DebugIdentity'), gated_debug_ops)
115-
self.assertIn(('b/read', 0, 'DebugIdentity'), gated_debug_ops)
116114
self.assertIn(('c', 0, 'DebugIdentity'), gated_debug_ops)
117-
self.assertIn(('c/read', 0, 'DebugIdentity'), gated_debug_ops)
118115
self.assertIn(('d', 0, 'DebugIdentity'), gated_debug_ops)
119-
self.assertIn(('d/read', 0, 'DebugIdentity'), gated_debug_ops)
116+
120117
self.assertIn(('x', 0, 'DebugIdentity'), gated_debug_ops)
121118
self.assertIn(('y', 0, 'DebugIdentity'), gated_debug_ops)
122119
self.assertIn(('z', 0, 'DebugIdentity'), gated_debug_ops)
@@ -151,9 +148,6 @@ def testExtractGatedGrpcTensorsFoundNoGatedGrpcOps(self):
151148
self.assertEqual([], gated_debug_ops)
152149

153150

154-
@test_util.run_v1_only((
155-
'Graph creates different op structure in v2. See '
156-
'debug_graphs_helper_test.py[1].'))
157151
class BaseExpandedNodeNameTest(tf.test.TestCase):
158152

159153
def testMaybeBaseExpandedNodeName(self):
@@ -174,9 +168,14 @@ def testMaybeBaseExpandedNodeName(self):
174168
self.assertEqual(
175169
'bar/b/read',
176170
graph_wrapper.maybe_base_expanded_node_name('bar/b/read'))
177-
# TODO(#1705): TF 2.0 tf.add creates nested nodes.
178-
self.assertEqual(
179-
'baz/c', graph_wrapper.maybe_base_expanded_node_name('baz/c'))
171+
172+
if tensorflow_python_tf2.enabled():
173+
# NOTE(#1705): TF 2.0 tf.add creates nested nodes.
174+
self.assertEqual(
175+
'baz/c/(c)', graph_wrapper.maybe_base_expanded_node_name('baz/c'))
176+
else:
177+
self.assertEqual(
178+
'baz/c', graph_wrapper.maybe_base_expanded_node_name('baz/c'))
180179

181180

182181
if __name__ == '__main__':

tensorboard/plugins/debugger/interactive_debugger_plugin_test.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import numpy as np
3434
import portpicker # pylint: disable=import-error
3535
from six.moves import urllib # pylint: disable=wrong-import-order
36-
import tensorflow as tf # pylint: disable=wrong-import-order
36+
import tensorflow.compat.v1 as tf # pylint: disable=wrong-import-order
3737
from tensorflow.python import debug as tf_debug # pylint: disable=wrong-import-order
3838
from werkzeug import test as werkzeug_test # pylint: disable=wrong-import-order
3939
from werkzeug import wrappers # pylint: disable=wrong-import-order
@@ -44,11 +44,14 @@
4444
from tensorboard.plugins.debugger import interactive_debugger_plugin
4545
from tensorboard.util import test_util
4646

47+
# These unit tests for Debugger Plugin V1 are tied to TF1.x behavior
48+
# (`tf.Session`s).
49+
tf.disable_v2_behavior()
50+
4751

4852
_SERVER_URL_PREFIX = '/data/plugin/debugger/'
4953

5054

51-
@test_util.run_v1_only('Test fails to run and clean up properly; they time out.')
5255
class InteractiveDebuggerPluginTest(tf.test.TestCase):
5356

5457
def setUp(self):
@@ -111,14 +114,14 @@ def _deserializeResponse(self, response):
111114
def _runSimpleAddMultiplyGraph(self, variable_size=1):
112115
session_run_results = []
113116
def session_run_job():
114-
with tf.compat.v1.Session() as sess:
117+
with tf.Session() as sess:
115118
a = tf.Variable([10.0] * variable_size, name='a')
116119
b = tf.Variable([20.0] * variable_size, name='b')
117120
c = tf.Variable([30.0] * variable_size, name='c')
118121
x = tf.multiply(a, b, name="x")
119122
y = tf.add(x, c, name="y")
120123

121-
sess.run(tf.compat.v1.global_variables_initializer())
124+
sess.run(tf.global_variables_initializer())
122125

123126
sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url)
124127
session_run_results.append(sess.run(y))
@@ -129,12 +132,12 @@ def session_run_job():
129132
def _runMultiStepAssignAddGraph(self, steps):
130133
session_run_results = []
131134
def session_run_job():
132-
with tf.compat.v1.Session() as sess:
135+
with tf.Session() as sess:
133136
a = tf.Variable(10, dtype=tf.int32, name='a')
134137
b = tf.Variable(1, dtype=tf.int32, name='b')
135-
inc_a = tf.compat.v1.assign_add(a, b, name='inc_a')
138+
inc_a = tf.assign_add(a, b, name='inc_a')
136139

137-
sess.run(tf.compat.v1.global_variables_initializer())
140+
sess.run(tf.global_variables_initializer())
138141

139142
sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url)
140143
for _ in range(steps):
@@ -146,15 +149,15 @@ def session_run_job():
146149
def _runTfGroupGraph(self):
147150
session_run_results = []
148151
def session_run_job():
149-
with tf.compat.v1.Session() as sess:
152+
with tf.Session() as sess:
150153
a = tf.Variable(10, dtype=tf.int32, name='a')
151154
b = tf.Variable(20, dtype=tf.int32, name='b')
152155
d = tf.constant(1, dtype=tf.int32, name='d')
153-
inc_a = tf.compat.v1.assign_add(a, d, name='inc_a')
154-
inc_b = tf.compat.v1.assign_add(b, d, name='inc_b')
156+
inc_a = tf.assign_add(a, d, name='inc_a')
157+
inc_b = tf.assign_add(b, d, name='inc_b')
155158
inc_ab = tf.group([inc_a, inc_b], name="inc_ab")
156159

157-
sess.run(tf.compat.v1.global_variables_initializer())
160+
sess.run(tf.global_variables_initializer())
158161

159162
sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url)
160163
session_run_results.append(sess.run(inc_ab))
@@ -703,7 +706,7 @@ def testGetSourceOpTraceback(self):
703706
def _runInitializer(self):
704707
session_run_results = []
705708
def session_run_job():
706-
with tf.compat.v1.Session() as sess:
709+
with tf.Session() as sess:
707710
a = tf.Variable([10.0] * 10, name='a')
708711
sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url)
709712
# Run the initializer with a debugger-wrapped tf.Session.
@@ -781,7 +784,7 @@ def testCommDataForUninitializedTensorIsHandledCorrectly(self):
781784
def _runHealthPillNetwork(self):
782785
session_run_results = []
783786
def session_run_job():
784-
with tf.compat.v1.Session() as sess:
787+
with tf.Session() as sess:
785788
a = tf.Variable(
786789
[np.nan, np.inf, np.inf, -np.inf, -np.inf, -np.inf, 10, 20, 30],
787790
dtype=tf.float32, name='a')
@@ -828,11 +831,11 @@ def testHealthPill(self):
828831
def _runAsciiStringNetwork(self):
829832
session_run_results = []
830833
def session_run_job():
831-
with tf.compat.v1.Session() as sess:
834+
with tf.Session() as sess:
832835
str1 = tf.Variable('abc', name='str1')
833836
str2 = tf.Variable('def', name='str2')
834837
str_concat = tf.add(str1, str2, name='str_concat')
835-
sess.run(tf.compat.v1.global_variables_initializer())
838+
sess.run(tf.global_variables_initializer())
836839
sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url)
837840
session_run_results.append(sess.run(str_concat))
838841
session_run_thread = threading.Thread(target=session_run_job)
@@ -885,11 +888,11 @@ def testAsciiStringTensorIsHandledCorrectly(self):
885888
def _runBinaryStringNetwork(self):
886889
session_run_results = []
887890
def session_run_job():
888-
with tf.compat.v1.Session() as sess:
891+
with tf.Session() as sess:
889892
str1 = tf.Variable([b'\x01' * 3, b'\x02' * 3], name='str1')
890893
str2 = tf.Variable([b'\x03' * 3, b'\x04' * 3], name='str2')
891894
str_concat = tf.add(str1, str2, name='str_concat')
892-
sess.run(tf.compat.v1.global_variables_initializer())
895+
sess.run(tf.global_variables_initializer())
893896
sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url)
894897
session_run_results.append(sess.run(str_concat))
895898
session_run_thread = threading.Thread(target=session_run_job)

tensorboard/plugins/debugger/session_debug_test.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,18 @@
3535

3636
import numpy as np
3737
import portpicker # pylint: disable=import-error
38-
import tensorflow as tf # pylint: disable=wrong-import-order
38+
import tensorflow.compat.v1 as tf # pylint: disable=wrong-import-order
3939
from tensorflow.python import debug as tf_debug # pylint: disable=wrong-import-order
4040

4141
from tensorboard.plugins.debugger import constants
4242
from tensorboard.plugins.debugger import debugger_server_lib
4343
from tensorboard.util import test_util
4444

45+
# These unit tests for Debugger Plugin V1 are tied to TF1.x behavior
46+
# (`tf.Session`s).
47+
tf.disable_v2_behavior()
48+
4549

46-
@test_util.run_v1_only("Server fails to come up. Requires study.")
4750
class SessionDebugTestBase(tf.test.TestCase):
4851

4952
def setUp(self):
@@ -67,17 +70,17 @@ def tearDown(self):
6770
if os.path.isdir(self._logdir):
6871
shutil.rmtree(self._logdir)
6972

70-
tf.compat.v1.reset_default_graph()
73+
tf.reset_default_graph()
7174

7275
def _poll_server_till_success(self, max_tries, poll_interval_seconds):
7376
for _ in range(max_tries):
7477
try:
75-
with tf.compat.v1.Session() as sess:
78+
with tf.Session() as sess:
7679
a_init_val = np.array([42.0])
7780
a_init = tf.constant(a_init_val, shape=[1], name="a_init")
7881
a = tf.Variable(a_init, name="a")
7982

80-
run_options = tf.compat.v1.RunOptions(output_partition_graphs=True)
83+
run_options = tf.RunOptions(output_partition_graphs=True)
8184
tf_debug.watch_graph(run_options,
8285
sess.graph,
8386
debug_ops=["DebugNumericSummary"],
@@ -125,7 +128,7 @@ def _compute_health_pill(self, x):
125128
def _check_health_pills_in_events_file(self,
126129
events_file_path,
127130
debug_key_to_tensors):
128-
reader = tf.compat.v1.python_io.tf_record_iterator(events_file_path)
131+
reader = tf.python_io.tf_record_iterator(events_file_path)
129132
event_read = tf.Event()
130133

131134
# The first event in the file should contain the events version, which is
@@ -159,7 +162,7 @@ def _check_health_pills_in_events_file(self,
159162
health_pills[debug_key][i])
160163

161164
def testRunSimpleNetworkoWithInfAndNaNWorks(self):
162-
with tf.compat.v1.Session() as sess:
165+
with tf.Session() as sess:
163166
x_init_val = np.array([[2.0], [-1.0]])
164167
y_init_val = np.array([[0.0], [-0.25]])
165168
z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]])
@@ -171,14 +174,14 @@ def testRunSimpleNetworkoWithInfAndNaNWorks(self):
171174
z_init = tf.constant(z_init_val, shape=[2, 2])
172175
z = tf.Variable(z_init, name="z")
173176

174-
u = tf.compat.v1.div(x, y, name="u") # Produces an Inf.
177+
u = tf.div(x, y, name="u") # Produces an Inf.
175178
v = tf.matmul(z, u, name="v") # Produces NaN and Inf.
176179

177180
sess.run(x.initializer)
178181
sess.run(y.initializer)
179182
sess.run(z.initializer)
180183

181-
run_options = tf.compat.v1.RunOptions(output_partition_graphs=True)
184+
run_options = tf.RunOptions(output_partition_graphs=True)
182185
tf_debug.watch_graph(run_options,
183186
sess.graph,
184187
debug_ops=["DebugNumericSummary"],
@@ -221,18 +224,18 @@ def testRunSimpleNetworkoWithInfAndNaNWorks(self):
221224
self.assertEqual(0, report[1].pos_inf_event_count)
222225

223226
def testMultipleInt32ValuesOverMultipleRunsAreRecorded(self):
224-
with tf.compat.v1.Session() as sess:
227+
with tf.Session() as sess:
225228
x_init_val = np.array([10], dtype=np.int32)
226229
x_init = tf.constant(x_init_val, shape=[1], name="x_init")
227230
x = tf.Variable(x_init, name="x")
228231

229232
x_inc_val = np.array([2], dtype=np.int32)
230233
x_inc = tf.constant(x_inc_val, name="x_inc")
231-
inc_x = tf.compat.v1.assign_add(x, x_inc, name="inc_x")
234+
inc_x = tf.assign_add(x, x_inc, name="inc_x")
232235

233236
sess.run(x.initializer)
234237

235-
run_options = tf.compat.v1.RunOptions(output_partition_graphs=True)
238+
run_options = tf.RunOptions(output_partition_graphs=True)
236239
tf_debug.watch_graph(run_options,
237240
sess.graph,
238241
debug_ops=["DebugNumericSummary"],
@@ -266,7 +269,7 @@ def testConcurrentNumericsAlertsAreRegisteredCorrectly(self):
266269
# Before any Session runs, the report ought to be empty.
267270
self.assertEqual([], self._debug_data_server.numerics_alert_report())
268271

269-
with tf.compat.v1.Session() as sess:
272+
with tf.Session() as sess:
270273
x_init_val = np.array([[2.0], [-1.0]])
271274
y_init_val = np.array([[0.0], [-0.25]])
272275
z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]])
@@ -278,7 +281,7 @@ def testConcurrentNumericsAlertsAreRegisteredCorrectly(self):
278281
z_init = tf.constant(z_init_val, shape=[2, 2])
279282
z = tf.Variable(z_init, name="z")
280283

281-
u = tf.compat.v1.div(x, y, name="u") # Produces an Inf.
284+
u = tf.div(x, y, name="u") # Produces an Inf.
282285
v = tf.matmul(z, u, name="v") # Produces NaN and Inf.
283286

284287
sess.run(x.initializer)
@@ -287,7 +290,7 @@ def testConcurrentNumericsAlertsAreRegisteredCorrectly(self):
287290

288291
run_options_list = []
289292
for i in range(num_threads):
290-
run_options = tf.compat.v1.RunOptions(output_partition_graphs=True)
293+
run_options = tf.RunOptions(output_partition_graphs=True)
291294
# Use different grpc:// URL paths so that each thread opens a separate
292295
# gRPC stream to the debug data server, simulating multi-worker setting.
293296
tf_debug.watch_graph(run_options,

0 commit comments

Comments
 (0)