Skip to content
This repository was archived by the owner on May 25, 2024. It is now read-only.

Commit 06e75d4

Browse files
committed
refactor templates
1 parent e5b062d commit 06e75d4

File tree

20 files changed

+219
-311
lines changed

20 files changed

+219
-311
lines changed

micromlgen/decisiontree.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1-
from sklearn.tree import DecisionTreeClassifier
2-
from micromlgen.utils import jinja
1+
from micromlgen.utils import jinja, check_type
32

43

54
def is_decisiontree(clf):
65
"""Test if classifier can be ported"""
7-
return isinstance(clf, DecisionTreeClassifier)
6+
return check_type(clf, 'DecisionTreeClassifier')
87

98

109
def port_decisiontree(clf, **kwargs):
1110
"""Port sklearn's DecisionTreeClassifier"""
12-
kwargs['classname'] = kwargs['classname'] or 'DecisionTree'
1311
return jinja('decisiontree/decisiontree.jinja', {
1412
'left': clf.tree_.children_left,
1513
'right': clf.tree_.children_right,
1614
'features': clf.tree_.feature,
1715
'thresholds': clf.tree_.threshold,
1816
'classes': clf.tree_.value,
1917
'i': 0
18+
}, {
19+
'classname': 'DecisionTree'
2020
}, **kwargs)

micromlgen/gaussiannb.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
from sklearn.naive_bayes import GaussianNB
2-
from micromlgen.utils import jinja
1+
from micromlgen.utils import jinja, check_type
32

43

54
def is_gaussiannb(clf):
65
"""Test if classifier can be ported"""
7-
return isinstance(clf, GaussianNB)
6+
return check_type(clf, 'GaussianNB')
87

98

109
def port_gaussiannb(clf, **kwargs):
11-
"""Port sklearn's DecisionTreeClassifier"""
12-
kwargs['classname'] = kwargs['classname'] or 'GaussianNB'
10+
"""Port sklearn's GaussianNB"""
1311
return jinja('gaussiannb/gaussiannb.jinja', {
1412
'sigma': clf.sigma_,
1513
'theta': clf.theta_,
1614
'prior': clf.class_prior_,
1715
'classes': clf.classes_,
1816
'n_classes': len(clf.classes_)
17+
}, {
18+
'classname': 'GaussianNB'
1919
}, **kwargs)

micromlgen/logisticregression.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
from sklearn.linear_model import LogisticRegression
2-
from micromlgen.utils import jinja
1+
from micromlgen.utils import jinja, check_type
32

43

54
def is_logisticregression(clf):
65
"""Test if classifier can be ported"""
7-
return isinstance(clf, LogisticRegression)
6+
return check_type(clf, 'LogisticRegression')
87

98

109
def port_logisticregression(clf, **kwargs):
11-
"""Port sklearn's DecisionTreeClassifier"""
12-
kwargs['classname'] = kwargs['classname'] or 'LogisticRegression'
10+
"""Port sklearn's LogisticRegressionClassifier"""
1311
return jinja('logisticregression/logisticregression.jinja', {
1412
'weights': clf.coef_,
1513
'intercept': clf.intercept_,
1614
'classes': clf.classes_,
1715
'n_classes': len(clf.classes_)
16+
}, {
17+
'classname': 'LogisticRegression'
1818
}, **kwargs)

micromlgen/micromlgen.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,12 @@
1-
from sklearn.decomposition import PCA
2-
from sklearn.svm import SVC, LinearSVC, OneClassSVM
3-
4-
try:
5-
from skbayes.rvm_ard_models import RVC
6-
except ImportError:
7-
from micromlgen.patches import RVC
8-
try:
9-
from sefr import SEFR
10-
except ImportError:
11-
from micromlgen.patches import SEFR
12-
131
from micromlgen import platforms
14-
from micromlgen.svm import port_svm
15-
from micromlgen.rvm import port_rvm
16-
from micromlgen.sefr import port_sefr
2+
from micromlgen.svm import is_svm, port_svm
3+
from micromlgen.rvm import is_rvm, port_rvm
4+
from micromlgen.sefr import is_sefr, port_sefr
175
from micromlgen.decisiontree import is_decisiontree, port_decisiontree
186
from micromlgen.randomforest import is_randomforest, port_randomforest
197
from micromlgen.logisticregression import is_logisticregression, port_logisticregression
208
from micromlgen.gaussiannb import is_gaussiannb, port_gaussiannb
21-
from micromlgen.pca import port_pca
9+
from micromlgen.pca import is_pca, port_pca
2210

2311

2412
def port(
@@ -29,13 +17,11 @@ def port(
2917
precision=None):
3018
"""Port a classifier to plain C++"""
3119
assert platform in platforms.ALL, 'Unknown platform %s. Use one of %s' % (platform, ', '.join(platforms.ALL))
32-
if isinstance(clf, (SVC, LinearSVC, OneClassSVM)):
20+
if is_svm(clf):
3321
return port_svm(**locals())
34-
elif isinstance(clf, RVC):
22+
elif is_rvm(clf):
3523
return port_rvm(**locals())
36-
elif isinstance(clf, PCA):
37-
return port_pca(pca=clf, **locals())
38-
elif isinstance(clf, SEFR):
24+
elif is_sefr(clf):
3925
return port_sefr(**locals())
4026
elif is_decisiontree(clf):
4127
return port_decisiontree(**locals())
@@ -45,4 +31,6 @@ def port(
4531
return port_logisticregression(**locals())
4632
elif is_gaussiannb(clf):
4733
return port_gaussiannb(**locals())
34+
elif is_pca(clf):
35+
return port_pca(**locals())
4836
raise TypeError('clf MUST be one of SVC, LinearSVC, OneClassSVC, RVC, DecisionTree, RandomForest, LogisticRegression, GaussianNB, SEFR, PCA')

micromlgen/patches.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

micromlgen/pca.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1-
from micromlgen.utils import jinja
1+
from micromlgen.utils import jinja, check_type
22

33

4-
def port_pca(pca, classname=None, **kwargs):
4+
def is_pca(clf):
5+
"""Test if classifier can be ported"""
6+
return check_type(clf, 'PCA')
7+
8+
9+
def port_pca(clf, **kwargs):
510
"""Port a PCA"""
6-
template_data = {
11+
return jinja('pca/pca.jinja', {
712
'arrays': {
8-
'components': pca.components_,
9-
'mean': pca.mean_
13+
'components': clf.components_,
14+
'mean': clf.mean_
1015
},
11-
'classname': classname if classname is not None else 'PCA'
12-
}
13-
return jinja('pca/pca.jinja', template_data)
16+
}, {
17+
'classname': 'PCA'
18+
}, **kwargs)

micromlgen/randomforest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
from sklearn.ensemble import RandomForestClassifier
2-
from micromlgen.utils import jinja
1+
from micromlgen.utils import jinja, check_type
32

43

54
def is_randomforest(clf):
65
"""Test if classifier can be ported"""
7-
return isinstance(clf, RandomForestClassifier)
6+
return check_type(clf, 'RandomForestClassifier')
87

98

109
def port_randomforest(clf, **kwargs):
1110
"""Port sklearn's RandomForestClassifier"""
12-
kwargs['classname'] = kwargs['classname'] or 'RandomForest'
1311
return jinja('randomforest/randomforest.jinja', {
1412
'n_classes': clf.n_classes_,
1513
'trees': [{
@@ -19,4 +17,6 @@ def port_randomforest(clf, **kwargs):
1917
'thresholds': clf.tree_.threshold,
2018
'classes': clf.tree_.value,
2119
} for clf in clf.estimators_]
22-
}, **locals())
20+
}, {
21+
'classname': 'RandomForest'
22+
}, **kwargs)

micromlgen/rvm.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1-
from micromlgen.utils import jinja
1+
from micromlgen.utils import jinja, check_type
22

33

4-
def port_rvm(clf, classname, **kwargs):
4+
def is_rvm(clf):
5+
"""Test if classifier can be ported"""
6+
return check_type(clf, 'RVC')
7+
8+
9+
def port_rvm(clf, **kwargs):
510
"""Port a RVM classifier"""
6-
assert classname is None or len(classname) > 0, 'Invalid class name'
7-
template_data = {
8-
**kwargs,
11+
return jinja('rvm/rvm.jinja', {
12+
'n_classes': len(clf.intercept_),
913
'kernel': {
1014
'type': clf.kernel,
1115
'gamma': clf.gamma,
1216
'coef0': clf.coef0,
1317
'degree': clf.degree
1418
},
1519
'sizes': {
16-
'features': len(clf.relevant_vectors_[0]),
20+
'features': clf.relevant_vectors_[0].shape[1],
1721
},
1822
'arrays': {
1923
'vectors': clf.relevant_vectors_,
@@ -23,6 +27,6 @@ def port_rvm(clf, classname, **kwargs):
2327
'mean': clf._x_mean,
2428
'std': clf._x_std
2529
},
26-
'classname': classname if classname is not None else 'RVM',
27-
}
28-
return jinja('rvm/rvm.jinja', template_data)
30+
}, {
31+
'classname': 'RVC'
32+
}, **kwargs)

micromlgen/sefr.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
from micromlgen.utils import jinja
1+
from micromlgen.utils import jinja, check_type
2+
3+
4+
def is_sefr(clf):
5+
"""Test if classifier can be ported"""
6+
return check_type(clf, 'SEFR')
27

38

49
def port_sefr(clf, classname=None, **kwargs):
510
"""Port SEFR classifier"""
6-
kwargs.update({
11+
return jinja('sefr/sefr.jinja', {
712
'weights': clf.weights,
813
'bias': clf.bias,
914
'dimension': len(clf.weights),
10-
'classname': classname or 'SEFR'
11-
})
12-
return jinja('sefr/sefr.jinja', kwargs)
15+
}, {
16+
'classname': 'SEFR'
17+
}, **kwargs)

micromlgen/svm.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
from sklearn.svm import OneClassSVM
1+
from micromlgen.utils import jinja, check_type
22

3-
from micromlgen.utils import jinja
43

4+
def is_svm(clf):
5+
"""Test if classifier can be ported"""
6+
return check_type(clf, 'SVC', 'LinearSVC', 'OneClassSVM')
57

6-
def port_svm(clf, classname=None, **kwargs):
8+
9+
def port_svm(clf, **kwargs):
710
"""Port a SVC / LinearSVC classifier"""
811
assert isinstance(clf.gamma, float), 'You probably didn\'t set an explicit value for gamma: 0.001 is a good default'
9-
assert classname is None or len(classname) > 0, 'Invalid class name'
10-
if classname is None:
11-
classname = 'OneClassSVM' if isinstance(clf, OneClassSVM) else 'SVM'
1212
support_v = clf.support_vectors_
1313
n_classes = len(clf.n_support_)
14-
template_data = {
15-
**kwargs,
14+
return jinja('svm/svm.jinja', {
1615
'kernel': {
1716
'type': clf.kernel,
1817
'gamma': clf.gamma,
@@ -30,7 +29,7 @@ def port_svm(clf, classname=None, **kwargs):
3029
'supports': support_v,
3130
'intercepts': clf.intercept_,
3231
'coefs': clf.dual_coef_
33-
},
34-
'classname': classname
35-
}
36-
return jinja('svm/svm.jinja', template_data)
32+
}
33+
}, {
34+
'classname': 'OneClassSVM' if check_type(clf, 'OneClassSVM') else 'SVM'
35+
}, **kwargs)

micromlgen/templates/_skeleton.jinja

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
namespace Eloquent {
4+
namespace ML {
5+
namespace Port {
6+
class {{ classname }} {
7+
public:
8+
9+
/**
10+
* Predict class for features vector
11+
*/
12+
int predict(float *x) {
13+
{% block predict %}{% endblock %}
14+
}
15+
16+
{% include 'classmap.jinja' %}
17+
18+
{% block public %}{% endblock %}
19+
20+
protected:
21+
22+
{% block protected %}{% endblock %}
23+
};
24+
}
25+
}
26+
}
Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,5 @@
1-
#pragma once
1+
{% extends '_skeleton.jinja' %}
22

3-
namespace Eloquent {
4-
namespace ML {
5-
namespace Port {
6-
7-
class {{ classname }} {
8-
public:
9-
10-
/**
11-
* Predict class for features vector
12-
*/
13-
int predict(float *x) {
14-
{% include 'decisiontree/tree.jinja' %}
15-
}
16-
17-
{% include 'classmap.jinja' %}
18-
};
19-
}
20-
}
21-
}
3+
{% block predict %}
4+
{% include 'decisiontree/tree.jinja' %}
5+
{% endblock %}

0 commit comments

Comments
 (0)