Skip to content

Commit

Permalink
Minor updates in lookup table names scopes.
Browse files Browse the repository at this point in the history
Change: 141233623
  • Loading branch information
ysuematsu authored and tensorflower-gardener committed Dec 7, 2016
1 parent 89d31de commit a122c51
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions tensorflow/contrib/lookup/lookup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ def size(self, name=None):
Returns:
A scalar tensor containing the number of elements in this table.
"""
if name is None:
name = "%s_Size" % self._name
# pylint: disable=protected-access
return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
# pylint: enable=protected-access
with ops.name_scope(name, "%s_Size" % self._name,
[self._table_ref]) as scope:
# pylint: disable=protected-access
return gen_data_flow_ops._lookup_table_size(self._table_ref, name=scope)
# pylint: enable=protected-access

def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
Expand All @@ -163,9 +163,6 @@ def lookup(self, keys, name=None):
TypeError: when `keys` or `default_value` doesn't match the table data
types.
"""
if name is None:
name = "%s_lookup_table_find" % self._name

key_tensor = keys
if isinstance(keys, sparse_tensor.SparseTensor):
key_tensor = keys.values
Expand All @@ -174,12 +171,12 @@ def lookup(self, keys, name=None):
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))

# pylint: disable=protected-access
values = gen_data_flow_ops._lookup_table_find(self._table_ref,
key_tensor,
self._default_value,
name=name)
# pylint: enable=protected-access
with ops.name_scope(name, "%s_Lookup" % self._name,
[self._table_ref]) as scope:
# pylint: disable=protected-access
values = gen_data_flow_ops._lookup_table_find(
self._table_ref, key_tensor, self._default_value, name=scope)
# pylint: enable=protected-access

values.set_shape(key_tensor.get_shape())
if isinstance(keys, sparse_tensor.SparseTensor):
Expand Down Expand Up @@ -220,13 +217,13 @@ def __init__(self, initializer, default_value, shared_name=None, name=None):
Returns:
A `HashTable` object.
"""
with ops.name_scope(name, "hash_table", [initializer]):
with ops.name_scope(name, "hash_table", [initializer]) as scope:
# pylint: disable=protected-access
table_ref = gen_data_flow_ops._hash_table(
shared_name=shared_name,
key_dtype=initializer.key_dtype,
value_dtype=initializer.value_dtype,
name=name)
name=scope)
# pylint: enable=protected-access

super(HashTable, self).__init__(table_ref, default_value, initializer)
Expand Down

0 comments on commit a122c51

Please sign in to comment.