@@ -29,11 +29,21 @@ def new_field_desc(self):
2929 vocabulary = ["a" , "b" , "c" ])
3030 return desc
3131
32- def check_from_dict (self , feature_column ):
32+ def check_serialize (self , feature_column ):
3333 d = fc .FeatureColumn .to_dict (feature_column )
34- new_fc = fc .FeatureColumn .from_dict (d )
34+ new_fc = fc .FeatureColumn .from_dict_or_feature_column (d )
35+ new_d = fc .FeatureColumn .to_dict (new_fc )
36+ typ = type (feature_column )
37+ self .assertEqual (typ , type (new_fc ))
38+ self .assertEqual (typ .__name__ , d ["type" ])
39+ self .assertEqual (d , new_d )
40+
41+ dump_json = json .dumps (feature_column ,
42+ cls = fc .JSONEncoderWithFeatureColumn )
43+ new_fc = json .loads (dump_json , cls = fc .JSONDecoderWithFeatureColumn )
3544 new_d = fc .FeatureColumn .to_dict (new_fc )
3645 self .assertEqual (type (feature_column ), type (new_fc ))
46+ self .assertEqual (typ , type (new_fc ))
3747 self .assertEqual (d , new_d )
3848
3949 def test_field_desc (self ):
@@ -45,7 +55,8 @@ def test_field_desc(self):
4555 self .assertEqual (json_desc ["format" ], desc .format )
4656 self .assertEqual (json_desc ["shape" ], desc .shape )
4757 self .assertEqual (json_desc ["is_sparse" ], desc .is_sparse )
48- self .assertEqual (json_desc ["vocabulary" ], desc .vocabulary )
58+ vocab = set (json_desc ["vocabulary" ])
59+ self .assertEqual (vocab , desc .vocabulary )
4960 self .assertEqual (json_desc ["max_id" ], desc .max_id )
5061
5162 def test_feature_column_subclass (self ):
@@ -73,13 +84,13 @@ def test_numeric_column(self):
7384
7485 d1 = fc .FeatureColumn .to_dict (nc1 )
7586 self .assertEqual (d1 ["type" ], "NumericColumn" )
76- self .assertEqual (d1 ["field_desc" ], desc1 .to_dict ())
77- self .check_from_dict (nc1 )
87+ self .assertEqual (d1 ["value" ][ " field_desc" ], desc1 .to_dict ())
88+ self .check_serialize (nc1 )
7889
7990 d2 = fc .FeatureColumn .to_dict (nc2 )
8091 self .assertEqual (d2 ["type" ], "NumericColumn" )
81- self .assertEqual (d2 ["field_desc" ], desc2 .to_dict ())
82- self .check_from_dict (nc2 )
92+ self .assertEqual (d2 ["value" ][ " field_desc" ], desc2 .to_dict ())
93+ self .check_serialize (nc2 )
8394
8495 def test_bucket_column (self ):
8596 desc = self .new_field_desc ()
@@ -93,10 +104,11 @@ def test_bucket_column(self):
93104 self .assertEqual (bc .get_field_desc ()[0 ].to_json (), desc .to_json ())
94105 d = fc .FeatureColumn .to_dict (bc )
95106 self .assertEqual (d ["type" ], "BucketColumn" )
96- self .assertEqual (d ["boundaries" ], boundaries )
97- self .assertEqual (d ["source_column" ]["type" ], "NumericColumn" )
98- self .assertEqual (d ["source_column" ]["field_desc" ], desc .to_dict ())
99- self .check_from_dict (bc )
107+ self .assertEqual (d ["value" ]["boundaries" ], boundaries )
108+ self .assertEqual (d ["value" ]["source_column" ]["type" ], "NumericColumn" )
109+ self .assertEqual (d ["value" ]["source_column" ]["value" ]["field_desc" ],
110+ desc .to_dict ())
111+ self .check_serialize (bc )
100112
101113 bc = bc .new_feature_column_from (desc )
102114 self .assertTrue (isinstance (bc , fc .BucketColumn ))
@@ -106,10 +118,11 @@ def test_bucket_column(self):
106118 self .assertEqual (bc .get_field_desc ()[0 ].to_json (), desc .to_json ())
107119 d = fc .FeatureColumn .to_dict (bc )
108120 self .assertEqual (d ["type" ], "BucketColumn" )
109- self .assertEqual (d ["boundaries" ], boundaries )
110- self .assertEqual (d ["source_column" ]["type" ], "NumericColumn" )
111- self .assertEqual (d ["source_column" ]["field_desc" ], desc .to_dict ())
112- self .check_from_dict (bc )
121+ self .assertEqual (d ["value" ]["boundaries" ], boundaries )
122+ self .assertEqual (d ["value" ]["source_column" ]["type" ], "NumericColumn" )
123+ self .assertEqual (d ["value" ]["source_column" ]["value" ]["field_desc" ],
124+ desc .to_dict ())
125+ self .check_serialize (bc )
113126
114127 def test_category_column (self ):
115128 desc = self .new_field_desc ()
@@ -126,9 +139,9 @@ def test_category_column(self):
126139
127140 d = fc .FeatureColumn .to_dict (cc )
128141 self .assertEqual (d ["type" ], fc_class .__name__ )
129- self .assertEqual (d ["field_desc" ], desc .to_dict ())
130- self .assertEqual (d ["bucket_size" ], bucket_size )
131- self .check_from_dict (cc )
142+ self .assertEqual (d ["value" ][ " field_desc" ], desc .to_dict ())
143+ self .assertEqual (d ["value" ][ " bucket_size" ], bucket_size )
144+ self .check_serialize (cc )
132145
133146 cc = cc .new_feature_column_from (desc )
134147 self .assertTrue (isinstance (cc , fc_class ))
@@ -138,9 +151,9 @@ def test_category_column(self):
138151
139152 d = fc .FeatureColumn .to_dict (cc )
140153 self .assertEqual (d ["type" ], fc_class .__name__ )
141- self .assertEqual (d ["field_desc" ], desc .to_dict ())
142- self .assertEqual (d ["bucket_size" ], bucket_size )
143- self .check_from_dict (cc )
154+ self .assertEqual (d ["value" ][ " field_desc" ], desc .to_dict ())
155+ self .assertEqual (d ["value" ][ " bucket_size" ], bucket_size )
156+ self .check_serialize (cc )
144157
145158 def test_cross_column (self ):
146159 desc = self .new_field_desc ()
@@ -155,13 +168,13 @@ def test_cross_column(self):
155168
156169 d = fc .FeatureColumn .to_dict (cc )
157170 self .assertEqual (d ["type" ], "CrossColumn" )
158- keys = d ["keys" ]
171+ keys = d ["value" ][ " keys" ]
159172 self .assertEqual (len (keys ), 2 )
160173 self .assertEqual (keys [0 ]["type" ], "NumericColumn" )
161- self .assertEqual (keys [0 ]["field_desc" ], desc .to_dict ())
174+ self .assertEqual (keys [0 ]["value" ][ " field_desc" ], desc .to_dict ())
162175 self .assertEqual (keys [1 ], "cross_feature_2" )
163- self .assertEqual (d ["hash_bucket_size" ], hash_bucket_size )
164- self .check_from_dict (cc )
176+ self .assertEqual (d ["value" ][ " hash_bucket_size" ], hash_bucket_size )
177+ self .check_serialize (cc )
165178
166179 def test_embedding_and_indicator_column (self ):
167180 desc = self .new_field_desc ()
@@ -179,13 +192,15 @@ def test_embedding_and_indicator_column(self):
179192
180193 d = fc .FeatureColumn .to_dict (fc1 )
181194 self .assertEqual (d ["type" ], fc_class .__name__ )
182- self .assertEqual (d ["name" ], "" )
183- self .assertEqual (d ["category_column" ]["type" ],
195+ self .assertEqual (d ["value" ][ " name" ], "" )
196+ self .assertEqual (d ["value" ][ " category_column" ]["type" ],
184197 "CategoryHashColumn" )
185- self .assertEqual (d ["category_column" ]["field_desc" ],
186- desc .to_dict ())
187- self .assertEqual (d ["category_column" ]["bucket_size" ], 4096 )
188- self .check_from_dict (fc1 )
198+ self .assertEqual (
199+ d ["value" ]["category_column" ]["value" ]["field_desc" ],
200+ desc .to_dict ())
201+ self .assertEqual (
202+ d ["value" ]["category_column" ]["value" ]["bucket_size" ], 4096 )
203+ self .check_serialize (fc1 )
189204
190205 fc2 = fc_class (category_column = None , name = "my_category_column" )
191206 fc2_descs = fc2 .get_field_desc ()
@@ -197,9 +212,9 @@ def test_embedding_and_indicator_column(self):
197212
198213 d = fc .FeatureColumn .to_dict (fc2 )
199214 self .assertEqual (d ["type" ], fc_class .__name__ )
200- self .assertEqual (d ["name" ], "my_category_column" )
201- self .assertEqual (d ["category_column" ], None )
202- self .check_from_dict (fc2 )
215+ self .assertEqual (d ["value" ][ " name" ], "my_category_column" )
216+ self .assertEqual (d ["value" ][ " category_column" ], None )
217+ self .check_serialize (fc2 )
203218
204219
205220if __name__ == '__main__' :
0 commit comments