19
19
from __future__ import division
20
20
from __future__ import print_function
21
21
import collections
22
+ import functools
22
23
import itertools
23
24
import os
24
25
import struct
@@ -1019,12 +1020,14 @@ def preprocessing_fn(inputs):
1019
1020
preprocessing_fn , pipeline = p ))
1020
1021
self .assertFalse (output_cache )
1021
1022
1022
- def test_tf_function_fails_cache (self ):
1023
+ def test_tf_function_works_with_cache (self ):
1023
1024
1024
- def preprocessing_fn (inputs ):
1025
+ def preprocessing_fn (inputs , should_add_one ):
1025
1026
1026
1027
@tf .function
1027
1028
def identity (x ):
1029
+ if should_add_one :
1030
+ x = x + 1
1028
1031
return x
1029
1032
1030
1033
return {
@@ -1035,8 +1038,9 @@ def identity(x):
1035
1038
1036
1039
feature_spec = {'x' : tf .io .FixedLenFeature ([], tf .float32 )}
1037
1040
input_data_dict = {'span-0' : [dict (x = - 2 ), dict (x = 4 )]}
1038
- run_result = self ._run_pipeline (feature_spec , input_data_dict ,
1039
- preprocessing_fn )
1041
+ run_result = self ._run_pipeline (
1042
+ feature_spec , input_data_dict ,
1043
+ functools .partial (preprocessing_fn , should_add_one = False ))
1040
1044
first_cache_output , p1 = run_result .cache_output , run_result .pipeline
1041
1045
1042
1046
for key in input_data_dict :
@@ -1050,12 +1054,100 @@ def identity(x):
1050
1054
_get_counter_value (p1 .metrics , 'saved_models_created' ),
1051
1055
_SINGLE_PHASE_NUM_SAVED_MODELS )
1052
1056
1057
+ # Cache is still valid since the contents of the tf.function are the same.
1058
+ run_result = self ._run_pipeline (
1059
+ feature_spec ,
1060
+ input_data_dict ,
1061
+ functools .partial (preprocessing_fn , should_add_one = False ),
1062
+ should_read_cache = True )
1063
+ second_cache_output , p2 = run_result .cache_output , run_result .pipeline
1064
+
1065
+ self .assertFalse (second_cache_output )
1066
+
1067
+ self .assertEqual (_get_counter_value (p2 .metrics , 'num_instances' ), 0 )
1068
+ self .assertEqual (_get_counter_value (p2 .metrics , 'cache_entries_decoded' ), 1 )
1069
+ self .assertEqual (_get_counter_value (p2 .metrics , 'cache_entries_encoded' ), 0 )
1070
+ self .assertEqual (
1071
+ _get_counter_value (p2 .metrics , 'saved_models_created' ),
1072
+ _ZERO_PHASE_NUM_SAVED_MODELS )
1073
+
1074
+ self .assertEqual (_get_counter_value (p2 .metrics , 'num_instances' ), 0 )
1075
+ self .assertEqual (_get_counter_value (p2 .metrics , 'cache_entries_decoded' ), 1 )
1076
+ self .assertEqual (_get_counter_value (p2 .metrics , 'cache_entries_encoded' ), 0 )
1077
+ self .assertEqual (_get_counter_value (p2 .metrics , 'saved_models_created' ), 1 )
1078
+
1079
+ # Modifying the tf.function contents causes cache invalidation.
1080
+ run_result = self ._run_pipeline (
1081
+ feature_spec ,
1082
+ input_data_dict ,
1083
+ functools .partial (preprocessing_fn , should_add_one = True ),
1084
+ should_read_cache = True )
1085
+ third_output_cache , p3 = run_result .cache_output , run_result .pipeline
1086
+
1087
+ for key in input_data_dict :
1088
+ self .assertIn (key , third_output_cache )
1089
+ self .assertEqual (1 , len (third_output_cache [key ]))
1090
+
1091
+ self .assertEqual (_get_counter_value (p3 .metrics , 'num_instances' ), 2 )
1092
+ self .assertEqual (_get_counter_value (p3 .metrics , 'cache_entries_decoded' ), 0 )
1093
+ self .assertEqual (_get_counter_value (p3 .metrics , 'cache_entries_encoded' ), 1 )
1094
+ self .assertEqual (_get_counter_value (p3 .metrics , 'saved_models_created' ), 2 )
1095
+
1096
+ def test_incomplete_graphs_fail_cache (self ):
1097
+
1098
+ def preprocessing_fn (inputs ):
1099
+ # Subtract 10 from x using a tf.while_loop.
1100
+ @tf .function (input_signature = [
1101
+ tf .TensorSpec ([], tf .int32 ),
1102
+ tf .TensorSpec ([], tf .int64 )
1103
+ ])
1104
+ def stop_condition (counter , x_minus_counter ):
1105
+ del x_minus_counter # unused
1106
+ return tf .less (counter , 10 )
1107
+
1108
+ @tf .function (input_signature = [
1109
+ tf .TensorSpec ([], tf .int32 ),
1110
+ tf .TensorSpec ([], tf .int64 )
1111
+ ])
1112
+ def iteration (counter , x_minus_counter ):
1113
+ return tf .add (counter , 1 ), tf .add (x_minus_counter , - 1 )
1114
+
1115
+ initial_values = [tf .constant (0 ), inputs ['x' ]]
1116
+ final_values = tf .raw_ops .While (
1117
+ cond = stop_condition .get_concrete_function (),
1118
+ body = iteration .get_concrete_function (),
1119
+ input = initial_values )
1120
+
1121
+ y = final_values [1 ]
1122
+
1123
+ return {'y' : tft .mean (y ) + tf .zeros_like (inputs ['x' ], dtype = tf .float32 )}
1124
+
1125
+ feature_spec = {
1126
+ 'x' : tf .io .FixedLenFeature ([], tf .int64 ),
1127
+ }
1128
+ input_data_dict = {
1129
+ 'span-0' : [dict (x = - 2 ), dict (x = 4 )],
1130
+ }
1053
1131
run_result = self ._run_pipeline (feature_spec , input_data_dict ,
1054
1132
preprocessing_fn )
1133
+ first_cache_output , p1 = run_result .cache_output , run_result .pipeline
1134
+
1135
+ for key in input_data_dict :
1136
+ self .assertIn (key , first_cache_output )
1137
+ self .assertEqual (1 , len (first_cache_output [key ]))
1138
+
1139
+ self .assertEqual (_get_counter_value (p1 .metrics , 'num_instances' ), 2 )
1140
+ self .assertEqual (_get_counter_value (p1 .metrics , 'cache_entries_decoded' ), 0 )
1141
+ self .assertEqual (_get_counter_value (p1 .metrics , 'cache_entries_encoded' ), 1 )
1142
+ self .assertEqual (
1143
+ _get_counter_value (p1 .metrics , 'saved_models_created' ),
1144
+ _SINGLE_PHASE_NUM_SAVED_MODELS )
1145
+
1146
+ run_result = self ._run_pipeline (
1147
+ feature_spec , input_data_dict , preprocessing_fn , should_read_cache = True )
1055
1148
second_cache_output , p2 = run_result .cache_output , run_result .pipeline
1056
1149
1057
- # We expect a full output cache again because tf.function in the
1058
- # preprocessing_fn broke that cache entry.
1150
+ # We expect the cache to fail here because the tf.function is now different.
1059
1151
for key in input_data_dict :
1060
1152
self .assertIn (key , second_cache_output )
1061
1153
self .assertEqual (1 , len (second_cache_output [key ]))
0 commit comments