Skip to content

Commit d729528

Browse files
committed
PR comments
1 parent 3c849f2 commit d729528

File tree

2 files changed

+43
-6
lines changed

2 files changed

+43
-6
lines changed

python/sparkdl/graph/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def get_op(tfobj_or_name, graph):
6262
By default the graph we don't require this argument to be provided.
6363
"""
6464
graph = validated_graph(graph)
65+
_assert_same_graph(tfobj_or_name, graph)
6566
if isinstance(tfobj_or_name, tf.Operation):
6667
return tfobj_or_name
6768
name = tfobj_or_name
@@ -71,8 +72,7 @@ def get_op(tfobj_or_name, graph):
7172
raise TypeError('invalid op request for [type {}] {}'.format(type(name), name))
7273
_op_name = op_name(name, graph=None)
7374
op = graph.get_operation_by_name(_op_name)
74-
assert op is not None, \
75-
'cannot locate op {} in current graph'.format(_op_name)
75+
assert isinstance(op, tf.Operation), 'expect tf.Operation, but got {}'.format(type(op))
7676
return op
7777

7878
def get_tensor(tfobj_or_name, graph):
@@ -85,6 +85,7 @@ def get_tensor(tfobj_or_name, graph):
8585
By default the graph we don't require this argument to be provided.
8686
"""
8787
graph = validated_graph(graph)
88+
_assert_same_graph(tfobj_or_name, graph)
8889
if isinstance(tfobj_or_name, tf.Tensor):
8990
return tfobj_or_name
9091
name = tfobj_or_name
@@ -94,8 +95,7 @@ def get_tensor(tfobj_or_name, graph):
9495
raise TypeError('invalid tensor request for {} of {}'.format(name, type(name)))
9596
_tensor_name = tensor_name(name, graph=None)
9697
tnsr = graph.get_tensor_by_name(_tensor_name)
97-
assert tnsr is not None, \
98-
'cannot locate tensor {} in current graph'.format(_tensor_name)
98+
assert isinstance(tnsr, tf.Tensor), 'expect tf.Tensor, but got {}'.format(type(tnsr))
9999
return tnsr
100100

101101
def tensor_name(tfobj_or_name, graph=None):
@@ -213,3 +213,10 @@ def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False):
213213
return g
214214
else:
215215
return gdef_frozen
216+
217+
218+
def _assert_same_graph(tfobj, graph):
219+
if graph is None or not hasattr(tfobj, 'graph'):
220+
return
221+
assert tfobj.graph == graph, \
222+
'the graph of TensorFlow element {} != graph {}'.format(tfobj, graph)

python/tests/graph/test_utils.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,31 @@ def _gen_invalid_tensor_or_op_input_with_wrong_types():
4242
yield TestCase(data=wrong_val, description='wrong type {}'.format(type(wrong_val)))
4343

4444

45-
def _gen_valid_tensor_op_objects():
45+
def _gen_invalid_tensor_or_op_with_graph_pairing():
46+
tnsr = tf.constant(1427.08, name='someConstOp')
47+
other_graph = tf.Graph()
48+
op_name = tnsr.op.name
49+
50+
# Test get_tensor and get_op returns tensor or op contained in the same graph
51+
yield TestCase(data=lambda: tfx.get_op(tnsr, other_graph),
52+
description='test graph from getting op fron tensor')
53+
yield TestCase(data=lambda: tfx.get_tensor(tnsr, other_graph),
54+
description='test graph from getting tensor from tensor')
55+
yield TestCase(data=lambda: tfx.get_op(tnsr.name, other_graph),
56+
description='test graph from getting op fron tensor name')
57+
yield TestCase(data=lambda: tfx.get_tensor(tnsr.name, other_graph),
58+
description='test graph from getting tensor from tensor name')
59+
yield TestCase(data=lambda: tfx.get_op(tnsr.op, other_graph),
60+
description='test graph from getting op from op')
61+
yield TestCase(data=lambda: tfx.get_tensor(tnsr.op, other_graph),
62+
description='test graph from getting tensor from op')
63+
yield TestCase(data=lambda: tfx.get_op(op_name, other_graph),
64+
description='test graph from getting op from op name')
65+
yield TestCase(data=lambda: tfx.get_tensor(op_name, other_graph),
66+
description='test graph from getting tensor from op name')
67+
68+
69+
def _gen_valid_tensor_op_input_combos():
4670
op_name = 'someConstOp'
4771
tnsr_name = '{}:0'.format(op_name)
4872
tnsr = tf.constant(1427.08, name=op_name)
@@ -154,8 +178,14 @@ def test_invalid_tensor_inputs_with_wrong_types(self, data, description):
154178
with self.assertRaises(TypeError, msg=description):
155179
tfx.get_tensor(data, tf.Graph())
156180

157-
@parameterized.expand(_gen_valid_tensor_op_objects)
181+
@parameterized.expand(_gen_valid_tensor_op_input_combos)
158182
def test_valid_tensor_op_object_inputs(self, data, description):
159183
""" Must get correct graph elements from valid graph elements or their names """
160184
tfobj_or_name_a, tfobj_or_name_b = data
161185
self.assertEqual(tfobj_or_name_a, tfobj_or_name_b, msg=description)
186+
187+
@parameterized.expand(_gen_invalid_tensor_or_op_with_graph_pairing)
188+
def test_invalid_tensor_op_object_graph_pairing(self, data, description):
189+
""" Must fail when the graph element is from a different graph than the provided """
190+
with self.assertRaises((KeyError, AssertionError, TypeError), msg=description):
191+
data()

0 commit comments

Comments
 (0)