|
18 | 18 | import functools
|
19 | 19 | import os
|
20 | 20 | import struct
|
| 21 | +from typing import Callable, Mapping, List |
21 | 22 | import apache_beam as beam
|
22 | 23 | from apache_beam.testing import util as beam_test_util
|
23 | 24 | import numpy as np
|
24 | 25 |
|
25 | 26 | import tensorflow as tf
|
26 | 27 | import tensorflow_transform as tft
|
27 | 28 | from tensorflow_transform import analyzer_nodes
|
| 29 | +from tensorflow_transform import common_types |
28 | 30 | from tensorflow_transform import impl_helper
|
29 | 31 | from tensorflow_transform import nodes
|
30 | 32 | import tensorflow_transform.beam as tft_beam
|
@@ -73,10 +75,10 @@ def _preprocessing_fn_for_common_optimize_traversal(inputs):
|
73 | 75 | 's': tf.io.FixedLenFeature([], tf.string)
|
74 | 76 | },
|
75 | 77 | preprocessing_fn=_preprocessing_fn_for_common_optimize_traversal,
|
76 |
| - dataset_input_cache_dict={ |
| 78 | + dataset_input_cache_dicts=[{ |
77 | 79 | _make_cache_key(b'CacheableCombineAccumulate[x#mean_and_var]'):
|
78 | 80 | 'cache hit',
|
79 |
| - }, |
| 81 | + }], |
80 | 82 | expected_dot_graph_str=r"""digraph G {
|
81 | 83 | directed=True;
|
82 | 84 | node [shape=Mrecord];
|
@@ -165,6 +167,83 @@ def _preprocessing_fn_for_common_optimize_traversal(inputs):
|
165 | 167 | }
|
166 | 168 | """)
|
167 | 169 |
|
| 170 | +_OPTIMIZE_TRAVERSAL_MULTI_PHASE_FULL_CACHE_HIT_CASE = dict( |
| 171 | + testcase_name='multi_phase_full_cache_coverage', |
| 172 | + feature_spec={ |
| 173 | + 'x': tf.io.FixedLenFeature([], tf.float32), |
| 174 | + 's': tf.io.FixedLenFeature([], tf.string) |
| 175 | + }, |
| 176 | + preprocessing_fn=_preprocessing_fn_for_common_optimize_traversal, |
| 177 | + dataset_input_cache_dicts=[{ |
| 178 | + _make_cache_key(b'CacheableCombineAccumulate[x#mean_and_var]'): |
| 179 | + 'cache hit', |
| 180 | + _make_cache_key(b'VocabularyAccumulate[vocabulary]'): |
| 181 | + 'cache hit', |
| 182 | + }] * 2, |
| 183 | + expected_dot_graph_str=r"""digraph G { |
| 184 | +directed=True; |
| 185 | +node [shape=Mrecord]; |
| 186 | +"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex0]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-0')|cache_key: \<bytes\>|coder: \<_VocabularyAccumulatorCoder\>|label: DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex0]|partitionable: True}"]; |
| 187 | +"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-1')|cache_key: \<bytes\>|coder: \<_VocabularyAccumulatorCoder\>|label: DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]|partitionable: True}"]; |
| 188 | +"FlattenCache[VocabularyMerge[vocabulary]]" [label="{Flatten|label: FlattenCache[VocabularyMerge[vocabulary]]|partitionable: True}"]; |
| 189 | +"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex0]" -> "FlattenCache[VocabularyMerge[vocabulary]]"; |
| 190 | +"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]" -> "FlattenCache[VocabularyMerge[vocabulary]]"; |
| 191 | +"VocabularyMerge[vocabulary]" [label="{VocabularyMerge|vocab_ordering_type: 1|use_adjusted_mutual_info: False|min_diff_from_avg: None|label: VocabularyMerge[vocabulary]}"]; |
| 192 | +"FlattenCache[VocabularyMerge[vocabulary]]" -> "VocabularyMerge[vocabulary]"; |
| 193 | +"VocabularyCount[vocabulary]" [label="{VocabularyCount|label: VocabularyCount[vocabulary]}"]; |
| 194 | +"VocabularyMerge[vocabulary]" -> "VocabularyCount[vocabulary]"; |
| 195 | +"CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]" [label="{CreateTensorBinding|tensor: vocabulary/vocab_vocabulary_unpruned_vocab_size:0|is_asset_filepath: False|label: CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]}"]; |
| 196 | +"VocabularyCount[vocabulary]" -> "CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]"; |
| 197 | +"VocabularyPrune[vocabulary]" [label="{VocabularyPrune|top_k: None|frequency_threshold: 0|informativeness_threshold: -inf|coverage_top_k: None|coverage_frequency_threshold: 0|coverage_informativeness_threshold: -inf|key_fn: None|filter_newline_characters: True|input_dtype: string|label: VocabularyPrune[vocabulary]}"]; |
| 198 | +"VocabularyMerge[vocabulary]" -> "VocabularyPrune[vocabulary]"; |
| 199 | +"VocabularyOrderAndWrite[vocabulary]" [label="{VocabularyOrderAndWrite|vocab_filename: vocab_vocabulary|store_frequency: False|input_dtype: string|label: VocabularyOrderAndWrite[vocabulary]|fingerprint_shuffle: False|file_format: text|input_is_sorted: False}"]; |
| 200 | +"VocabularyPrune[vocabulary]" -> "VocabularyOrderAndWrite[vocabulary]"; |
| 201 | +"CreateTensorBinding[vocabulary#Placeholder]" [label="{CreateTensorBinding|tensor: vocabulary/Placeholder:0|is_asset_filepath: True|label: CreateTensorBinding[vocabulary#Placeholder]}"]; |
| 202 | +"VocabularyOrderAndWrite[vocabulary]" -> "CreateTensorBinding[vocabulary#Placeholder]"; |
| 203 | +"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex0]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-0')|cache_key: \<bytes\>|coder: \<JsonNumpyCacheCoder\>|label: DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex0]|partitionable: True}"]; |
| 204 | +"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex1]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-1')|cache_key: \<bytes\>|coder: \<JsonNumpyCacheCoder\>|label: DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex1]|partitionable: True}"]; |
| 205 | +"FlattenCache[CacheableCombineMerge[x#mean_and_var]]" [label="{Flatten|label: FlattenCache[CacheableCombineMerge[x#mean_and_var]]|partitionable: True}"]; |
| 206 | +"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex0]" -> "FlattenCache[CacheableCombineMerge[x#mean_and_var]]"; |
| 207 | +"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex1]" -> "FlattenCache[CacheableCombineMerge[x#mean_and_var]]"; |
| 208 | +"CacheableCombineMerge[x#mean_and_var]" [label="{CacheableCombineMerge|combiner: \<WeightedMeanAndVarCombiner\>|label: CacheableCombineMerge[x#mean_and_var]}"]; |
| 209 | +"FlattenCache[CacheableCombineMerge[x#mean_and_var]]" -> "CacheableCombineMerge[x#mean_and_var]"; |
| 210 | +"ExtractCombineMergeOutputs[x#mean_and_var]" [label="{ExtractCombineMergeOutputs|output_tensor_info_list: [TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None), TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None)]|label: ExtractCombineMergeOutputs[x#mean_and_var]|{<0>0|<1>1}}"]; |
| 211 | +"CacheableCombineMerge[x#mean_and_var]" -> "ExtractCombineMergeOutputs[x#mean_and_var]"; |
| 212 | +"CreateTensorBinding[x#mean_and_var#Placeholder]" [label="{CreateTensorBinding|tensor: x/mean_and_var/Placeholder:0|is_asset_filepath: False|label: CreateTensorBinding[x#mean_and_var#Placeholder]}"]; |
| 213 | +"ExtractCombineMergeOutputs[x#mean_and_var]":0 -> "CreateTensorBinding[x#mean_and_var#Placeholder]"; |
| 214 | +"CreateTensorBinding[x#mean_and_var#Placeholder_1]" [label="{CreateTensorBinding|tensor: x/mean_and_var/Placeholder_1:0|is_asset_filepath: False|label: CreateTensorBinding[x#mean_and_var#Placeholder_1]}"]; |
| 215 | +"ExtractCombineMergeOutputs[x#mean_and_var]":1 -> "CreateTensorBinding[x#mean_and_var#Placeholder_1]"; |
| 216 | +"CreateSavedModelForAnalyzerInputs[Phase1]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x_square_deviations/mean_and_var/Cast_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x_square_deviations/mean_and_var/div_no_nan', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x_square_deviations/mean_and_var/div_no_nan_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x_square_deviations/mean_and_var/zeros', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase1]}"]; |
| 217 | +"CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]" -> "CreateSavedModelForAnalyzerInputs[Phase1]"; |
| 218 | +"CreateTensorBinding[vocabulary#Placeholder]" -> "CreateSavedModelForAnalyzerInputs[Phase1]"; |
| 219 | +"CreateTensorBinding[x#mean_and_var#Placeholder]" -> "CreateSavedModelForAnalyzerInputs[Phase1]"; |
| 220 | +"CreateTensorBinding[x#mean_and_var#Placeholder_1]" -> "CreateSavedModelForAnalyzerInputs[Phase1]"; |
| 221 | +"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"]; |
| 222 | +"ApplySavedModel[Phase1]" [label="{ApplySavedModel|phase: 1|label: ApplySavedModel[Phase1]|partitionable: True}"]; |
| 223 | +"CreateSavedModelForAnalyzerInputs[Phase1]" -> "ApplySavedModel[Phase1]"; |
| 224 | +"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase1]"; |
| 225 | +"TensorSource[x_square_deviations#mean_and_var]" [label="{ExtractFromDict|keys: ('x_square_deviations/mean_and_var/Cast_1', 'x_square_deviations/mean_and_var/div_no_nan', 'x_square_deviations/mean_and_var/div_no_nan_1', 'x_square_deviations/mean_and_var/zeros')|label: TensorSource[x_square_deviations#mean_and_var]|partitionable: True}"]; |
| 226 | +"ApplySavedModel[Phase1]" -> "TensorSource[x_square_deviations#mean_and_var]"; |
| 227 | +"CacheableCombineAccumulate[x_square_deviations#mean_and_var]" [label="{CacheableCombineAccumulate|combiner: \<WeightedMeanAndVarCombiner\>|label: CacheableCombineAccumulate[x_square_deviations#mean_and_var]|partitionable: True}"]; |
| 228 | +"TensorSource[x_square_deviations#mean_and_var]" -> "CacheableCombineAccumulate[x_square_deviations#mean_and_var]"; |
| 229 | +"CacheableCombineMerge[x_square_deviations#mean_and_var]" [label="{CacheableCombineMerge|combiner: \<WeightedMeanAndVarCombiner\>|label: CacheableCombineMerge[x_square_deviations#mean_and_var]}"]; |
| 230 | +"CacheableCombineAccumulate[x_square_deviations#mean_and_var]" -> "CacheableCombineMerge[x_square_deviations#mean_and_var]"; |
| 231 | +"ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]" [label="{ExtractCombineMergeOutputs|output_tensor_info_list: [TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None), TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None)]|label: ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]|{<0>0|<1>1}}"]; |
| 232 | +"CacheableCombineMerge[x_square_deviations#mean_and_var]" -> "ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]"; |
| 233 | +"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]" [label="{CreateTensorBinding|tensor: x_square_deviations/mean_and_var/Placeholder:0|is_asset_filepath: False|label: CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]}"]; |
| 234 | +"ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]":0 -> "CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]"; |
| 235 | +"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]" [label="{CreateTensorBinding|tensor: x_square_deviations/mean_and_var/Placeholder_1:0|is_asset_filepath: False|label: CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]}"]; |
| 236 | +"ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]":1 -> "CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]"; |
| 237 | +CreateSavedModel [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x_normalized', \"Tensor\<shape: [None], \<dtype: 'float32'\>\>\")])|label: CreateSavedModel}"]; |
| 238 | +"CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]" -> CreateSavedModel; |
| 239 | +"CreateTensorBinding[vocabulary#Placeholder]" -> CreateSavedModel; |
| 240 | +"CreateTensorBinding[x#mean_and_var#Placeholder]" -> CreateSavedModel; |
| 241 | +"CreateTensorBinding[x#mean_and_var#Placeholder_1]" -> CreateSavedModel; |
| 242 | +"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]" -> CreateSavedModel; |
| 243 | +"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]" -> CreateSavedModel; |
| 244 | +} |
| 245 | +""") |
| 246 | + |
168 | 247 | _TF_VERSION_NAMED_PARAMETERS = [
|
169 | 248 | dict(testcase_name='CompatV1', use_tf_compat_v1=True),
|
170 | 249 | dict(testcase_name='V2', use_tf_compat_v1=False),
|
@@ -264,7 +343,7 @@ def is_partitionable(self):
|
264 | 343 | testcase_name='generalized_chained_ptransforms',
|
265 | 344 | feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)},
|
266 | 345 | preprocessing_fn=_preprocessing_fn_for_generalized_chained_ptransforms,
|
267 |
| - dataset_input_cache_dict=None, |
| 346 | + dataset_input_cache_dicts=None, |
268 | 347 | expected_dot_graph_str=r"""digraph G {
|
269 | 348 | directed=True;
|
270 | 349 | node [shape=Mrecord];
|
@@ -334,6 +413,7 @@ def is_partitionable(self):
|
334 | 413 |
|
335 | 414 | _OPTIMIZE_TRAVERSAL_TEST_CASES = [
|
336 | 415 | _OPTIMIZE_TRAVERSAL_COMMON_CASE,
|
| 416 | + _OPTIMIZE_TRAVERSAL_MULTI_PHASE_FULL_CACHE_HIT_CASE, |
337 | 417 | _OPTIMIZE_TRAVERSAL_GENERALIZED_CHAINED_PTRANSFORMS_CASE,
|
338 | 418 | ]
|
339 | 419 |
|
@@ -1006,18 +1086,26 @@ def preprocessing_fn(inputs):
|
1006 | 1086 |
|
1007 | 1087 | @tft_unit.named_parameters(*_OPTIMIZE_TRAVERSAL_TEST_CASES)
|
1008 | 1088 | @mock_out_cache_hash
|
1009 |
| - def test_optimize_traversal(self, feature_spec, preprocessing_fn, |
1010 |
| - dataset_input_cache_dict, expected_dot_graph_str): |
1011 |
| - span_0_key, span_1_key = analyzer_cache.DatasetKey( |
1012 |
| - 'span-0'), analyzer_cache.DatasetKey('span-1') |
1013 |
| - if dataset_input_cache_dict is not None: |
1014 |
| - cache = {span_0_key: dataset_input_cache_dict} |
| 1089 | + def test_optimize_traversal( |
| 1090 | + self, feature_spec: Mapping[str, common_types.FeatureSpecType], |
| 1091 | + preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]], |
| 1092 | + Mapping[str, common_types.TensorType]], |
| 1093 | + dataset_input_cache_dicts: List[Mapping[str, str]], |
| 1094 | + expected_dot_graph_str: str): |
| 1095 | + dataset_keys = [ |
| 1096 | + analyzer_cache.DatasetKey('span-0'), |
| 1097 | + analyzer_cache.DatasetKey('span-1') |
| 1098 | + ] |
| 1099 | + if dataset_input_cache_dicts is not None: |
| 1100 | + cache = { |
| 1101 | + key: cache_dict |
| 1102 | + for key, cache_dict in zip(dataset_keys, dataset_input_cache_dicts) |
| 1103 | + } |
1015 | 1104 | else:
|
1016 | 1105 | cache = {}
|
1017 | 1106 | dot_string = self._publish_rendered_dot_graph_file(preprocessing_fn,
|
1018 | 1107 | feature_spec,
|
1019 |
| - {span_0_key, span_1_key}, |
1020 |
| - cache) |
| 1108 | + set(dataset_keys), cache) |
1021 | 1109 |
|
1022 | 1110 | self.assertSameElements(
|
1023 | 1111 | expected_dot_graph_str.split('\n'),
|
|
0 commit comments