Skip to content

Commit

Permalink
Add property hub.Module.variables with the values of variable_map,
Browse files Browse the repository at this point in the history
except that nested lists are flattened.

PiperOrigin-RevId: 209943227
  • Loading branch information
TensorFlow Hub Authors authored and arnoegw committed Aug 24, 2018
1 parent eebe436 commit ebb7367
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tensorflow_hub/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,17 @@ def variable_map(self):
"""
return self._impl.variable_map

@property
def variables(self):
"""Returns the list of all tf.Variables created by module instantiation."""
result = []
for _, value in sorted(self.variable_map.items()):
if isinstance(value, list):
result.extend(value)
else:
result.append(value)
return result


def _try_get_state_scope(name, mark_name_scope_used=True):
"""Returns a fresh variable/name scope for a module's state.
Expand Down
10 changes: 10 additions & 0 deletions tensorflow_hub/native_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def testVariables(self):
out = m()
self.assertEqual(list(m.variable_map.keys()), ["var123"])
self.assertEqual(m.variable_map["var123"].name, "test/var123:0")
self.assertEqual([v.name for v in m.variables], ["test/var123:0"])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
self.assertAllClose(sess.run(out), [1.0, 2.0, 3.0])
Expand All @@ -373,6 +374,7 @@ def testResourceVariables(self):
out = m()
self.assertEqual(list(m.variable_map.keys()), ["rv_var123"])
self.assertEqual(m.variable_map["rv_var123"].name, "test_rv/rv_var123:0")
self.assertEqual([v.name for v in m.variables], ["test_rv/rv_var123:0"])

# Check that "shared_name" attributes are adapted correctly:
for op_prefix in ["test_rv", "test_rv_apply_default"]:
Expand Down Expand Up @@ -440,6 +442,7 @@ def testNonResourceVariables(self):
out = m()
self.assertEqual(list(m.variable_map.keys()), ["var123"])
self.assertEqual(m.variable_map["var123"].name, "test_non_rv/var123:0")
self.assertEqual([v.name for v in m.variables], ["test_non_rv/var123:0"])

export_path = os.path.join(self.get_temp_dir(), "non-resource-variables")
with tf.Session() as sess:
Expand Down Expand Up @@ -537,6 +540,12 @@ def testPartitionedVariables(self):
["test/partitioned_variable/part_0:0",
"test/partitioned_variable/part_1:0",
"test/partitioned_variable/part_2:0"])
self.assertAllEqual( # Check deterministric order (by variable_map key).
[variable.name for variable in m.variables],
["test/normal_variable:0",
"test/partitioned_variable/part_0:0",
"test/partitioned_variable/part_1:0",
"test/partitioned_variable/part_2:0"])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
self.assertAllClose(sess.run(out), 2 * np.ones([7, 3]))
Expand All @@ -548,6 +557,7 @@ def testLargePartitionedVariables(self):
out = m()
self.assertEqual(len(m.variable_map), 2)
self.assertEqual(len(m.variable_map["partitioned_variable"]), 25)
self.assertEqual(len(m.variables), 26)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
self.assertAllClose(sess.run(out), 2 * np.ones([600, 3]))
Expand Down

0 comments on commit ebb7367

Please sign in to comment.