Skip to content
This repository was archived by the owner on May 25, 2024. It is now read-only.

Commit a9a59c4

Browse files
committed
add XGBClassifier
1 parent 66a8fbb commit a9a59c4

File tree

8 files changed

+82
-4
lines changed

8 files changed

+82
-4
lines changed

MANIFEST

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ micromlgen/__init__.py
55
micromlgen/convolution.py
66
micromlgen/decisiontree.py
77
micromlgen/gaussiannb.py
8+
micromlgen/linear_regression.py
89
micromlgen/logisticregression.py
910
micromlgen/micromlgen.py
1011
micromlgen/pca.py
@@ -16,6 +17,7 @@ micromlgen/sefr.py
1617
micromlgen/svm.py
1718
micromlgen/utils.py
1819
micromlgen/wifiindoorpositioning.py
20+
micromlgen/xgboost.py
1921
micromlgen/templates/_skeleton.jinja
2022
micromlgen/templates/classmap.jinja
2123
micromlgen/templates/dot.jinja
@@ -27,6 +29,7 @@ micromlgen/templates/decisiontree/decisiontree.jinja
2729
micromlgen/templates/decisiontree/tree.jinja
2830
micromlgen/templates/gaussiannb/gaussiannb.jinja
2931
micromlgen/templates/gaussiannb/vote.jinja
32+
micromlgen/templates/linearregression/linearregression.jinja
3033
micromlgen/templates/logisticregression/logisticregression.jinja
3134
micromlgen/templates/logisticregression/vote.arduino.jinja
3235
micromlgen/templates/logisticregression/vote.attiny.jinja
@@ -49,3 +52,5 @@ micromlgen/templates/svm/kernel/arduino.jinja
4952
micromlgen/templates/svm/kernel/attiny.jinja
5053
micromlgen/templates/svm/kernel/kernel.jinja
5154
micromlgen/templates/wifiindoorpositioning/wifiindoorpositioning.jinja
55+
micromlgen/templates/xgboost/tree.jinja
56+
micromlgen/templates/xgboost/xgboost.jinja

micromlgen/__init__.pyc

504 Bytes
Binary file not shown.

micromlgen/micromlgen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from micromlgen.pca import is_pca, port_pca
1010
from micromlgen.principalfft import is_principalfft, port_principalfft
1111
from micromlgen.linear_regression import is_linear_regression, port_linear_regression
12-
12+
from micromlgen.xgboost import is_xgboost, port_xgboost
1313

1414
def port(
1515
clf,
@@ -40,4 +40,6 @@ def port(
4040
return port_principalfft(**locals(), **kwargs)
4141
elif is_linear_regression(clf):
4242
return port_linear_regression(**locals(), **kwargs)
43+
elif is_xgboost(clf):
44+
return port_xgboost(**locals(), **kwargs)
4345
raise TypeError('clf MUST be one of %s' % ', '.join(platforms.ALLOWED_CLASSIFIERS))

micromlgen/platforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
'GaussianNB',
1717
'LogisticRegression',
1818
'PCA',
19-
'PrincipalFFT'
19+
'PrincipalFFT',
20+
'LinearRegression'
2021
]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{% if tree['left'][i] != tree['right'][i] %}
2+
if (x[{{ tree['features'][i] }}] <= {{ tree['thresholds'][i] }}) {
3+
{% with i = tree['left'][i] %}
4+
{% include 'xgboost/tree.jinja' %}
5+
{% endwith %}
6+
}
7+
else {
8+
{% with i = tree['right'][i] %}
9+
{% include 'xgboost/tree.jinja' %}
10+
{% endwith %}
11+
}
12+
{% else %}
13+
votes[{{ class_idx }}] += {{ tree['thresholds'][i] }};
14+
{% endif %}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{% extends '_skeleton.jinja' %}
2+
3+
{% block predict %}
4+
float votes[{{ n_classes }}] = { 0.0f };
5+
6+
{% for k, tree in f.enumerate(trees) %}
7+
{% with i = 0, class_idx = k % n_classes %}
8+
// tree #{{ k + 1 }}
9+
{% include 'xgboost/tree.jinja' %}
10+
{% endwith %}
11+
{% endfor %}
12+
13+
{% include 'vote.jinja' %}
14+
{% endblock %}

micromlgen/xgboost.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from micromlgen.utils import jinja, check_type
2+
from tempfile import NamedTemporaryFile
3+
import json
4+
5+
6+
def format_tree(tree):
7+
"""
8+
Format xgboost tree like a sklearn DecisionTree
9+
:param tree:
10+
:return:
11+
"""
12+
split_indices = tree['split_indices']
13+
split_conditions = tree['split_conditions']
14+
left_children = tree['left_children']
15+
right_children = tree['right_children']
16+
return {
17+
'left': left_children,
18+
'right': right_children,
19+
'features': split_indices,
20+
'thresholds': split_conditions
21+
}
22+
23+
24+
def is_xgboost(clf):
25+
"""Test if classifier can be ported"""
26+
return check_type(clf, 'XGBClassifier')
27+
28+
29+
def port_xgboost(clf, **kwargs):
30+
"""Port a XGBoost classifier"""
31+
with NamedTemporaryFile('w+', suffix='.json', encoding='utf-8') as tmp:
32+
clf.save_model(tmp.name)
33+
tmp.seek(0)
34+
decoded = json.load(tmp)
35+
trees = [format_tree(tree) for tree in decoded['learner']['gradient_booster']['model']['trees']]
36+
print(trees)
37+
return jinja('xgboost/xgboost.jinja', {
38+
'n_classes': int(decoded['learner']['learner_model_param']['num_class']),
39+
'trees': trees,
40+
}, {
41+
'classname': 'XGBClassifier'
42+
}, **kwargs)

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
setup(
88
name = 'micromlgen',
99
packages = ['micromlgen'],
10-
version = '1.1.10',
10+
version = '1.1.11',
1111
license='MIT',
1212
description = 'Generate C code for microcontrollers from Python\'s sklearn classifiers',
1313
author = 'Simone Salerno',
1414
author_email = 'eloquentarduino@gmail.com',
1515
url = 'https://github.com/eloquentarduino/micromlgen',
16-
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_1110.tar.gz',
16+
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_1111.tar.gz',
1717
keywords = [
1818
'ML',
1919
'microcontrollers',

0 commit comments

Comments
 (0)