Skip to content

Commit 14ee57f

Browse files
tf-transform-teamelmer-garduno
authored andcommitted
Project import generated by Copybara (go/copybara).
PiperOrigin-RevId: 154310859
1 parent 77e0e29 commit 14ee57f

25 files changed

+2511
-725
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ independent from the specific runner as possible.
4949
## Getting Started
5050

5151
For instructions on using tf.Transform see the [getting started
52-
guide](./getting_started.md)
52+
guide](./getting_started.md).

examples/sentiment_example.py

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@
4646
NUM_TEST_INSTANCES = 25000
4747

4848
REVIEW_COLUMN = 'review'
49+
REVIEW_WEIGHT = 'review_weight'
4950
LABEL_COLUMN = 'label'
5051

51-
PUNCTUATION_CHARACTERS = ['.', ',', '!', '?', '(', ')']
52+
DELIMITERS = '.,!?() '
5253

5354

5455
# pylint: disable=invalid-name
@@ -139,43 +140,14 @@ def preprocessing_fn(inputs):
139140
"""Preprocess input columns into transformed columns."""
140141
review = inputs[REVIEW_COLUMN]
141142

142-
def remove_character(s, char):
143-
"""Remove a character from a string.
144-
145-
Args:
146-
s: A SparseTensor of rank 1 of type tf.string
147-
char: A string of length 1
148-
149-
Returns:
150-
The string `s` with the given character removed (i.e. replaced by
151-
'')
152-
"""
153-
# Hacky implementation where we split and rejoin.
154-
split = tf.string_split(s, char)
155-
rejoined = tf.reduce_join(
156-
tf.sparse_to_dense(
157-
split.indices, split.dense_shape, split.values, ''),
158-
1)
159-
return rejoined
160-
161-
def remove_punctuation(s):
162-
"""Remove puncuation from a string.
163-
164-
Args:
165-
s: A SparseTensor of rank 1 of type tf.string
166-
167-
Returns:
168-
The string `s` with punctuation removed.
169-
"""
170-
for char in PUNCTUATION_CHARACTERS:
171-
s = remove_character(s, char)
172-
return s
173-
174-
cleaned_review = tft.map(remove_punctuation, review)
175-
review_tokens = tft.map(tf.string_split, cleaned_review)
143+
review_tokens = tft.map(lambda x: tf.string_split(x, DELIMITERS),
144+
review)
176145
review_indices = tft.string_to_int(review_tokens, top_k=VOCAB_SIZE)
146+
# Add one for the oov bucket created by string_to_int.
147+
review_weight = tft.tfidf_weights(review_indices, VOCAB_SIZE + 1)
177148
return {
178149
REVIEW_COLUMN: review_indices,
150+
REVIEW_WEIGHT: review_weight,
179151
LABEL_COLUMN: inputs[LABEL_COLUMN]
180152
}
181153

@@ -230,9 +202,11 @@ def train_and_evaluate(transformed_train_filepattern,
230202
review_column = feature_column.sparse_column_with_integerized_feature(
231203
REVIEW_COLUMN,
232204
bucket_size=VOCAB_SIZE + 1,
233-
combiner='sqrtn')
205+
combiner='sum')
206+
weighted_reviews = feature_column.weighted_sparse_column(review_column,
207+
REVIEW_WEIGHT)
234208

235-
estimator = learn.LinearClassifier([review_column])
209+
estimator = learn.LinearClassifier([weighted_reviews])
236210

237211
transformed_metadata = metadata_io.read_metadata(transformed_metadata_dir)
238212
train_input_fn = input_fn_maker.build_training_input_fn(

setup.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,38 +13,29 @@
1313
# limitations under the License.
1414
"""Package Setup script for the tf.Transform binary.
1515
"""
16-
import os
17-
1816
from setuptools import find_packages
1917
from setuptools import setup
2018

19+
# Tensorflow transform version.
20+
__version__ = '0.1.8'
2121

22-
def get_required_install_packages():
23-
return [
2422

23+
def _make_required_install_packages():
24+
return [
2525
# Using >= for better integration tests. During release this is
2626
# automatically changed to a ==.
27-
'google-cloud-dataflow == 0.6.0',
27+
'apache-beam[gcp] == 0.6.0',
2828
]
2929

3030

31-
def get_version():
32-
# Obtain the version from the global names on version.py
33-
# We cannot do 'from tensorflow_transform import version' since the transitive
34-
# dependencies will not be available when the installer is created.
35-
global_names = {}
36-
execfile(os.path.normpath('tensorflow_transform/version.py'), global_names)
37-
return global_names['__version__']
38-
39-
4031
setup(
4132
name='tensorflow-transform',
42-
version=get_version(),
33+
version=__version__,
4334
author='Google Inc.',
4435
author_email='tf-transform-feedback@google.com',
4536
license='Apache 2.0',
4637
namespace_packages=[],
47-
install_requires=get_required_install_packages(),
38+
install_requires=_make_required_install_packages(),
4839
packages=find_packages(),
4940
include_package_data=True,
5041
description='A library for data preprocessing with TensorFlow',

tensorflow_transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
from tensorflow_transform.analyzers import *
1818
from tensorflow_transform.api import *
1919
from tensorflow_transform.mappers import *
20+
from tensorflow_transform.pretrained_models import *
2021
# pylint: enable=wildcard-import

tensorflow_transform/analyzers.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,63 +21,105 @@
2121
from tensorflow_transform import api
2222

2323

24-
def min(x): # pylint: disable=redefined-builtin
24+
def _get_output_shape(x, reduce_instance_dims):
25+
"""Determines the shape of the output of a numerical analyzer.
26+
27+
Args:
28+
x: An input `Column' wrapping a `Tensor`.
29+
reduce_instance_dims: If true, collapses the batch and instance dimensions
30+
to arrive at a single scalar output. If False, only collapses the batch
31+
dimension and outputs a vector of the same shape as the output.
32+
33+
Returns:
34+
The shape to use for the output placeholder.
35+
"""
36+
if reduce_instance_dims:
37+
# Numerical analyzers produce scalar output by default
38+
return ()
39+
else:
40+
in_shape = x.tensor.shape
41+
if in_shape:
42+
# The output will be the same shape as the input, but without the batch.
43+
return in_shape.as_list()[1:]
44+
else:
45+
return None
46+
47+
48+
def min(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin
2549
"""Computes the minimum of a `Column`.
2650
2751
Args:
2852
x: An input `Column' wrapping a `Tensor`.
53+
reduce_instance_dims: By default collapses the batch and instance dimensions
54+
to arrive at a single scalar output. If False, only collapses the batch
55+
dimension and outputs a vector of the same shape as the output.
2956
3057
Returns:
3158
A `Statistic`.
3259
"""
3360
if not isinstance(x.tensor, tf.Tensor):
3461
raise TypeError('Expected a Tensor, but got %r' % x.tensor)
3562

63+
arg_dict = {'reduce_instance_dims': reduce_instance_dims}
3664

3765
# pylint: disable=protected-access
38-
return api._AnalyzerOutput(tf.placeholder(x.tensor.dtype, ()),
39-
api.CanonicalAnalyzers.MIN, [x], {})
66+
return api._AnalyzerOutput(
67+
tf.placeholder(x.tensor.dtype, _get_output_shape(
68+
x, reduce_instance_dims)), api.CanonicalAnalyzers.MIN, [x], arg_dict)
4069

4170

42-
def max(x): # pylint: disable=redefined-builtin
71+
def max(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin
4372
"""Computes the maximum of a `Column`.
4473
4574
Args:
4675
x: An input `Column' wrapping a `Tensor`.
76+
reduce_instance_dims: By default collapses the batch and instance dimensions
77+
to arrive at a single scalar output. If False, only collapses the batch
78+
dimension and outputs a vector of the same shape as the output.
4779
4880
Returns:
4981
A `Statistic`.
5082
"""
5183
if not isinstance(x.tensor, tf.Tensor):
5284
raise TypeError('Expected a Tensor, but got %r' % x.tensor)
5385

86+
arg_dict = {'reduce_instance_dims': reduce_instance_dims}
5487
# pylint: disable=protected-access
55-
return api._AnalyzerOutput(tf.placeholder(x.tensor.dtype, ()),
56-
api.CanonicalAnalyzers.MAX, [x], {})
88+
return api._AnalyzerOutput(
89+
tf.placeholder(x.tensor.dtype, _get_output_shape(
90+
x, reduce_instance_dims)), api.CanonicalAnalyzers.MAX, [x], arg_dict)
5791

5892

59-
def sum(x): # pylint: disable=redefined-builtin
93+
def sum(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin
6094
"""Computes the sum of a `Column`.
6195
6296
Args:
6397
x: An input `Column' wrapping a `Tensor`.
98+
reduce_instance_dims: By default collapses the batch and instance dimensions
99+
to arrive at a single scalar output. If False, only collapses the batch
100+
dimension and outputs a vector of the same shape as the output.
64101
65102
Returns:
66103
A `Statistic`.
67104
"""
68105
if not isinstance(x.tensor, tf.Tensor):
69106
raise TypeError('Expected a Tensor, but got %r' % x.tensor)
70107

108+
arg_dict = {'reduce_instance_dims': reduce_instance_dims}
71109
# pylint: disable=protected-access
72-
return api._AnalyzerOutput(tf.placeholder(x.tensor.dtype, ()),
73-
api.CanonicalAnalyzers.SUM, [x], {})
110+
return api._AnalyzerOutput(
111+
tf.placeholder(x.tensor.dtype, _get_output_shape(
112+
x, reduce_instance_dims)), api.CanonicalAnalyzers.SUM, [x], arg_dict)
74113

75114

76-
def size(x):
115+
def size(x, reduce_instance_dims=True):
77116
"""Computes the total size of instances in a `Column`.
78117
79118
Args:
80119
x: An input `Column' wrapping a `Tensor`.
120+
reduce_instance_dims: By default collapses the batch and instance dimensions
121+
to arrive at a single scalar output. If False, only collapses the batch
122+
dimension and outputs a vector of the same shape as the output.
81123
82124
Returns:
83125
A `Statistic`.
@@ -86,14 +128,17 @@ def size(x):
86128
raise TypeError('Expected a Tensor, but got %r' % x.tensor)
87129

88130
# Note: Calling `sum` defined in this module, not the builtin.
89-
return sum(api.map(tf.ones_like, x))
131+
return sum(api.map(tf.ones_like, x), reduce_instance_dims)
90132

91133

92-
def mean(x):
134+
def mean(x, reduce_instance_dims=True):
93135
"""Computes the mean of the values in a `Column`.
94136
95137
Args:
96138
x: An input `Column' wrapping a `Tensor`.
139+
reduce_instance_dims: By default collapses the batch and instance dimensions
140+
to arrive at a single scalar output. If False, only collapses the batch
141+
dimension and outputs a vector of the same shape as the output.
97142
98143
Returns:
99144
A `Column` with an underlying `Tensor` of shape [1], containing the mean.
@@ -102,7 +147,9 @@ def mean(x):
102147
raise TypeError('Expected a Tensor, but got %r' % x.tensor)
103148

104149
# Note: Calling `sum` defined in this module, not the builtin.
105-
return api.map_statistics(tf.divide, sum(x), size(x))
150+
return api.map_statistics(tf.divide,
151+
sum(x, reduce_instance_dims),
152+
size(x, reduce_instance_dims))
106153

107154

108155
def uniques(x, top_k=None, frequency_threshold=None):

tensorflow_transform/beam/analyzer_impls.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,22 @@
2828
from apache_beam.typehints import with_output_types
2929

3030
import six
31+
import tensorflow as tf
3132
from tensorflow_transform.beam import common
3233

3334

35+
def flatten_value_to_list(batch):
36+
"""Converts an N-D dense or sparse batch to a 1-D list."""
37+
if isinstance(batch, tf.SparseTensorValue):
38+
dense_values = batch.values
39+
else:
40+
dense_values = batch
41+
# Ravel for flattening and tolist so that we go to native Python types
42+
# for more efficient followup processing.
43+
#
44+
return dense_values.ravel().tolist()
45+
46+
3447
@with_input_types(List[common.NUMERIC_TYPE])
3548
@with_output_types(common.NUMERIC_TYPE)
3649
class _NumericAnalyzer(beam.PTransform):
@@ -40,12 +53,57 @@ def __init__(self, fn):
4053
self._fn = fn
4154

4255
def expand(self, pcoll):
56+
pcoll |= 'FlattenValueToList' >> beam.Map(flatten_value_to_list)
4357
return (pcoll
4458
| 'CombineWithinList' >> beam.Map(self._fn)
4559
| 'CombineGlobally'
4660
>> beam.CombineGlobally(self._fn).without_defaults())
4761

4862

63+
@with_input_types(List[common.PRIMITIVE_TYPE])
64+
@with_output_types(List[common.PRIMITIVE_TYPE])
65+
class _NumericAnalyzerOnBatchDim(beam.PTransform):
66+
"""Reduces a PCollection on the batche dimension using to the given function.
67+
68+
Args:
69+
fn: The function used to reduce the PCollection. It must take as inputs an
70+
ndarray of data, and an axis parameter used to specify that the
71+
reduction should only happen along the batch dimension, and all
72+
instance dimensions should be preserved.
73+
"""
74+
75+
class _CombineOnBatchDim(beam.CombineFn):
76+
"""Combines the PCollection only on the 0th dimension using nparray."""
77+
78+
def __init__(self, fn):
79+
self._fn = fn
80+
81+
def create_accumulator(self):
82+
return []
83+
84+
def add_input(self, accumulator, next_input):
85+
batch = self._fn(next_input, axis=0)
86+
if any(accumulator):
87+
return self._fn((accumulator, batch), axis=0)
88+
else:
89+
return batch
90+
91+
def merge_accumulators(self, accumulators):
92+
# numpy's sum, min, max, etc functions operate on array-like objects, but
93+
# not arbitrary iterables. Convert the provided accumulators into a list
94+
return self._fn(list(accumulators), axis=0)
95+
96+
def extract_output(self, accumulator):
97+
return accumulator
98+
99+
def __init__(self, fn):
100+
self._fn = fn
101+
102+
def expand(self, pcoll):
103+
return (pcoll | 'CombineOnBatchDim'
104+
>> beam.CombineGlobally(self._CombineOnBatchDim(self._fn)))
105+
106+
49107
@with_input_types(List[common.PRIMITIVE_TYPE])
50108
@with_output_types(List[common.PRIMITIVE_TYPE])
51109
class _UniquesAnalyzer(beam.PTransform):
@@ -63,6 +121,7 @@ def expand(self, pcoll):
63121
# this to create a single element PCollection containing this list of
64122
# pairs in sorted order by decreasing counts (and by values for equal
65123
# counts).
124+
pcoll |= 'FlattenValueToList' >> beam.Map(flatten_value_to_list)
66125

67126
counts = (
68127
pcoll
@@ -90,7 +149,7 @@ def expand(self, pcoll):
90149
# from a single file.
91150
#
92151
@beam.ptransform_fn
93-
def Reshard(pcoll):
152+
def Reshard(pcoll): # pylint: disable=invalid-name
94153
return (
95154
pcoll
96155
| 'PairWithNone' >> beam.Map(lambda x: (None, x))

0 commit comments

Comments
 (0)