@@ -95,6 +95,56 @@ def call(self, inputs):
95
95
# Test with a different batch size
96
96
revived_model .serve (tf .random .normal ((6 , 10 )))
97
97
98
+ @parameterized .named_parameters (
99
+ named_product (struct_type = ["tuple" , "array" , "dict" ])
100
+ )
101
+ def test_model_with_input_structure (self , struct_type ):
102
+
103
+ class TupleModel (models .Model ):
104
+
105
+ def call (self , inputs ):
106
+ x , y = inputs
107
+ return ops .add (x , y )
108
+
109
+ class ArrayModel (models .Model ):
110
+
111
+ def call (self , inputs ):
112
+ x = inputs [0 ]
113
+ y = inputs [1 ]
114
+ return ops .add (x , y )
115
+
116
+ class DictModel (models .Model ):
117
+
118
+ def call (self , inputs ):
119
+ x = inputs ["x" ]
120
+ y = inputs ["y" ]
121
+ return ops .add (x , y )
122
+
123
+ if struct_type == "tuple" :
124
+ model = TupleModel ()
125
+ ref_input = (tf .random .normal ((3 , 10 )), tf .random .normal ((3 , 10 )))
126
+ elif struct_type == "array" :
127
+ model = ArrayModel ()
128
+ ref_input = [tf .random .normal ((3 , 10 )), tf .random .normal ((3 , 10 ))]
129
+ elif struct_type == "dict" :
130
+ model = DictModel ()
131
+ ref_input = {
132
+ "x" : tf .random .normal ((3 , 10 )),
133
+ "y" : tf .random .normal ((3 , 10 )),
134
+ }
135
+
136
+ temp_filepath = os .path .join (self .get_temp_dir (), "exported_model" )
137
+ ref_output = model (tree .map_structure (ops .convert_to_tensor , ref_input ))
138
+
139
+ export_lib .export_model (model , temp_filepath )
140
+ revived_model = tf .saved_model .load (temp_filepath )
141
+ self .assertAllClose (ref_output , revived_model .serve (ref_input ))
142
+ # Test with a different batch size
143
+ bigger_input = tree .map_structure (
144
+ lambda x : tf .concat ([x , x ], axis = 0 ), ref_input
145
+ )
146
+ revived_model .serve (bigger_input )
147
+
98
148
@parameterized .named_parameters (
99
149
named_product (model_type = ["sequential" , "functional" , "subclass" ])
100
150
)
0 commit comments