Skip to content

Commit 12dcf29

Browse files
Adds max_norm option to embedding_lookup (and upstream functions).
Change: 139325873
1 parent 9003342 commit 12dcf29

File tree

6 files changed

+66
-16
lines changed

6 files changed

+66
-16
lines changed

tensorflow/contrib/layers/python/layers/embedding_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def safe_embedding_lookup_sparse(embedding_weights,
4141
combiner=None,
4242
default_id=None,
4343
name=None,
44-
partition_strategy="div"):
44+
partition_strategy="div",
45+
max_norm=None):
4546
"""Lookup embedding results, accounting for invalid IDs and empty features.
4647
4748
The partitioned embedding in `embedding_weights` must all be the same shape
@@ -75,6 +76,8 @@ def safe_embedding_lookup_sparse(embedding_weights,
7576
name: A name for this operation (optional).
7677
partition_strategy: A string specifying the partitioning strategy.
7778
Currently `"div"` and `"mod"` are supported. Default is `"div"`.
79+
max_norm: If not None, all embeddings are l2-normalized to max_norm before
80+
combining.
7881
7982
8083
Returns:
@@ -135,7 +138,8 @@ def safe_embedding_lookup_sparse(embedding_weights,
135138
sparse_weights,
136139
combiner=combiner,
137140
partition_strategy=partition_strategy,
138-
name=None if default_id is None else scope)
141+
name=None if default_id is None else scope,
142+
max_norm=max_norm)
139143

140144
if default_id is None:
141145
# Broadcast is_row_empty to the same shape as embedding_lookup_result,

tensorflow/contrib/layers/python/layers/feature_column.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ class _DeepEmbeddingLookupArguments(
162162
"combiner",
163163
"dimension",
164164
"shared_embedding_name",
165-
"hashed"])):
165+
"hashed",
166+
"max_norm"])):
166167
"""Represents the information needed from a column for embedding lookup.
167168
168169
Used to to compute DNN inputs and weighted sum.
@@ -822,7 +823,7 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
822823
"_EmbeddingColumn",
823824
["sparse_id_column", "dimension", "combiner", "initializer",
824825
"ckpt_to_load_from", "tensor_name_in_ckpt", "shared_embedding_name",
825-
"shared_vocab_size"])):
826+
"shared_vocab_size", "max_norm"])):
826827
"""Represents an embedding column.
827828
828829
Args:
@@ -863,7 +864,8 @@ def __new__(cls,
863864
ckpt_to_load_from=None,
864865
tensor_name_in_ckpt=None,
865866
shared_embedding_name=None,
866-
shared_vocab_size=None):
867+
shared_vocab_size=None,
868+
max_norm=None):
867869
if initializer is not None and not callable(initializer):
868870
raise ValueError("initializer must be callable if specified. "
869871
"Embedding of column_name: {}".format(
@@ -882,7 +884,8 @@ def __new__(cls,
882884
initializer, ckpt_to_load_from,
883885
tensor_name_in_ckpt,
884886
shared_embedding_name,
885-
shared_vocab_size)
887+
shared_vocab_size,
888+
max_norm)
886889

887890
@property
888891
def name(self):
@@ -922,7 +925,8 @@ def _deep_embedding_lookup_arguments(self, input_tensor):
922925
initializer=self.initializer,
923926
combiner=self.combiner,
924927
shared_embedding_name=self.shared_embedding_name,
925-
hashed=False)
928+
hashed=False,
929+
max_norm=self.max_norm)
926930

927931
def _checkpoint_path(self):
928932
if self.ckpt_to_load_from is not None:
@@ -1133,7 +1137,8 @@ def _deep_embedding_lookup_arguments(self, input_tensor):
11331137
combiner=self.combiner,
11341138
dimension=self.dimension,
11351139
shared_embedding_name=None,
1136-
hashed=True)
1140+
hashed=True,
1141+
max_norm=None)
11371142

11381143

11391144
def hashed_embedding_column(column_name,

tensorflow/contrib/layers/python/layers/feature_column_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def _embeddings_from_arguments(column,
130130
input_tensor,
131131
sparse_weights=weight_tensor,
132132
combiner=args.combiner,
133-
name=column.name + 'weights')
133+
name=column.name + 'weights',
134+
max_norm=args.max_norm)
134135

135136

136137
def _input_from_feature_columns(columns_to_tensors,

tensorflow/contrib/opt/python/training/variable_clipping_optimizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,12 @@ class VariableClippingOptimizer(optimizer.Optimizer):
4242
Multiple instances of `VariableClippingOptimizer` may be chained to specify
4343
different max norms for different subsets of variables.
4444
45+
This is more efficient at serving-time than using normalization during
46+
embedding lookup, at the expense of more expensive training and fewer
47+
guarantees about the norms.
48+
4549
@@__init__
50+
4651
"""
4752

4853
def __init__(self,

tensorflow/python/kernel_tests/embedding_ops_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,26 @@ def testSimpleSharded(self):
228228
self.assertAllEqual(np_result, tf_result)
229229
self.assertShapeEqual(np_result, embedding)
230230

231+
def testMaxNorm(self):
232+
with self.test_session():
233+
embeddings = tf.constant([[2.0]])
234+
235+
ids = tf.constant([0], dtype=tf.int32)
236+
embedding = tf.nn.embedding_lookup([embeddings], ids, max_norm=1.0)
237+
238+
self.assertAllEqual(embedding.eval(), [[1.0]])
239+
240+
def testMaxNormNontrivial(self):
241+
with self.test_session():
242+
embeddings = tf.constant([[2.0, 4.0], [3.0, 1.0]])
243+
244+
ids = tf.constant([0, 1], dtype=tf.int32)
245+
embedding = tf.nn.embedding_lookup([embeddings], ids, max_norm=2.0)
246+
247+
norms = tf.sqrt(tf.reduce_sum(embeddings * embeddings, axis=1))
248+
normalized = embeddings/tf.stack([norms, norms], axis=1)
249+
self.assertAllEqual(embedding.eval(), 2 * normalized.eval())
250+
231251
def testSimpleShardedPartitionedVariable(self):
232252
with self.test_session() as sess:
233253
num_shards = 2

tensorflow/python/ops/embedding_ops.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensorflow.python.framework import ops
2626
from tensorflow.python.framework import sparse_tensor
2727
from tensorflow.python.ops import array_ops
28+
from tensorflow.python.ops import clip_ops
2829
from tensorflow.python.ops import data_flow_ops
2930
from tensorflow.python.ops import math_ops
3031
from tensorflow.python.ops import resource_variable_ops
@@ -33,7 +34,7 @@
3334

3435

3536
def embedding_lookup(params, ids, partition_strategy="mod", name=None,
36-
validate_indices=True):
37+
validate_indices=True, max_norm=None):
3738
"""Looks up `ids` in a list of embedding tensors.
3839
3940
This function is used to perform parallel lookups on the list of
@@ -73,6 +74,8 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
7374
is `"mod"`.
7475
name: A name for the operation (optional).
7576
validate_indices: Whether or not to validate gather indices.
77+
max_norm: If not None, embedding values are l2-normalized to the value of
78+
max_norm.
7679
7780
Returns:
7881
A `Tensor` with the same type as the tensors in `params`.
@@ -86,17 +89,26 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
8689
params = list(params) # Iterate to get the underlying Variables.
8790
if not isinstance(params, list):
8891
params = [params]
92+
def maybe_normalize(x):
93+
if max_norm is not None:
94+
if x.get_shape().ndims is not None:
95+
ndims = x.get_shape().ndims
96+
else:
97+
ndims = array_ops.size(array_ops.shape(x))
98+
return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
99+
return x
89100
with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
90101
np = len(params) # Number of partitions
91102
params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
92103
if np == 1:
93104
with ops.colocate_with(params[0]):
94105
# TODO(apassos): implement the sharded version as well.
95106
if isinstance(params[0], resource_variable_ops.ResourceVariable):
96-
return params[0].sparse_read(ids, name=name)
107+
ret = params[0].sparse_read(ids, name=name)
97108
else:
98-
return array_ops.gather(params[0], ids, name=name,
99-
validate_indices=validate_indices)
109+
ret = array_ops.gather(params[0], ids, name=name,
110+
validate_indices=validate_indices)
111+
return maybe_normalize(ret)
100112
else:
101113
ids = ops.convert_to_tensor(ids, name="ids")
102114
flat_ids = array_ops.reshape(ids, [-1])
@@ -180,13 +192,14 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
180192
# Normally the reshape is sufficient, but setting shape explicitly
181193
# teaches shape inference that params[1:].get_shape() matters.
182194
ret.set_shape(ids.get_shape().concatenate(element_shape))
183-
return ret
195+
return maybe_normalize(ret)
184196

185197

186198
def embedding_lookup_sparse(params, sp_ids, sp_weights,
187199
partition_strategy="mod",
188200
name=None,
189-
combiner=None):
201+
combiner=None,
202+
max_norm=None):
190203
"""Computes embeddings for the given ids and weights.
191204
192205
This op assumes that there is at least one id for each row in the dense tensor
@@ -216,6 +229,8 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights,
216229
"mean" is the weighted sum divided by the total weight.
217230
"sqrtn" is the weighted sum divided by the square root of the sum of the
218231
squares of the weights.
232+
max_norm: If not None, each embedding is normalized to have l2 norm equal
233+
to max_norm before combining.
219234
220235
Returns:
221236
A dense tensor representing the combined embeddings for the
@@ -291,7 +306,7 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights,
291306
idx = None
292307

293308
embeddings = embedding_lookup(
294-
params, ids, partition_strategy=partition_strategy)
309+
params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
295310
if not ignore_weights:
296311
weights = sp_weights.values
297312
if weights.dtype != embeddings.dtype:

0 commit comments

Comments
 (0)