From a122c5146250fda93db5b28baf28c88ee56f8355 Mon Sep 17 00:00:00 2001 From: Yutaka Leon Date: Tue, 6 Dec 2016 15:13:30 -0800 Subject: [PATCH] Minor updates in lookup table names scopes. Change: 141233623 --- tensorflow/contrib/lookup/lookup_ops.py | 29 +++++++++++-------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 5047d8d87bd213..2e449afcfaea8c 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -141,11 +141,11 @@ def size(self, name=None): Returns: A scalar tensor containing the number of elements in this table. """ - if name is None: - name = "%s_Size" % self._name - # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name) - # pylint: enable=protected-access + with ops.name_scope(name, "%s_Size" % self._name, + [self._table_ref]) as scope: + # pylint: disable=protected-access + return gen_data_flow_ops._lookup_table_size(self._table_ref, name=scope) + # pylint: enable=protected-access def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -163,9 +163,6 @@ def lookup(self, keys, name=None): TypeError: when `keys` or `default_value` doesn't match the table data types. """ - if name is None: - name = "%s_lookup_table_find" % self._name - key_tensor = keys if isinstance(keys, sparse_tensor.SparseTensor): key_tensor = keys.values @@ -174,12 +171,12 @@ def lookup(self, keys, name=None): raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) - # pylint: disable=protected-access - values = gen_data_flow_ops._lookup_table_find(self._table_ref, - key_tensor, - self._default_value, - name=name) - # pylint: enable=protected-access + with ops.name_scope(name, "%s_Lookup" % self._name, + [self._table_ref]) as scope: + # pylint: disable=protected-access + values = gen_data_flow_ops._lookup_table_find( + self._table_ref, key_tensor, self._default_value, name=scope) + # pylint: enable=protected-access values.set_shape(key_tensor.get_shape()) if isinstance(keys, sparse_tensor.SparseTensor): @@ -220,13 +217,13 @@ def __init__(self, initializer, default_value, shared_name=None, name=None): Returns: A `HashTable` object. """ - with ops.name_scope(name, "hash_table", [initializer]): + with ops.name_scope(name, "hash_table", [initializer]) as scope: # pylint: disable=protected-access table_ref = gen_data_flow_ops._hash_table( shared_name=shared_name, key_dtype=initializer.key_dtype, value_dtype=initializer.value_dtype, - name=name) + name=scope) # pylint: enable=protected-access super(HashTable, self).__init__(table_ref, default_value, initializer)