Skip to content

Commit

Permalink
Add sparse_text_embedding_column to OSS TF-Hub
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 252622345
  • Loading branch information
TensorFlow Hub Authors authored and andresusanopinto committed Jun 12, 2019
1 parent e6e0589 commit 2684046
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 42 deletions.
4 changes: 2 additions & 2 deletions tensorflow_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
from distutils.version import LooseVersion
import tensorflow as tf


# pylint: disable=g-import-not-at-top
# Only do imports after check TensorFlow version so the useful
# error message is thrown instead of an obscure error of missing
# symbols at executing the imports.
from tensorflow_hub.estimator import LatestModuleExporter
from tensorflow_hub.estimator import register_module_for_export
from tensorflow_hub.feature_column import image_embedding_column
from tensorflow_hub.feature_column import sparse_text_embedding_column
from tensorflow_hub.feature_column import text_embedding_column
from tensorflow_hub.image_util import attach_image_module_info
from tensorflow_hub.image_util import get_expected_image_size
Expand Down Expand Up @@ -63,12 +63,12 @@
# pylint: enable=g-bad-import-order
# pylint: enable=g-import-not-at-top


# Used by doc generation script.
_allowed_symbols = [
"LatestModuleExporter",
"register_module_for_export",
"image_embedding_column",
"sparse_text_embedding_column",
"text_embedding_column",
"attach_image_module_info",
"get_expected_image_size",
Expand Down
151 changes: 135 additions & 16 deletions tensorflow_hub/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@
# from it if this version does.
if hasattr(feature_column_lib, "DenseColumn"):
# Use feature columns v2 if available.
class DenseFeatureColumn(feature_column._DenseColumn, # pylint: disable=protected-access
feature_column_lib.DenseColumn):
class DenseFeatureColumn(
feature_column._DenseColumn, # pylint: disable=protected-access
feature_column_lib.DenseColumn):
pass
else:
class DenseFeatureColumn(feature_column._DenseColumn): # pylint: disable=protected-access
Expand Down Expand Up @@ -80,10 +81,10 @@ def text_embedding_column(key, module_spec, trainable=False):
key: A string or `_FeatureColumn` identifying the text feature.
module_spec: A ModuleSpec defining the Module to instantiate or a path where
to load a ModuleSpec via `load_module_spec`
trainable: Whether or not the Module is trainable. False by default,
meaning the pre-trained weights are frozen. This is different from the
ordinary tf.feature_column.embedding_column(), but that one is intended
for training from scratch.
trainable: Whether or not the Module is trainable. False by default, meaning
the pre-trained weights are frozen. This is different from the ordinary
tf.feature_column.embedding_column(), but that one is intended for
training from scratch.
Returns:
`_DenseColumn` that converts from text input.
Expand All @@ -93,8 +94,8 @@ def text_embedding_column(key, module_spec, trainable=False):
"""
module_spec = module.as_module_spec(module_spec)
_check_module_is_text_embedding(module_spec)
return _TextEmbeddingColumn(key=key, module_spec=module_spec,
trainable=trainable)
return _TextEmbeddingColumn(
key=key, module_spec=module_spec, trainable=trainable)


def _check_module_is_text_embedding(module_spec):
Expand All @@ -118,10 +119,8 @@ def _check_module_is_text_embedding(module_spec):
input_shape = input_info.get_shape()
if not (input_info.dtype == tf.string and input_shape.ndims == 1 and
input_shape.as_list() == [None]):
issues.append(
"Module default signature must have only one input "
"tf.Tensor(shape=(?,), dtype=string)"
)
issues.append("Module default signature must have only one input "
"tf.Tensor(shape=(?,), dtype=string)")

# Find issues with signature outputs.
output_info_dict = module_spec.get_output_info_dict()
Expand All @@ -132,10 +131,8 @@ def _check_module_is_text_embedding(module_spec):
output_shape = output_info.get_shape()
if not (output_info.dtype == tf.float32 and output_shape.ndims == 2 and
not output_shape.as_list()[0] and output_shape.as_list()[1]):
issues.append(
"Module default signature must have a 'default' output of "
"tf.Tensor(shape=(?,K), dtype=float32)."
)
issues.append("Module default signature must have a 'default' output of "
"tf.Tensor(shape=(?,K), dtype=float32).")

if issues:
raise ValueError("Module is not a text-embedding: %r" % issues)
Expand Down Expand Up @@ -360,3 +357,125 @@ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
def get_dense_tensor(self, transformation_cache, state_manager):
images = transformation_cache.get(self, state_manager)
return self._get_dense_tensor_for_images(images)


def sparse_text_embedding_column(key,
module_spec,
combiner,
default_value,
trainable=False):
"""Uses a Module to construct dense representations from sparse text features.
The input to this feature column is a batch of multiple strings with
arbitrary size, assuming the input is a SparseTensor.
This type of feature column is typically suited for modules that operate on
pre-tokenized text to produce token level embeddings which are combined with
the combiner into a text embedding. The combiner always treats the tokens as a
bag of words rather than a sequence.
The output (i.e., transformed input layer) is a DenseTensor, with shape
[batch_size, num_embedding_dim].
For Example:
```python
comment = sparse_text_embedding_column("comment", "/tmp/text_module")
feature_columns = [comment, ...]
...
features = {
"comment": tf.SparseTensor(indices=[[0, 0], [1, 2]],
values=['sparse', 'embedding'],
dense_shape=[3, 4]),
...
}
estimator = tf.estimator.DNNClassifier(hidden_units, feature_columns)
```
Args:
key: A string or `_FeatureColumn` identifying the text feature.
module_spec: A string handle or a `_ModuleSpec` identifying the module.
combiner: a string specifying reducing op for embeddings in the same
Example. Currently, 'mean', 'sqrtn', 'sum' are supported. Using
combiner=None is undefined.
default_value: default value for Examples where the text feature is empty.
Note, it's recommended to have default_value consistent OOV tokens, in
case there was special handling of OOV in the text module. If None, the
text feature is assumed be non-empty for each Example.
trainable: Whether or not the Module is trainable. False by default, meaning
the pre-trained weights are frozen. This is different from the ordinary
tf.feature_column.embedding_column(), but that one is intended for
training from scratch.
Returns:
`_DenseColumn` that converts from text input.
Raises:
ValueError: if module_spec is not suitable for use in this feature column.
ValueError: if combiner not in ('mean', 'sqrtn', 'sum').
"""
module_spec = module.as_module_spec(module_spec)
_check_module_is_text_embedding(module_spec)
if combiner not in ("mean", "sqrtn", "sum"):
raise ValueError("combiner must be 'mean', 'sqrtn' or 'sum': %r" % combiner)
return _SparseTextEmbeddingColumn(
key=key,
module_spec=module_spec,
trainable=trainable,
default_value=default_value,
combiner=combiner)


class _SparseTextEmbeddingColumn(
feature_column._DenseColumn, # pylint: disable=protected-access
collections.namedtuple(
"_ModuleEmbeddingColumn",
("key", "combiner", "module_spec", "default_value", "trainable"))):
"""Returned by sparse_text_embedding_column(). Do not use directly."""

@property
def name(self):
"""Returns string. Used for variable_scope and naming."""
if not hasattr(self, "_name"):
key_name = self.key if isinstance(self.key,
six.string_types) else self.key.name
self._name = "{}_hub_module_embedding".format(key_name)
return self._name

def _transform_feature(self, inputs):
"""Returns intermediate representation (usually a `Tensor`)."""
return inputs.get(self.key)

@property
def _parse_example_spec(self):
"""Returns a `tf.Example` parsing spec as dict."""
return {self.key: tf_v1.VarLenFeature(tf.string)}

@property
def _variable_shape(self):
"""`TensorShape` of `_get_dense_tensor`, without batch dimension."""
return self.module_spec.get_output_info_dict()["default"].get_shape()[1:]

def _get_dense_tensor_for_inputs(self, text_batch, trainable):
m = module.Module(self.module_spec, trainable=self.trainable and trainable)

if self.default_value is not None:
text_batch = tf.sparse.fill_empty_rows(text_batch, self.default_value)[0]
embedded_tokens = m(text_batch.values)
embedding_ids = tf.SparseTensor(
indices=text_batch.indices,
values=tf.range(tf.shape(text_batch.indices)[0], dtype=tf.int32),
dense_shape=text_batch.dense_shape)

return tf.nn.embedding_lookup_sparse(
params=embedded_tokens,
sp_ids=embedding_ids,
sp_weights=None,
combiner=self.combiner)

def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
"""Returns a `Tensor`."""
del weight_collections
text_batch = inputs.get(self)
return self._get_dense_tensor_for_inputs(text_batch, self.trainable and
trainable)
Loading

0 comments on commit 2684046

Please sign in to comment.