Skip to content

Commit e41640d

Browse files
CloudManXUbuntuUbuntu
authored andcommitted
Frontend Support for Sagemaker-Sklearn-Extension Models Part I (neo-ai#145)
* tmp commit * tmp checkpoint * checkpoint * add auto_ml frontend parser * +registration of shapefunc for isnan and isinf, enable dyn tiling in robust imputer * tmp * unit tests for robustImputer, thresholdOneHotEncoder, robustStandardScaler and ColumnTransformer * Add ASF header * docker sklearn installation * docker installation of Sklearn * typo fixes * documentation fixes and error handling when sklearn is not installed Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-239.us-west-2.compute.internal> Co-authored-by: Ubuntu <ubuntu@ip-172-31-72-105.ec2.internal>
1 parent cab407a commit e41640d

File tree

6 files changed

+333
-0
lines changed

6 files changed

+333
-0
lines changed

docker/Dockerfile.ci_cpu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ RUN bash /install/ubuntu_install_arm_compute_lib.sh
8080
COPY install/ubuntu_install_caffe.sh /install/ubuntu_install_caffe.sh
8181
RUN bash /install/ubuntu_install_caffe.sh
8282

83+
# Sagemaker-Sklearn-Extension deps
84+
COPY install/ubuntu_install_sklearn.sh /install/ubuntu_install_sklearn.sh
85+
RUN bash /install/ubuntu_install_sklearn.sh
86+
8387
# Github Arm(R) Ethos(TM)-N NPU driver
8488
COPY install/ubuntu_install_ethosn_driver_stack.sh /install/ubuntu_install_ethosn_driver_stack.sh
8589
RUN bash /install/ubuntu_install_ethosn_driver_stack.sh
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/bin/bash
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
19+
set -e
20+
set -u
21+
set -o pipefail
22+
23+
# install the latest version of Sklearn and Sagemaker-Scikit-Learn-Extension
24+
pip3 install sklearn
25+
pip3 install sagemaker-scikit-learn-extension

python/tvm/relay/frontend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,5 @@
3434
from .darknet import from_darknet
3535
from .pytorch import from_pytorch
3636
from .caffe import from_caffe
37+
from .sklearn import from_sklearn
38+
from .sklearn import from_auto_ml
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
18+
# pylint: disable=import-outside-toplevel
19+
20+
import numpy as np
21+
import tvm
22+
from tvm.ir import IRModule
23+
24+
from ... import nd as _nd
25+
from .. import analysis
26+
from .. import expr as _expr
27+
from .. import function as _function
28+
from .. import op as _op
29+
from .. import vision as _vision
30+
31+
from ..function import Function
32+
from ..expr import Call, Let
33+
from ..expr import If, Tuple, TupleGetItem
34+
from ..expr import RefCreate, RefRead, RefWrite
35+
from ..expr_functor import ExprFunctor
36+
from ..adt import Match, Clause
37+
38+
from .common import AttrCvt, Renamer, ExprTable
39+
from .common import get_relay_op, new_var, infer_shape, infer_channels
40+
from .common import infer_type, get_name
41+
from .common import infer_value as _infer_value
42+
from .common import infer_value_simulated as _infer_value_simulated
43+
44+
45+
def _SimpleImputer(op, inexpr, dshape, dtype, columns=None):
46+
"""
47+
Scikit-Learn Transformer:
48+
Imputation transformer for completing missing values.
49+
"""
50+
boolean_mask = _op.isnan(inexpr)
51+
fill_col = _op.const(np.array(op.statistics_, dtype=dtype))
52+
input_shape = _op.shape_of(inexpr)
53+
reps = _op.take(input_shape, _op.const([0]))
54+
reps = _op.concatenate([reps, _op.const([1])], axis=0)
55+
56+
fill_val = _op.tile(fill_col, reps=reps)
57+
indices =_op.const(np.arange(len(op.statistics_)))
58+
fill_val = _op.take(fill_val, indices=indices, axis=1)
59+
60+
ret = _op.where(boolean_mask,
61+
fill_val,
62+
inexpr)
63+
64+
return ret
65+
66+
def _RobustImputer(op, inexpr, dshape, dtype, columns=None):
67+
"""
68+
Sagemaker-Scikit-Learn-Extension Transformer:
69+
Imputation transformer for completing missing values with multi-column support.
70+
"""
71+
if columns:
72+
column_indices = _op.const(columns)
73+
inexpr = _op.take(inexpr, indices=column_indices, axis=1)
74+
75+
if op.mask_function is not None:
76+
inf_mask = _op.isinf(inexpr)
77+
nan_val = _op.full_like(inexpr, _op.const(np.array(np.nan, dtype=dtype)))
78+
inexpr = _op.where(inf_mask, nan_val, inexpr)
79+
ret = _SimpleImputer(op.simple_imputer_, inexpr, dshape, dtype, columns)
80+
81+
return ret
82+
83+
def _ThresholdOneHotEncoder(op, inexpr, dshape, dtype, columns=None):
84+
"""
85+
Sagemaker-Scikit-Learn-Extension Transformer:
86+
Encode categorical integer features as a one-hot numeric array, with optional restrictions on
87+
feature encoding.
88+
"""
89+
if columns:
90+
column_indices = _op.const(columns)
91+
inexpr = _op.take(inexpr, indices=column_indices, axis=1)
92+
93+
num_cat = len(op.categories_)
94+
cols = _op.split(inexpr, num_cat, axis=1)
95+
96+
out = []
97+
for i in range(num_cat):
98+
category = op.categories_[i]
99+
cat_tensor = _op.const(np.array(category, dtype=dtype))
100+
tiled_col = _op.tile(cols[i], (1, len(category)))
101+
one_hot_mask = _op.equal(tiled_col, cat_tensor)
102+
one_hot = _op.cast(one_hot_mask, dtype)
103+
out.append(one_hot)
104+
105+
ret = _op.concatenate(out, axis=1)
106+
return ret
107+
108+
def _RobustStandardScaler(op, inexpr, dshape, dtype, columns=None):
109+
"""
110+
Sagemaker-Scikit-Learn-Extension Transformer:
111+
Standardize features by removing the mean and scaling to unit variance
112+
"""
113+
scaler = op.scaler_
114+
ret = _op.subtract(inexpr, _op.const(np.array(scaler.mean_, dtype), dtype))
115+
ret = _op.divide(ret, _op.const(np.array(scaler.scale_, dtype), dtype))
116+
return ret
117+
118+
def _ColumnTransformer(op, inexpr, dshape, dtype, columns=None):
119+
"""
120+
Scikit-Learn Compose:
121+
Applies transformers to columns of an array
122+
"""
123+
out = []
124+
for _, pipe, cols in op.transformers_:
125+
mod = pipe.steps[0][1]
126+
out.append(sklearn_op_to_relay(mod, inexpr, dshape, dtype, cols))
127+
128+
return _op.concatenate(out, axis=1)
129+
130+
_convert_map = {
131+
'ColumnTransformer':_ColumnTransformer,
132+
'SimpleImputer': _SimpleImputer,
133+
'RobustImputer': _RobustImputer,
134+
'RobustStandardScaler': _RobustStandardScaler,
135+
'ThresholdOneHotEncoder': _ThresholdOneHotEncoder
136+
}
137+
138+
def sklearn_op_to_relay(op, inexpr, dshape, dtype, columns=None):
139+
classname = type(op).__name__
140+
return _convert_map[classname](op, inexpr, dshape, dtype, columns)
141+
142+
def from_sklearn(model,
143+
shape=None,
144+
dtype="float32",
145+
columns=None):
146+
147+
try:
148+
import sklearn
149+
except ImportError as e:
150+
raise ImportError(
151+
"Unable to import scikit-learn which is required {}".format(e))
152+
153+
inexpr = _expr.var('input', shape=shape, dtype=dtype)
154+
outexpr = sklearn_op_to_relay(model, inexpr, shape, dtype, columns)
155+
156+
func = _function.Function(analysis.free_vars(outexpr), outexpr)
157+
return IRModule.from_expr(func), []
158+
159+
def from_auto_ml(model,
160+
shape=None,
161+
dtype="float32"):
162+
163+
try:
164+
import sklearn
165+
except ImportError as e:
166+
raise ImportError(
167+
"Unable to import scikit-learn which is required {}".format(e))
168+
169+
outexpr = _expr.var('input', shape=shape, dtype=dtype)
170+
for _, transformer in model.feature_transformer.steps:
171+
outexpr = sklearn_op_to_relay(transformer, outexpr, shape, dtype, None)
172+
173+
func = _function.Function(analysis.free_vars(outexpr), outexpr)
174+
return IRModule.from_expr(func), []

python/tvm/relay/op/_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,7 @@ def elemwise_shape_func(attrs, inputs, _):
272272
register_shape_func("clip", False, elemwise_shape_func)
273273
register_shape_func("log2", False, elemwise_shape_func)
274274
register_shape_func("sigmoid", False, elemwise_shape_func)
275+
register_shape_func("isnan", False, elemwise_shape_func)
276+
register_shape_func("isinf", False, elemwise_shape_func)
277+
register_shape_func("where", False, elemwise_shape_func)
278+
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import numpy as np
18+
19+
from sklearn.pipeline import Pipeline
20+
from sklearn.impute import SimpleImputer
21+
from sklearn.compose import ColumnTransformer
22+
from sagemaker_sklearn_extension.impute import RobustImputer
23+
from sagemaker_sklearn_extension.preprocessing import RobustStandardScaler
24+
from sagemaker_sklearn_extension.preprocessing import ThresholdOneHotEncoder
25+
26+
from tvm import topi
27+
import tvm.topi.testing
28+
import tvm
29+
from tvm import te
30+
from tvm import relay
31+
from tvm.contrib import graph_runtime
32+
import scipy
33+
34+
class SklearnTestHelper:
35+
def __init__(self, target='llvm', ctx=tvm.cpu(0)):
36+
self.compiled_model = None
37+
self.target = target
38+
self.ctx = ctx
39+
40+
def compile(self, model, dshape, dtype, columns=None, auto_ml=False):
41+
if auto_ml:
42+
mod, _ = relay.frontend.from_auto_ml(model, dshape, dtype)
43+
else:
44+
mod, _ = relay.frontend.from_sklearn(model, dshape, dtype, columns)
45+
46+
self.ex = relay.create_executor('vm', mod=mod, ctx=self.ctx, target=self.target)
47+
48+
def run(self, data):
49+
result = self.ex.evaluate()(data)
50+
return result.asnumpy()
51+
52+
def _test_model_impl(helper, model, dshape, input_data):
53+
helper.compile(model, dshape, 'float32')
54+
sklearn_out = model.transform(input_data)
55+
tvm_out = helper.run(input_data)
56+
tvm.testing.assert_allclose(sklearn_out, tvm_out, rtol=1e-5, atol=1e-5)
57+
58+
def test_simple_imputer():
59+
st_helper = SklearnTestHelper()
60+
data = np.array([[4, 5, np.nan, 7], [0, np.nan, 2, 3], [8, 9, 10, 11], [np.nan, 13, 14, 15]],
61+
dtype=np.float32)
62+
63+
imp_mean = SimpleImputer(missing_values=np.nan, strategy='median')
64+
imp_mean.fit(data)
65+
66+
dshape = (relay.Any(), len(data[0]))
67+
_test_model_impl(st_helper, imp_mean, dshape, data)
68+
69+
def test_robust_imputer():
70+
st_helper = SklearnTestHelper()
71+
data = np.array([[4, 5, np.nan, 7], [0, np.nan, 2, 3], [8, 9, 10, 11], [np.nan, 13, 14, 15]],
72+
dtype=np.float32)
73+
74+
ri = RobustImputer(dtype=None, strategy="constant", fill_values=np.nan, mask_function=None)
75+
ri.fit(data)
76+
77+
dshape = (relay.Any(), len(data[0]))
78+
_test_model_impl(st_helper, ri, dshape, data)
79+
80+
def test_robust_scaler():
81+
st_helper = SklearnTestHelper()
82+
rss = RobustStandardScaler()
83+
84+
data = np.array([[0, 0], [0, 0], [1, 1], [1, 1]], dtype=np.float32)
85+
rss.fit(data)
86+
87+
dshape = (relay.Any(), len(data[0]))
88+
_test_model_impl(st_helper, rss, dshape, data)
89+
90+
def test_threshold_onehot_encoder():
91+
st_helper = SklearnTestHelper()
92+
tohe = ThresholdOneHotEncoder()
93+
94+
data = np.array([[10, 1, 7], [11, 3, 8], [11, 2, 9]], dtype=np.int32)
95+
tohe.fit(data)
96+
tohe.categories_ = [[10, 11], [1, 2, 3], [7, 8, 9]]
97+
98+
dshape = (relay.Any(), len(data[0]))
99+
st_helper.compile(tohe, dshape, 'int32')
100+
sklearn_out = tohe.transform(data).toarray()
101+
tvm_out = st_helper.run(data)
102+
tvm.testing.assert_allclose(sklearn_out, tvm_out, rtol=1e-5, atol=1e-5)
103+
104+
def test_column_transfomer():
105+
st_helper = SklearnTestHelper()
106+
107+
data = np.array([[4, 5, np.nan, 7], [0, np.nan, 2, 3], [8, 9, 10, 11], [np.nan, 13, 14, 15]],
108+
dtype=np.float32)
109+
110+
pipeline = Pipeline(steps=[('robustimputer',
111+
RobustImputer(fill_values=np.nan, strategy='constant'))])
112+
ct = ColumnTransformer(transformers=[('numeric_processing', pipeline, [0, 1, 2, 3])])
113+
ct.fit(data)
114+
115+
dshape = (relay.Any(), relay.Any())
116+
_test_model_impl(st_helper, ct, dshape, data)
117+
118+
119+
if __name__ == '__main__':
120+
test_simple_imputer()
121+
test_robust_imputer()
122+
test_robust_scaler()
123+
test_column_transfomer()
124+
test_threshold_onehot_encoder()

0 commit comments

Comments
 (0)