Skip to content

Commit 334c075

Browse files
zoyahavtf-transform-team
authored andcommitted
Update graph_tools to allow it to recursively inspect FuncGraphs produced by tf.functions.
PiperOrigin-RevId: 275588693
1 parent cea633a commit 334c075

File tree

7 files changed

+752
-161
lines changed

7 files changed

+752
-161
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
* Moved beam/shared lib to `tfx-bsl`. If running with latest master, `tfx-bsl`
2727
must also be latest master.
2828
* Depends on `tfx-bsl>=0.15,<0.16`.
29+
* `preprocessing_fn`s now have beta support of calls to `tf.function`s, as long
30+
as they don't contain calls to `tf.Transform` analyzers/mappers or table
31+
initializers.
2932

3033
## Breaking changes
3134
* `always_return_num_quantiles` changed to default to True in `tft.quantiles`

tensorflow_transform/beam/analysis_graph_builder.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import collections
2121
import copy
2222
import hashlib
23-
import uuid
2423

2524
# GOOGLE-INITIALIZATION
2625

@@ -41,6 +40,14 @@
4140
def _serialize_op_attr(op_attr):
4241
"""Deterministicly serializes tf.Operation attrs since it is a map."""
4342
sorted_attributes = sorted(op_attr.items(), key=lambda kv: kv[0])
43+
if 'f' in op_attr:
44+
# This is a tf.Function node, and it includes attributes that are
45+
# inconsistent across runs such as _gradient_op_type, config_proto, so we
46+
# only keep input and output types since other information will arrive from
47+
# the FuncGraph attributes.
48+
sorted_attributes = [
49+
kv for kv in sorted_attributes if kv[0] in ('Tin', 'Tout')
50+
]
4451
result = []
4552
for key, attr_value in sorted_attributes:
4653
result.append(key)
@@ -49,8 +56,7 @@ def _serialize_op_attr(op_attr):
4956
raise ValueError(
5057
'Unable to serialize op attributes that contain a `list.func` field')
5158
if attr_value.HasField('func'):
52-
# TODO(b/138796127): Support tf.function fingerprint.
53-
result.append(uuid.uuid4().hex)
59+
# There should be a separate call for the FuncGraph attributes.
5460
attr_value.ClearField('func')
5561
result.append(attr_value.SerializeToString())
5662
return result

tensorflow_transform/beam/analysis_graph_builder_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ def _preprocessing_fn_with_no_analyzers(inputs):
4848

4949

5050
def _preprocessing_fn_with_one_analyzer(inputs):
51-
x = inputs['x']
51+
52+
@tf.function
53+
def _plus_one(x):
54+
return x + 1
55+
56+
x = _plus_one(inputs['x'])
5257
x_mean = tft.mean(x, name='x')
5358
x_centered = x - x_mean
5459
return {'x_centered': x_centered}

tensorflow_transform/beam/cached_impl_test.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121
import collections
22+
import functools
2223
import itertools
2324
import os
2425
import struct
@@ -1019,12 +1020,14 @@ def preprocessing_fn(inputs):
10191020
preprocessing_fn, pipeline=p))
10201021
self.assertFalse(output_cache)
10211022

1022-
def test_tf_function_fails_cache(self):
1023+
def test_tf_function_works_with_cache(self):
10231024

1024-
def preprocessing_fn(inputs):
1025+
def preprocessing_fn(inputs, should_add_one):
10251026

10261027
@tf.function
10271028
def identity(x):
1029+
if should_add_one:
1030+
x = x + 1
10281031
return x
10291032

10301033
return {
@@ -1035,8 +1038,9 @@ def identity(x):
10351038

10361039
feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)}
10371040
input_data_dict = {'span-0': [dict(x=-2), dict(x=4)]}
1038-
run_result = self._run_pipeline(feature_spec, input_data_dict,
1039-
preprocessing_fn)
1041+
run_result = self._run_pipeline(
1042+
feature_spec, input_data_dict,
1043+
functools.partial(preprocessing_fn, should_add_one=False))
10401044
first_cache_output, p1 = run_result.cache_output, run_result.pipeline
10411045

10421046
for key in input_data_dict:
@@ -1050,12 +1054,100 @@ def identity(x):
10501054
_get_counter_value(p1.metrics, 'saved_models_created'),
10511055
_SINGLE_PHASE_NUM_SAVED_MODELS)
10521056

1057+
# Cache is still valid since the contents of the tf.function are the same.
1058+
run_result = self._run_pipeline(
1059+
feature_spec,
1060+
input_data_dict,
1061+
functools.partial(preprocessing_fn, should_add_one=False),
1062+
should_read_cache=True)
1063+
second_cache_output, p2 = run_result.cache_output, run_result.pipeline
1064+
1065+
self.assertFalse(second_cache_output)
1066+
1067+
self.assertEqual(_get_counter_value(p2.metrics, 'num_instances'), 0)
1068+
self.assertEqual(_get_counter_value(p2.metrics, 'cache_entries_decoded'), 1)
1069+
self.assertEqual(_get_counter_value(p2.metrics, 'cache_entries_encoded'), 0)
1070+
self.assertEqual(
1071+
_get_counter_value(p2.metrics, 'saved_models_created'),
1072+
_ZERO_PHASE_NUM_SAVED_MODELS)
1073+
1074+
self.assertEqual(_get_counter_value(p2.metrics, 'num_instances'), 0)
1075+
self.assertEqual(_get_counter_value(p2.metrics, 'cache_entries_decoded'), 1)
1076+
self.assertEqual(_get_counter_value(p2.metrics, 'cache_entries_encoded'), 0)
1077+
self.assertEqual(_get_counter_value(p2.metrics, 'saved_models_created'), 1)
1078+
1079+
# Modifying the tf.function contents causes cache invalidation.
1080+
run_result = self._run_pipeline(
1081+
feature_spec,
1082+
input_data_dict,
1083+
functools.partial(preprocessing_fn, should_add_one=True),
1084+
should_read_cache=True)
1085+
third_output_cache, p3 = run_result.cache_output, run_result.pipeline
1086+
1087+
for key in input_data_dict:
1088+
self.assertIn(key, third_output_cache)
1089+
self.assertEqual(1, len(third_output_cache[key]))
1090+
1091+
self.assertEqual(_get_counter_value(p3.metrics, 'num_instances'), 2)
1092+
self.assertEqual(_get_counter_value(p3.metrics, 'cache_entries_decoded'), 0)
1093+
self.assertEqual(_get_counter_value(p3.metrics, 'cache_entries_encoded'), 1)
1094+
self.assertEqual(_get_counter_value(p3.metrics, 'saved_models_created'), 2)
1095+
1096+
def test_incomplete_graphs_fail_cache(self):
1097+
1098+
def preprocessing_fn(inputs):
1099+
# Subtract 10 from x using a tf.while_loop.
1100+
@tf.function(input_signature=[
1101+
tf.TensorSpec([], tf.int32),
1102+
tf.TensorSpec([], tf.int64)
1103+
])
1104+
def stop_condition(counter, x_minus_counter):
1105+
del x_minus_counter # unused
1106+
return tf.less(counter, 10)
1107+
1108+
@tf.function(input_signature=[
1109+
tf.TensorSpec([], tf.int32),
1110+
tf.TensorSpec([], tf.int64)
1111+
])
1112+
def iteration(counter, x_minus_counter):
1113+
return tf.add(counter, 1), tf.add(x_minus_counter, -1)
1114+
1115+
initial_values = [tf.constant(0), inputs['x']]
1116+
final_values = tf.raw_ops.While(
1117+
cond=stop_condition.get_concrete_function(),
1118+
body=iteration.get_concrete_function(),
1119+
input=initial_values)
1120+
1121+
y = final_values[1]
1122+
1123+
return {'y': tft.mean(y) + tf.zeros_like(inputs['x'], dtype=tf.float32)}
1124+
1125+
feature_spec = {
1126+
'x': tf.io.FixedLenFeature([], tf.int64),
1127+
}
1128+
input_data_dict = {
1129+
'span-0': [dict(x=-2), dict(x=4)],
1130+
}
10531131
run_result = self._run_pipeline(feature_spec, input_data_dict,
10541132
preprocessing_fn)
1133+
first_cache_output, p1 = run_result.cache_output, run_result.pipeline
1134+
1135+
for key in input_data_dict:
1136+
self.assertIn(key, first_cache_output)
1137+
self.assertEqual(1, len(first_cache_output[key]))
1138+
1139+
self.assertEqual(_get_counter_value(p1.metrics, 'num_instances'), 2)
1140+
self.assertEqual(_get_counter_value(p1.metrics, 'cache_entries_decoded'), 0)
1141+
self.assertEqual(_get_counter_value(p1.metrics, 'cache_entries_encoded'), 1)
1142+
self.assertEqual(
1143+
_get_counter_value(p1.metrics, 'saved_models_created'),
1144+
_SINGLE_PHASE_NUM_SAVED_MODELS)
1145+
1146+
run_result = self._run_pipeline(
1147+
feature_spec, input_data_dict, preprocessing_fn, should_read_cache=True)
10551148
second_cache_output, p2 = run_result.cache_output, run_result.pipeline
10561149

1057-
# We expect a full output cache again because tf.function in the
1058-
# preprocessing_fn broke that cache entry.
1150+
# We expect the cache to fail here because the tf.function is now different.
10591151
for key in input_data_dict:
10601152
self.assertIn(key, second_cache_output)
10611153
self.assertEqual(1, len(second_cache_output[key]))

0 commit comments

Comments
 (0)