Skip to content

Commit

Permalink
Xiaowu/one vs one classifier (#904)
Browse files Browse the repository at this point in the history
* add essential files

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

* update files

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

* updated

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

* make the converter working

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

* Update test_sklearn_one_vs_one_classifier_converter.py

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

* Update one_vs_one_classifier.py

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

* update files

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

* first fix for ovo converter

Signed-off-by: xadupre <xadupre@microsoft.com>

* fix ovo, still an issue with LogisiticRegression and DecisionTree

Signed-off-by: xadupre <xadupre@microsoft.com>

* Update requirements.txt

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

* Update requirements.txt

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

* remove unnecessary files

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

* fix ovo

Signed-off-by: xadupre <xadupre@microsoft.com>

* final fix for ovo

Signed-off-by: xadupre <xadupre@microsoft.com>

* remove unnecessary option

Signed-off-by: xadupre <xadupre@microsoft.com>

* lint issues

Signed-off-by: xadupre <xadupre@microsoft.com>

* update ci

Signed-off-by: xadupre <xadupre@microsoft.com>

* Update linux-conda-CI.yml

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

* change CI

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>

Signed-off-by: xiaowuhu <xiaowuhu@microsoft.com>
Signed-off-by: xadupre <xadupre@microsoft.com>
Co-authored-by: xadupre <xadupre@microsoft.com>
  • Loading branch information
xiaowuhu and xadupre authored Sep 2, 2022
1 parent 9ece520 commit 2260310
Show file tree
Hide file tree
Showing 13 changed files with 407 additions and 115 deletions.
4 changes: 3 additions & 1 deletion .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,10 @@ jobs:
displayName: 'pytest'
- script: |
# some of this is triggering the following error when importing scipy on python 3.10
# ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29'
conda install -c conda-forge "lightgbm${lgbm.version}" xgboost --no-deps
pip install xgboost lightgbm hummingbird-ml hummingbird
pip install hummingbird-ml hummingbird xgboost lightgbm
pip install --no-deps git+https://github.com/microsoft/onnxconverter-common.git
pip install onnxmltools
displayName: 'install onnxmltools'
Expand Down
19 changes: 1 addition & 18 deletions .azure-pipelines/win32-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,7 @@ jobs:
onnxrt.version: 'onnxruntime==1.7.0' # -i https://test.pypi.org/simple/ ort-nightly'
onnxcc.version: 'onnxconverter-common==1.7.0' # git+https://github.com/microsoft/onnxconverter-common.git
sklearn.version: '==0.24.1'
Py38-Onnx181-Rt160-Skl0240:
python.version: '3.8'
onnx.version: 'onnx==1.8.1'
onnx.target_opset: ''
numpy.version: 'numpy>=1.18.1'
scipy.version: 'scipy'
onnxrt.version: 'onnxruntime==1.6.0'
onnxcc.version: 'onnxconverter-common==1.7.0'
sklearn.version: '==0.24.0'
Py38-Onnx170-Rt160-Skl0240:
python.version: '3.8'
onnx.version: 'onnx==1.7.0'
onnx.target_opset: ''
numpy.version: 'numpy>=1.18.1'
scipy.version: 'scipy'
onnxrt.version: 'onnxruntime==1.6.0'
onnxcc.version: 'onnxconverter-common==1.7.0'
sklearn.version: '==0.24.0'


maxParallel: 3

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ onnx>=1.2.1
scikit-learn>=0.19
scikit-learn<=1.1.1
onnxconverter-common>=1.7.0
scikit-learn<=1.1.1
4 changes: 3 additions & 1 deletion skl2onnx/_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
)

# Multi-class
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multiclass import OneVsRestClassifier, OneVsOneClassifier

# Tree-based models
from sklearn.ensemble import (
Expand Down Expand Up @@ -284,6 +284,7 @@
MLPClassifier,
MultinomialNB,
NuSVC,
OneVsOneClassifier,
OneVsRestClassifier,
PassiveAggressiveClassifier,
Perceptron,
Expand Down Expand Up @@ -373,6 +374,7 @@ def build_sklearn_operator_name_map():
Normalizer,
OneClassSVM,
OneHotEncoder,
OneVsOneClassifier,
OneVsRestClassifier,
OrdinalEncoder,
PCA,
Expand Down
2 changes: 2 additions & 0 deletions skl2onnx/operator_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from . import nearest_neighbours
from . import normaliser
from . import one_hot_encoder
from . import one_vs_one_classifier
from . import one_vs_rest_classifier
from . import ordinal_encoder
from . import ovr_decision_function
Expand Down Expand Up @@ -104,6 +105,7 @@
nearest_neighbours,
normaliser,
one_hot_encoder,
one_vs_one_classifier,
one_vs_rest_classifier,
ordinal_encoder,
ovr_decision_function,
Expand Down
154 changes: 154 additions & 0 deletions skl2onnx/operator_converters/one_vs_one_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0

from sklearn.base import is_regressor
from ..proto import onnx_proto
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..common._apply_operation import apply_cast, apply_concat, apply_reshape
from ..common.data_types import guess_proto_type, Int64TensorType
from .._supported_operators import sklearn_operator_name_map


def _iteration_one_versus(scope, container, inputs, i, estimator, cl_type,
proto_dtype, use_raw_scores=True, prob_shape=None):
op_type = sklearn_operator_name_map[type(estimator)]

this_operator = scope.declare_local_operator(op_type, raw_model=estimator)
this_operator.inputs = inputs

if is_regressor(estimator):
score_name = scope.declare_local_variable('score_%d' % i, cl_type())
this_operator.outputs.append(score_name)

if hasattr(estimator, 'coef_') and len(estimator.coef_.shape) == 2:
raise RuntimeError(
"OneVsRestClassifier or OneVsOneClassifier accepts "
"regressor with only one target.")
p1 = score_name.onnx_name
return None, None, p1

if container.has_options(estimator, 'raw_scores'):
options = {'raw_scores': use_raw_scores}
elif container.has_options(estimator, 'zipmap'):
options = {'zipmap': False}
else:
options = None
if options is not None:
container.add_options(id(estimator), options)
scope.add_options(id(estimator), options)

label_name = scope.declare_local_variable(
'label_%d' % i, Int64TensorType())
prob_name = scope.declare_local_variable(
'proba_%d' % i, inputs[0].type.__class__())
this_operator.outputs.append(label_name)
this_operator.outputs.append(prob_name)

# gets the label for the class 1
label = scope.get_unique_variable_name('lab_%d' % i)
apply_reshape(scope, label_name.onnx_name, label, container,
desired_shape=(-1, 1))
cast_label = scope.get_unique_variable_name('cast_lab_%d' % i)
apply_cast(scope, label, cast_label, container,
to=proto_dtype)

# get the probability for the class 1
if prob_shape is None:
# shape to use to reshape score
cst0 = scope.get_unique_variable_name('cst0')
container.add_initializer(cst0, onnx_proto.TensorProto.INT64, [1], [0])
shape = scope.get_unique_variable_name('shape')
container.add_node('Shape', [inputs[0].full_name], [shape])
first_dim = scope.get_unique_variable_name('dim')
container.add_node('Gather', [shape, cst0], [first_dim])
cst_1 = scope.get_unique_variable_name('cst_1')
container.add_initializer(
cst_1, onnx_proto.TensorProto.INT64, [1], [-1])
prob_shape = scope.get_unique_variable_name('shape')
apply_concat(scope, [first_dim, cst_1], prob_shape, container, axis=0)

prob_reshaped = scope.get_unique_variable_name('prob_%d' % i)
container.add_node('Reshape', [prob_name.onnx_name, prob_shape],
[prob_reshaped])

cst1 = scope.get_unique_variable_name('cst1')
container.add_initializer(cst1, onnx_proto.TensorProto.INT64, [1], [1])
cst2 = scope.get_unique_variable_name('cst2')
container.add_initializer(cst2, onnx_proto.TensorProto.INT64, [1], [2])

prob1 = scope.get_unique_variable_name('prob1_%d' % i)
container.add_node(
'Slice', [prob_reshaped, cst1, cst2, cst1], prob1)
return prob_shape, cast_label, prob1


def convert_one_vs_one_classifier(scope: Scope, operator: Operator,
container: ModelComponentContainer):

proto_dtype = guess_proto_type(operator.inputs[0].type)
if proto_dtype != onnx_proto.TensorProto.DOUBLE:
proto_dtype = onnx_proto.TensorProto.FLOAT
op = operator.raw_operator

# shape to use to reshape score
cst0 = scope.get_unique_variable_name('cst0')
container.add_initializer(cst0, onnx_proto.TensorProto.INT64, [1], [0])
cst1 = scope.get_unique_variable_name('cst1')
container.add_initializer(cst1, onnx_proto.TensorProto.INT64, [1], [1])
cst2 = scope.get_unique_variable_name('cst2')
container.add_initializer(cst2, onnx_proto.TensorProto.INT64, [1], [2])
shape = scope.get_unique_variable_name('shape')
container.add_node('Shape', [operator.inputs[0].full_name], [shape])
first_dim = scope.get_unique_variable_name('dim')
container.add_node('Gather', [shape, cst0], [first_dim])
cst_1 = scope.get_unique_variable_name('cst_1')
container.add_initializer(cst_1, onnx_proto.TensorProto.INT64, [1], [-1])
prob_shape = scope.get_unique_variable_name('shape')
apply_concat(scope, [first_dim, cst_1], prob_shape, container, axis=0)

label_names = []
prob_names = []
prob_shape = None
cl_type = operator.inputs[0].type.__class__
for i, estimator in enumerate(op.estimators_):
prob_shape, cast_label, prob1 = _iteration_one_versus(
scope, container, operator.inputs, i, estimator, cl_type,
proto_dtype, True, prob_shape=prob_shape)

label_names.append(cast_label)
prob_names.append(prob1)

conc_lab_name = scope.get_unique_variable_name('concat_out_ovo_label')
apply_concat(scope, label_names, conc_lab_name, container, axis=1)
conc_prob_name = scope.get_unique_variable_name('concat_out_ovo_prob')
apply_concat(scope, prob_names, conc_prob_name, container, axis=1)

# calls _ovr_decision_function
this_operator = scope.declare_local_operator(
"SklearnOVRDecisionFunction", op)

cl_type = operator.inputs[0].type.__class__
label = scope.declare_local_variable("label", cl_type())
container.add_node('Identity', [conc_lab_name], [label.onnx_name])
prob_score = scope.declare_local_variable("prob_score", cl_type())
container.add_node('Identity', [conc_prob_name], [prob_score.onnx_name])

this_operator.inputs.append(label)
this_operator.inputs.append(prob_score)

ovr_name = scope.declare_local_variable('ovr_output', cl_type())
this_operator.outputs.append(ovr_name)

output_name = operator.outputs[1].full_name
container.add_node('Identity', [ovr_name.onnx_name], [output_name])

container.add_node(
'ArgMax', 'ovr_output', operator.outputs[0].full_name, axis=1)


register_converter('SklearnOneVsOneClassifier',
convert_one_vs_one_classifier,
options={'zipmap': [True, False, 'columns'],
'nocl': [True, False],
'output_class_labels': [False, True]})
79 changes: 45 additions & 34 deletions skl2onnx/operator_converters/one_vs_rest_classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from sklearn.base import is_regressor
from sklearn.svm import LinearSVC
from ..proto import onnx_proto
from ..common._apply_operation import (
apply_concat, apply_identity, apply_mul, apply_reshape)
Expand All @@ -15,6 +16,45 @@
from .._supported_operators import sklearn_operator_name_map


def _iteration_one_versus(scope, container, inputs, i, estimator, cl_type,
proto_dtype, use_raw_scores=True, prob_shape=None):
op_type = sklearn_operator_name_map[type(estimator)]

this_operator = scope.declare_local_operator(
op_type, raw_model=estimator)
this_operator.inputs = inputs

if is_regressor(estimator):
score_name = scope.declare_local_variable('score_%d' % i, cl_type())
this_operator.outputs.append(score_name)

if hasattr(estimator, 'coef_') and len(estimator.coef_.shape) == 2:
raise RuntimeError(
"OneVsRestClassifier or OneVsOneClassifier accepts "
"regressor with only one target.")
p1 = score_name.onnx_name
else:
if container.has_options(estimator, 'raw_scores'):
container.add_options(
id(estimator), {'raw_scores': use_raw_scores})
scope.add_options(
id(estimator), {'raw_scores': use_raw_scores})
label_name = scope.declare_local_variable(
'label_%d' % i, Int64TensorType())
prob_name = scope.declare_local_variable('proba_%d' % i, cl_type())
this_operator.outputs.append(label_name)
this_operator.outputs.append(prob_name)

# gets the probability for the class 1
p1 = scope.get_unique_variable_name('probY_%d' % i)
if isinstance(estimator, LinearSVC):
apply_identity(scope, prob_name.onnx_name, p1, container)
else:
apply_slice(scope, prob_name.onnx_name, p1, container, starts=[1],
ends=[2], axes=[1])
return None, None, p1


def convert_one_vs_rest_classifier(scope: Scope, operator: Operator,
container: ModelComponentContainer):
"""
Expand All @@ -31,41 +71,12 @@ def convert_one_vs_rest_classifier(scope: Scope, operator: Operator,
options = container.get_options(op, dict(raw_scores=False))
use_raw_scores = options['raw_scores']
probs_names = []
cl_type = operator.inputs[0].type.__class__
prob_shape = None
for i, estimator in enumerate(op.estimators_):
op_type = sklearn_operator_name_map[type(estimator)]

this_operator = scope.declare_local_operator(
op_type, raw_model=estimator)
this_operator.inputs = operator.inputs

if is_regressor(estimator):
score_name = scope.declare_local_variable(
'score_%d' % i, operator.inputs[0].type.__class__())
this_operator.outputs.append(score_name)

if hasattr(estimator, 'coef_') and len(estimator.coef_.shape) == 2:
raise RuntimeError("OneVsRestClassifier accepts "
"regressor with only one target.")
p1 = score_name.onnx_name
else:
if container.has_options(estimator, 'raw_scores'):
container.add_options(
id(estimator), {'raw_scores': use_raw_scores})
scope.add_options(
id(estimator), {'raw_scores': use_raw_scores})
label_name = scope.declare_local_variable(
'label_%d' % i, Int64TensorType())
prob_name = scope.declare_local_variable(
'proba_%d' % i, operator.inputs[0].type.__class__())
this_operator.outputs.append(label_name)
this_operator.outputs.append(prob_name)

# gets the probability for the class 1
p1 = scope.get_unique_variable_name('probY_%d' % i)
apply_slice(scope, prob_name.onnx_name, p1, container, starts=[1],
ends=[2], axes=[1],
operator_name=scope.get_unique_operator_name('Slice'))

prob_shape, _, p1 = _iteration_one_versus(
scope, container, operator.inputs, i, estimator, cl_type,
proto_dtype, use_raw_scores, prob_shape=prob_shape)
probs_names.append(p1)

if op.multilabel_:
Expand Down
Loading

0 comments on commit 2260310

Please sign in to comment.