Skip to content

Commit 2ddd0df

Browse files
zoyahavtf-transform-team
authored andcommitted
Adding a test case validating the expected TFT graph for a 2-phase analysis which is fully covered by cache.
PiperOrigin-RevId: 405878964
1 parent b06e87b commit 2ddd0df

File tree

1 file changed

+99
-11
lines changed

1 file changed

+99
-11
lines changed

tensorflow_transform/beam/cached_impl_test.py

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
import functools
1919
import os
2020
import struct
21+
from typing import Callable, Mapping, List
2122
import apache_beam as beam
2223
from apache_beam.testing import util as beam_test_util
2324
import numpy as np
2425

2526
import tensorflow as tf
2627
import tensorflow_transform as tft
2728
from tensorflow_transform import analyzer_nodes
29+
from tensorflow_transform import common_types
2830
from tensorflow_transform import impl_helper
2931
from tensorflow_transform import nodes
3032
import tensorflow_transform.beam as tft_beam
@@ -73,10 +75,10 @@ def _preprocessing_fn_for_common_optimize_traversal(inputs):
7375
's': tf.io.FixedLenFeature([], tf.string)
7476
},
7577
preprocessing_fn=_preprocessing_fn_for_common_optimize_traversal,
76-
dataset_input_cache_dict={
78+
dataset_input_cache_dicts=[{
7779
_make_cache_key(b'CacheableCombineAccumulate[x#mean_and_var]'):
7880
'cache hit',
79-
},
81+
}],
8082
expected_dot_graph_str=r"""digraph G {
8183
directed=True;
8284
node [shape=Mrecord];
@@ -165,6 +167,83 @@ def _preprocessing_fn_for_common_optimize_traversal(inputs):
165167
}
166168
""")
167169

170+
_OPTIMIZE_TRAVERSAL_MULTI_PHASE_FULL_CACHE_HIT_CASE = dict(
171+
testcase_name='multi_phase_full_cache_coverage',
172+
feature_spec={
173+
'x': tf.io.FixedLenFeature([], tf.float32),
174+
's': tf.io.FixedLenFeature([], tf.string)
175+
},
176+
preprocessing_fn=_preprocessing_fn_for_common_optimize_traversal,
177+
dataset_input_cache_dicts=[{
178+
_make_cache_key(b'CacheableCombineAccumulate[x#mean_and_var]'):
179+
'cache hit',
180+
_make_cache_key(b'VocabularyAccumulate[vocabulary]'):
181+
'cache hit',
182+
}] * 2,
183+
expected_dot_graph_str=r"""digraph G {
184+
directed=True;
185+
node [shape=Mrecord];
186+
"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex0]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-0')|cache_key: \<bytes\>|coder: \<_VocabularyAccumulatorCoder\>|label: DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex0]|partitionable: True}"];
187+
"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-1')|cache_key: \<bytes\>|coder: \<_VocabularyAccumulatorCoder\>|label: DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]|partitionable: True}"];
188+
"FlattenCache[VocabularyMerge[vocabulary]]" [label="{Flatten|label: FlattenCache[VocabularyMerge[vocabulary]]|partitionable: True}"];
189+
"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex0]" -> "FlattenCache[VocabularyMerge[vocabulary]]";
190+
"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]" -> "FlattenCache[VocabularyMerge[vocabulary]]";
191+
"VocabularyMerge[vocabulary]" [label="{VocabularyMerge|vocab_ordering_type: 1|use_adjusted_mutual_info: False|min_diff_from_avg: None|label: VocabularyMerge[vocabulary]}"];
192+
"FlattenCache[VocabularyMerge[vocabulary]]" -> "VocabularyMerge[vocabulary]";
193+
"VocabularyCount[vocabulary]" [label="{VocabularyCount|label: VocabularyCount[vocabulary]}"];
194+
"VocabularyMerge[vocabulary]" -> "VocabularyCount[vocabulary]";
195+
"CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]" [label="{CreateTensorBinding|tensor: vocabulary/vocab_vocabulary_unpruned_vocab_size:0|is_asset_filepath: False|label: CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]}"];
196+
"VocabularyCount[vocabulary]" -> "CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]";
197+
"VocabularyPrune[vocabulary]" [label="{VocabularyPrune|top_k: None|frequency_threshold: 0|informativeness_threshold: -inf|coverage_top_k: None|coverage_frequency_threshold: 0|coverage_informativeness_threshold: -inf|key_fn: None|filter_newline_characters: True|input_dtype: string|label: VocabularyPrune[vocabulary]}"];
198+
"VocabularyMerge[vocabulary]" -> "VocabularyPrune[vocabulary]";
199+
"VocabularyOrderAndWrite[vocabulary]" [label="{VocabularyOrderAndWrite|vocab_filename: vocab_vocabulary|store_frequency: False|input_dtype: string|label: VocabularyOrderAndWrite[vocabulary]|fingerprint_shuffle: False|file_format: text|input_is_sorted: False}"];
200+
"VocabularyPrune[vocabulary]" -> "VocabularyOrderAndWrite[vocabulary]";
201+
"CreateTensorBinding[vocabulary#Placeholder]" [label="{CreateTensorBinding|tensor: vocabulary/Placeholder:0|is_asset_filepath: True|label: CreateTensorBinding[vocabulary#Placeholder]}"];
202+
"VocabularyOrderAndWrite[vocabulary]" -> "CreateTensorBinding[vocabulary#Placeholder]";
203+
"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex0]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-0')|cache_key: \<bytes\>|coder: \<JsonNumpyCacheCoder\>|label: DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex0]|partitionable: True}"];
204+
"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex1]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-1')|cache_key: \<bytes\>|coder: \<JsonNumpyCacheCoder\>|label: DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex1]|partitionable: True}"];
205+
"FlattenCache[CacheableCombineMerge[x#mean_and_var]]" [label="{Flatten|label: FlattenCache[CacheableCombineMerge[x#mean_and_var]]|partitionable: True}"];
206+
"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex0]" -> "FlattenCache[CacheableCombineMerge[x#mean_and_var]]";
207+
"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex1]" -> "FlattenCache[CacheableCombineMerge[x#mean_and_var]]";
208+
"CacheableCombineMerge[x#mean_and_var]" [label="{CacheableCombineMerge|combiner: \<WeightedMeanAndVarCombiner\>|label: CacheableCombineMerge[x#mean_and_var]}"];
209+
"FlattenCache[CacheableCombineMerge[x#mean_and_var]]" -> "CacheableCombineMerge[x#mean_and_var]";
210+
"ExtractCombineMergeOutputs[x#mean_and_var]" [label="{ExtractCombineMergeOutputs|output_tensor_info_list: [TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None), TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None)]|label: ExtractCombineMergeOutputs[x#mean_and_var]|{<0>0|<1>1}}"];
211+
"CacheableCombineMerge[x#mean_and_var]" -> "ExtractCombineMergeOutputs[x#mean_and_var]";
212+
"CreateTensorBinding[x#mean_and_var#Placeholder]" [label="{CreateTensorBinding|tensor: x/mean_and_var/Placeholder:0|is_asset_filepath: False|label: CreateTensorBinding[x#mean_and_var#Placeholder]}"];
213+
"ExtractCombineMergeOutputs[x#mean_and_var]":0 -> "CreateTensorBinding[x#mean_and_var#Placeholder]";
214+
"CreateTensorBinding[x#mean_and_var#Placeholder_1]" [label="{CreateTensorBinding|tensor: x/mean_and_var/Placeholder_1:0|is_asset_filepath: False|label: CreateTensorBinding[x#mean_and_var#Placeholder_1]}"];
215+
"ExtractCombineMergeOutputs[x#mean_and_var]":1 -> "CreateTensorBinding[x#mean_and_var#Placeholder_1]";
216+
"CreateSavedModelForAnalyzerInputs[Phase1]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x_square_deviations/mean_and_var/Cast_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x_square_deviations/mean_and_var/div_no_nan', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x_square_deviations/mean_and_var/div_no_nan_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x_square_deviations/mean_and_var/zeros', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase1]}"];
217+
"CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]" -> "CreateSavedModelForAnalyzerInputs[Phase1]";
218+
"CreateTensorBinding[vocabulary#Placeholder]" -> "CreateSavedModelForAnalyzerInputs[Phase1]";
219+
"CreateTensorBinding[x#mean_and_var#Placeholder]" -> "CreateSavedModelForAnalyzerInputs[Phase1]";
220+
"CreateTensorBinding[x#mean_and_var#Placeholder_1]" -> "CreateSavedModelForAnalyzerInputs[Phase1]";
221+
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
222+
"ApplySavedModel[Phase1]" [label="{ApplySavedModel|phase: 1|label: ApplySavedModel[Phase1]|partitionable: True}"];
223+
"CreateSavedModelForAnalyzerInputs[Phase1]" -> "ApplySavedModel[Phase1]";
224+
"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase1]";
225+
"TensorSource[x_square_deviations#mean_and_var]" [label="{ExtractFromDict|keys: ('x_square_deviations/mean_and_var/Cast_1', 'x_square_deviations/mean_and_var/div_no_nan', 'x_square_deviations/mean_and_var/div_no_nan_1', 'x_square_deviations/mean_and_var/zeros')|label: TensorSource[x_square_deviations#mean_and_var]|partitionable: True}"];
226+
"ApplySavedModel[Phase1]" -> "TensorSource[x_square_deviations#mean_and_var]";
227+
"CacheableCombineAccumulate[x_square_deviations#mean_and_var]" [label="{CacheableCombineAccumulate|combiner: \<WeightedMeanAndVarCombiner\>|label: CacheableCombineAccumulate[x_square_deviations#mean_and_var]|partitionable: True}"];
228+
"TensorSource[x_square_deviations#mean_and_var]" -> "CacheableCombineAccumulate[x_square_deviations#mean_and_var]";
229+
"CacheableCombineMerge[x_square_deviations#mean_and_var]" [label="{CacheableCombineMerge|combiner: \<WeightedMeanAndVarCombiner\>|label: CacheableCombineMerge[x_square_deviations#mean_and_var]}"];
230+
"CacheableCombineAccumulate[x_square_deviations#mean_and_var]" -> "CacheableCombineMerge[x_square_deviations#mean_and_var]";
231+
"ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]" [label="{ExtractCombineMergeOutputs|output_tensor_info_list: [TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None), TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None)]|label: ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]|{<0>0|<1>1}}"];
232+
"CacheableCombineMerge[x_square_deviations#mean_and_var]" -> "ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]";
233+
"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]" [label="{CreateTensorBinding|tensor: x_square_deviations/mean_and_var/Placeholder:0|is_asset_filepath: False|label: CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]}"];
234+
"ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]":0 -> "CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]";
235+
"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]" [label="{CreateTensorBinding|tensor: x_square_deviations/mean_and_var/Placeholder_1:0|is_asset_filepath: False|label: CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]}"];
236+
"ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]":1 -> "CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]";
237+
CreateSavedModel [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x_normalized', \"Tensor\<shape: [None], \<dtype: 'float32'\>\>\")])|label: CreateSavedModel}"];
238+
"CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]" -> CreateSavedModel;
239+
"CreateTensorBinding[vocabulary#Placeholder]" -> CreateSavedModel;
240+
"CreateTensorBinding[x#mean_and_var#Placeholder]" -> CreateSavedModel;
241+
"CreateTensorBinding[x#mean_and_var#Placeholder_1]" -> CreateSavedModel;
242+
"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]" -> CreateSavedModel;
243+
"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]" -> CreateSavedModel;
244+
}
245+
""")
246+
168247
_TF_VERSION_NAMED_PARAMETERS = [
169248
dict(testcase_name='CompatV1', use_tf_compat_v1=True),
170249
dict(testcase_name='V2', use_tf_compat_v1=False),
@@ -264,7 +343,7 @@ def is_partitionable(self):
264343
testcase_name='generalized_chained_ptransforms',
265344
feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)},
266345
preprocessing_fn=_preprocessing_fn_for_generalized_chained_ptransforms,
267-
dataset_input_cache_dict=None,
346+
dataset_input_cache_dicts=None,
268347
expected_dot_graph_str=r"""digraph G {
269348
directed=True;
270349
node [shape=Mrecord];
@@ -334,6 +413,7 @@ def is_partitionable(self):
334413

335414
_OPTIMIZE_TRAVERSAL_TEST_CASES = [
336415
_OPTIMIZE_TRAVERSAL_COMMON_CASE,
416+
_OPTIMIZE_TRAVERSAL_MULTI_PHASE_FULL_CACHE_HIT_CASE,
337417
_OPTIMIZE_TRAVERSAL_GENERALIZED_CHAINED_PTRANSFORMS_CASE,
338418
]
339419

@@ -1006,18 +1086,26 @@ def preprocessing_fn(inputs):
10061086

10071087
@tft_unit.named_parameters(*_OPTIMIZE_TRAVERSAL_TEST_CASES)
10081088
@mock_out_cache_hash
1009-
def test_optimize_traversal(self, feature_spec, preprocessing_fn,
1010-
dataset_input_cache_dict, expected_dot_graph_str):
1011-
span_0_key, span_1_key = analyzer_cache.DatasetKey(
1012-
'span-0'), analyzer_cache.DatasetKey('span-1')
1013-
if dataset_input_cache_dict is not None:
1014-
cache = {span_0_key: dataset_input_cache_dict}
1089+
def test_optimize_traversal(
1090+
self, feature_spec: Mapping[str, common_types.FeatureSpecType],
1091+
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
1092+
Mapping[str, common_types.TensorType]],
1093+
dataset_input_cache_dicts: List[Mapping[str, str]],
1094+
expected_dot_graph_str: str):
1095+
dataset_keys = [
1096+
analyzer_cache.DatasetKey('span-0'),
1097+
analyzer_cache.DatasetKey('span-1')
1098+
]
1099+
if dataset_input_cache_dicts is not None:
1100+
cache = {
1101+
key: cache_dict
1102+
for key, cache_dict in zip(dataset_keys, dataset_input_cache_dicts)
1103+
}
10151104
else:
10161105
cache = {}
10171106
dot_string = self._publish_rendered_dot_graph_file(preprocessing_fn,
10181107
feature_spec,
1019-
{span_0_key, span_1_key},
1020-
cache)
1108+
set(dataset_keys), cache)
10211109

10221110
self.assertSameElements(
10231111
expected_dot_graph_str.split('\n'),

0 commit comments

Comments
 (0)