Skip to content

Commit ee93f21

Browse files
authored
Add json dump and load support for FeatureColumn (#2794)
* add json dump load support * update vocabulary type * update * update * update
1 parent ae288e3 commit ee93f21

File tree

4 files changed

+180
-55
lines changed

4 files changed

+180
-55
lines changed

python/runtime/feature/column.py

Lines changed: 93 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import json
15+
1416
import six
1517
from runtime.feature.field_desc import DataType, FieldDesc
1618

@@ -56,9 +58,10 @@ def to_dict(cls, feature_column):
5658
Returns:
5759
A Python dict which represents the FeatureColumn object.
5860
"""
59-
d = feature_column._to_dict()
60-
d["type"] = type(feature_column).__name__
61-
return d
61+
return {
62+
"type": type(feature_column).__name__,
63+
"value": feature_column._to_dict(),
64+
}
6265

6366
def _to_dict(self):
6467
"""
@@ -70,19 +73,25 @@ def _to_dict(self):
7073
raise NotImplementedError()
7174

7275
@classmethod
73-
def from_dict(cls, d):
76+
def from_dict_or_feature_column(cls, obj):
7477
"""
75-
Create a FeatureColumn object from a Python dict. It can
76-
be used to deserialize a FeatureColumn object from a JSON string.
78+
If obj is of type dict, create a FeatureColumn object from a Python
79+
dict. If obj is of type FeatureColumn, return itself. This method
80+
can be used to deserialize a FeatureColumn object from a JSON string.
7781
7882
Args:
79-
d (dict): a Python dict object.
83+
obj (dict|FeatureColumn): a Python dict or FeatureColumn object.
8084
8185
Returns:
8286
A FeatureColumn object.
8387
"""
84-
typ = d.get("type")
85-
return eval(typ)._from_dict(d)
88+
if isinstance(obj, dict):
89+
typ = obj.get("type")
90+
return eval(typ)._from_dict(obj.get("value"))
91+
elif isinstance(obj, FeatureColumn):
92+
return obj
93+
else:
94+
raise TypeError("not supported type %s" % type(obj))
8695

8796
@classmethod
8897
def _from_dict(self, d):
@@ -169,14 +178,14 @@ def num_class(self):
169178

170179
def _to_dict(self):
171180
return {
172-
"type": "BucketColumn",
173181
"source_column": FeatureColumn.to_dict(self.source_column),
174182
"boundaries": self.boundaries,
175183
}
176184

177185
@classmethod
178186
def _from_dict(cls, d):
179-
source_column = FeatureColumn.from_dict(d["source_column"])
187+
source_column = FeatureColumn.from_dict_or_feature_column(
188+
d["source_column"])
180189
boundaries = d["boundaries"]
181190
return BucketColumn(source_column, boundaries)
182191

@@ -240,7 +249,6 @@ def num_class(self):
240249

241250
def _to_dict(self):
242251
return {
243-
"type": "CategoryHashColumn",
244252
"field_desc": self.field_desc.to_dict(),
245253
"bucket_size": self.bucket_size,
246254
}
@@ -343,7 +351,7 @@ def _from_dict(cls, d):
343351
if isinstance(k, six.string_types):
344352
keys.append(k)
345353
else:
346-
keys.append(FeatureColumn.from_dict(k))
354+
keys.append(FeatureColumn.from_dict_or_feature_column(k))
347355

348356
hash_bucket_size = d["hash_bucket_size"]
349357
return CrossColumn(keys, hash_bucket_size)
@@ -414,7 +422,8 @@ def _to_dict(self):
414422
def _from_dict(cls, d):
415423
category_column = d["category_column"]
416424
if category_column is not None:
417-
category_column = FeatureColumn.from_dict(category_column)
425+
category_column = FeatureColumn.from_dict_or_feature_column(
426+
category_column)
418427

419428
return EmbeddingColumn(category_column=category_column,
420429
dimension=d["dimension"],
@@ -469,6 +478,75 @@ def _to_dict(self):
469478
def _from_dict(cls, d):
470479
category_column = d["category_column"]
471480
if category_column is not None:
472-
category_column = FeatureColumn.from_dict(category_column)
481+
category_column = FeatureColumn.from_dict_or_feature_column(
482+
category_column)
473483

474484
return IndicatorColumn(category_column=category_column, name=d["name"])
485+
486+
487+
class JSONEncoderWithFeatureColumn(json.JSONEncoder):
488+
"""
489+
A helper class to serialize FeatureColumn objects to JSON string.
490+
"""
491+
def default(self, obj):
492+
"""
493+
Convert obj to an object that `json.dumps` accepts.
494+
If obj is of type FeatureColumn, convert it to a Python
495+
dict.
496+
497+
Args:
498+
obj: any Python object.
499+
500+
Returns:
501+
A Python object that `json.dumps` accepts.
502+
"""
503+
if isinstance(obj, FeatureColumn):
504+
return FeatureColumn.to_dict(obj)
505+
506+
# Use the default JSONEncoder if obj is not FeatureColumn
507+
return json.JSONEncoder.default(self, obj)
508+
509+
510+
SUPPORTED_CONCRETE_FEATURE_COLUMNS = [
511+
'NumericColumn',
512+
'BucketColumn',
513+
'CategoryIDColumn',
514+
'CategoryHashColumn',
515+
'SeqCategoryIDColumn',
516+
'CrossColumn',
517+
'EmbeddingColumn',
518+
'IndicatorColumn',
519+
]
520+
521+
522+
def feature_column_json_hook(obj):
523+
"""
524+
An object hook method that json.JSONDecoder accepts.
525+
It is used to convert a Python dict to FeatureColumn object
526+
if possible. See https://docs.python.org/3/library/json.html
527+
for the usage of object hook.
528+
529+
Args:
530+
obj: any Python object.
531+
532+
Returns:
533+
If obj can be converted to a FeatureColumn object, convert
534+
it. Otherwise, return itself.
535+
"""
536+
if isinstance(obj, dict):
537+
typ = obj.get("type")
538+
if typ in SUPPORTED_CONCRETE_FEATURE_COLUMNS:
539+
return FeatureColumn.from_dict_or_feature_column(obj)
540+
541+
return obj
542+
543+
544+
class JSONDecoderWithFeatureColumn(json.JSONDecoder):
545+
"""
546+
A helper class to deserialize JSON string to FeatureColumn objects.
547+
"""
548+
def __init__(self, *args, **kwargs):
549+
# See here: https://docs.python.org/3/library/json.html
550+
# for the usage of object_hook
551+
kwargs['object_hook'] = feature_column_json_hook
552+
super(JSONDecoderWithFeatureColumn, self).__init__(*args, **kwargs)

python/runtime/feature/column_test.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,21 @@ def new_field_desc(self):
2929
vocabulary=["a", "b", "c"])
3030
return desc
3131

32-
def check_from_dict(self, feature_column):
32+
def check_serialize(self, feature_column):
3333
d = fc.FeatureColumn.to_dict(feature_column)
34-
new_fc = fc.FeatureColumn.from_dict(d)
34+
new_fc = fc.FeatureColumn.from_dict_or_feature_column(d)
35+
new_d = fc.FeatureColumn.to_dict(new_fc)
36+
typ = type(feature_column)
37+
self.assertEqual(typ, type(new_fc))
38+
self.assertEqual(typ.__name__, d["type"])
39+
self.assertEqual(d, new_d)
40+
41+
dump_json = json.dumps(feature_column,
42+
cls=fc.JSONEncoderWithFeatureColumn)
43+
new_fc = json.loads(dump_json, cls=fc.JSONDecoderWithFeatureColumn)
3544
new_d = fc.FeatureColumn.to_dict(new_fc)
3645
self.assertEqual(type(feature_column), type(new_fc))
46+
self.assertEqual(typ, type(new_fc))
3747
self.assertEqual(d, new_d)
3848

3949
def test_field_desc(self):
@@ -45,7 +55,8 @@ def test_field_desc(self):
4555
self.assertEqual(json_desc["format"], desc.format)
4656
self.assertEqual(json_desc["shape"], desc.shape)
4757
self.assertEqual(json_desc["is_sparse"], desc.is_sparse)
48-
self.assertEqual(json_desc["vocabulary"], desc.vocabulary)
58+
vocab = set(json_desc["vocabulary"])
59+
self.assertEqual(vocab, desc.vocabulary)
4960
self.assertEqual(json_desc["max_id"], desc.max_id)
5061

5162
def test_feature_column_subclass(self):
@@ -73,13 +84,13 @@ def test_numeric_column(self):
7384

7485
d1 = fc.FeatureColumn.to_dict(nc1)
7586
self.assertEqual(d1["type"], "NumericColumn")
76-
self.assertEqual(d1["field_desc"], desc1.to_dict())
77-
self.check_from_dict(nc1)
87+
self.assertEqual(d1["value"]["field_desc"], desc1.to_dict())
88+
self.check_serialize(nc1)
7889

7990
d2 = fc.FeatureColumn.to_dict(nc2)
8091
self.assertEqual(d2["type"], "NumericColumn")
81-
self.assertEqual(d2["field_desc"], desc2.to_dict())
82-
self.check_from_dict(nc2)
92+
self.assertEqual(d2["value"]["field_desc"], desc2.to_dict())
93+
self.check_serialize(nc2)
8394

8495
def test_bucket_column(self):
8596
desc = self.new_field_desc()
@@ -93,10 +104,11 @@ def test_bucket_column(self):
93104
self.assertEqual(bc.get_field_desc()[0].to_json(), desc.to_json())
94105
d = fc.FeatureColumn.to_dict(bc)
95106
self.assertEqual(d["type"], "BucketColumn")
96-
self.assertEqual(d["boundaries"], boundaries)
97-
self.assertEqual(d["source_column"]["type"], "NumericColumn")
98-
self.assertEqual(d["source_column"]["field_desc"], desc.to_dict())
99-
self.check_from_dict(bc)
107+
self.assertEqual(d["value"]["boundaries"], boundaries)
108+
self.assertEqual(d["value"]["source_column"]["type"], "NumericColumn")
109+
self.assertEqual(d["value"]["source_column"]["value"]["field_desc"],
110+
desc.to_dict())
111+
self.check_serialize(bc)
100112

101113
bc = bc.new_feature_column_from(desc)
102114
self.assertTrue(isinstance(bc, fc.BucketColumn))
@@ -106,10 +118,11 @@ def test_bucket_column(self):
106118
self.assertEqual(bc.get_field_desc()[0].to_json(), desc.to_json())
107119
d = fc.FeatureColumn.to_dict(bc)
108120
self.assertEqual(d["type"], "BucketColumn")
109-
self.assertEqual(d["boundaries"], boundaries)
110-
self.assertEqual(d["source_column"]["type"], "NumericColumn")
111-
self.assertEqual(d["source_column"]["field_desc"], desc.to_dict())
112-
self.check_from_dict(bc)
121+
self.assertEqual(d["value"]["boundaries"], boundaries)
122+
self.assertEqual(d["value"]["source_column"]["type"], "NumericColumn")
123+
self.assertEqual(d["value"]["source_column"]["value"]["field_desc"],
124+
desc.to_dict())
125+
self.check_serialize(bc)
113126

114127
def test_category_column(self):
115128
desc = self.new_field_desc()
@@ -126,9 +139,9 @@ def test_category_column(self):
126139

127140
d = fc.FeatureColumn.to_dict(cc)
128141
self.assertEqual(d["type"], fc_class.__name__)
129-
self.assertEqual(d["field_desc"], desc.to_dict())
130-
self.assertEqual(d["bucket_size"], bucket_size)
131-
self.check_from_dict(cc)
142+
self.assertEqual(d["value"]["field_desc"], desc.to_dict())
143+
self.assertEqual(d["value"]["bucket_size"], bucket_size)
144+
self.check_serialize(cc)
132145

133146
cc = cc.new_feature_column_from(desc)
134147
self.assertTrue(isinstance(cc, fc_class))
@@ -138,9 +151,9 @@ def test_category_column(self):
138151

139152
d = fc.FeatureColumn.to_dict(cc)
140153
self.assertEqual(d["type"], fc_class.__name__)
141-
self.assertEqual(d["field_desc"], desc.to_dict())
142-
self.assertEqual(d["bucket_size"], bucket_size)
143-
self.check_from_dict(cc)
154+
self.assertEqual(d["value"]["field_desc"], desc.to_dict())
155+
self.assertEqual(d["value"]["bucket_size"], bucket_size)
156+
self.check_serialize(cc)
144157

145158
def test_cross_column(self):
146159
desc = self.new_field_desc()
@@ -155,13 +168,13 @@ def test_cross_column(self):
155168

156169
d = fc.FeatureColumn.to_dict(cc)
157170
self.assertEqual(d["type"], "CrossColumn")
158-
keys = d["keys"]
171+
keys = d["value"]["keys"]
159172
self.assertEqual(len(keys), 2)
160173
self.assertEqual(keys[0]["type"], "NumericColumn")
161-
self.assertEqual(keys[0]["field_desc"], desc.to_dict())
174+
self.assertEqual(keys[0]["value"]["field_desc"], desc.to_dict())
162175
self.assertEqual(keys[1], "cross_feature_2")
163-
self.assertEqual(d["hash_bucket_size"], hash_bucket_size)
164-
self.check_from_dict(cc)
176+
self.assertEqual(d["value"]["hash_bucket_size"], hash_bucket_size)
177+
self.check_serialize(cc)
165178

166179
def test_embedding_and_indicator_column(self):
167180
desc = self.new_field_desc()
@@ -179,13 +192,15 @@ def test_embedding_and_indicator_column(self):
179192

180193
d = fc.FeatureColumn.to_dict(fc1)
181194
self.assertEqual(d["type"], fc_class.__name__)
182-
self.assertEqual(d["name"], "")
183-
self.assertEqual(d["category_column"]["type"],
195+
self.assertEqual(d["value"]["name"], "")
196+
self.assertEqual(d["value"]["category_column"]["type"],
184197
"CategoryHashColumn")
185-
self.assertEqual(d["category_column"]["field_desc"],
186-
desc.to_dict())
187-
self.assertEqual(d["category_column"]["bucket_size"], 4096)
188-
self.check_from_dict(fc1)
198+
self.assertEqual(
199+
d["value"]["category_column"]["value"]["field_desc"],
200+
desc.to_dict())
201+
self.assertEqual(
202+
d["value"]["category_column"]["value"]["bucket_size"], 4096)
203+
self.check_serialize(fc1)
189204

190205
fc2 = fc_class(category_column=None, name="my_category_column")
191206
fc2_descs = fc2.get_field_desc()
@@ -197,9 +212,9 @@ def test_embedding_and_indicator_column(self):
197212

198213
d = fc.FeatureColumn.to_dict(fc2)
199214
self.assertEqual(d["type"], fc_class.__name__)
200-
self.assertEqual(d["name"], "my_category_column")
201-
self.assertEqual(d["category_column"], None)
202-
self.check_from_dict(fc2)
215+
self.assertEqual(d["value"]["name"], "my_category_column")
216+
self.assertEqual(d["value"]["category_column"], None)
217+
self.check_serialize(fc2)
203218

204219

205220
if __name__ == '__main__':

python/runtime/feature/derivation_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import json
1415
import unittest
1516

17+
import runtime.feature.column as fc
1618
import runtime.feature.derivation as fd
1719
import runtime.testing as testing
1820
from runtime.feature.column import (CategoryIDColumn, CrossColumn,
@@ -82,6 +84,14 @@ def test_infer_index(self):
8284
@unittest.skipUnless(testing.get_driver() in ["mysql", "hive"],
8385
"skip non MySQL and Hive tests")
8486
class TestFeatureDerivationWithMockedFeatures(unittest.TestCase):
87+
def check_json_dump(self, features):
88+
dump_json = json.dumps(features, cls=fc.JSONEncoderWithFeatureColumn)
89+
new_features = json.loads(dump_json,
90+
cls=fc.JSONDecoderWithFeatureColumn)
91+
new_dump_json = json.dumps(new_features,
92+
cls=fc.JSONEncoderWithFeatureColumn)
93+
self.assertEqual(dump_json, new_dump_json)
94+
8595
def test_without_cross(self):
8696
features = {
8797
'feature_columns': [
@@ -108,6 +118,9 @@ def test_without_cross(self):
108118
features, label = fd.infer_feature_columns(conn, select, features,
109119
label)
110120

121+
self.check_json_dump(features)
122+
self.check_json_dump(label)
123+
111124
self.assertEqual(len(features), 1)
112125
self.assertTrue("feature_columns" in features)
113126
features = features["feature_columns"]
@@ -230,6 +243,9 @@ def test_with_cross(self):
230243
features, label = fd.infer_feature_columns(conn, select, features,
231244
label)
232245

246+
self.check_json_dump(features)
247+
self.check_json_dump(label)
248+
233249
self.assertEqual(len(features), 1)
234250
self.assertTrue("feature_columns" in features)
235251
features = features["feature_columns"]
@@ -322,6 +338,9 @@ def test_no_column_clause(self):
322338
features, label = fd.infer_feature_columns(conn, select, features,
323339
label)
324340

341+
self.check_json_dump(features)
342+
self.check_json_dump(label)
343+
325344
self.assertEqual(len(features), 1)
326345
self.assertTrue("feature_columns" in features)
327346
features = features["feature_columns"]

0 commit comments

Comments
 (0)