Skip to content
This repository was archived by the owner on May 25, 2024. It is now read-only.

Commit b5beb91

Browse files
committed
refactored SVM + RVM, deleted PCA for now
1 parent dbaa25f commit b5beb91

33 files changed

+302
-348
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
venv
22
.idea
3-
tests
3+
tests
4+
dist

MANIFEST

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,15 @@ setup.cfg
33
setup.py
44
micromlgen/__init__.py
55
micromlgen/micromlgen.py
6-
micromlgen/templates/_scalar_product.jinja
7-
micromlgen/templates/binary_classification.jinja
6+
micromlgen/platforms.py
87
micromlgen/templates/classmap.jinja
9-
micromlgen/templates/compute_class.jinja
10-
micromlgen/templates/compute_decisions.jinja
11-
micromlgen/templates/compute_kernels.bck.jinja
12-
micromlgen/templates/compute_kernels.jinja
13-
micromlgen/templates/compute_votes.jinja
14-
micromlgen/templates/kernel_function.jinja
15-
micromlgen/templates/pca_function.jinja
16-
micromlgen/templates/self_test.jinja
17-
micromlgen/templates/svm.jinja
18-
micromlgen/templates/pca/_dot_product.jinja
19-
micromlgen/templates/pca/pca.jinja
8+
micromlgen/templates/rvm/rvm.jinja
9+
micromlgen/templates/svm/svm.jinja
10+
micromlgen/templates/svm/computations/class.jinja
11+
micromlgen/templates/svm/computations/decisions.jinja
12+
micromlgen/templates/svm/computations/votes.jinja
13+
micromlgen/templates/svm/computations/kernel/arduino.jinja
14+
micromlgen/templates/svm/computations/kernel/attiny.jinja
15+
micromlgen/templates/svm/kernel/arduino.jinja
16+
micromlgen/templates/svm/kernel/attiny.jinja
17+
micromlgen/templates/svm/kernel/kernel.jinja

dist/micromlgen-0.13.tar.gz

-4.14 KB
Binary file not shown.

micromlgen/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from micromlgen.micromlgen import port, port_pca, port_rvm
1+
import micromlgen.platforms as platforms
2+
from micromlgen.micromlgen import port
7 Bytes
Binary file not shown.
958 Bytes
Binary file not shown.
208 Bytes
Binary file not shown.

micromlgen/micromlgen.py

Lines changed: 74 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import os
22
import re
3-
from math import factorial
3+
from micromlgen import platforms
4+
from sklearn.svm import SVC, LinearSVC, OneClassSVM
5+
from skbayes.rvm_ard_models import RVC
46
from jinja2 import FileSystemLoader, Environment
57

68

79
def jinja(template_file, data):
10+
"""Render Jinja template"""
811
dir_path = os.path.dirname(os.path.realpath(__file__))
912
loader = FileSystemLoader(dir_path + '/templates')
1013
template = Environment(loader=loader).get_template(template_file)
@@ -13,67 +16,85 @@ def jinja(template_file, data):
1316
return code
1417

1518

16-
def port_pca(pca):
17-
return jinja('pca/pca.jinja', {
18-
'X_DIM': pca.components_.shape[1],
19-
'PCA_DIM': pca.components_.shape[0],
20-
'F': {
21-
'round': round
22-
},
23-
'pca_components': pca.components_,
19+
def _port(template, data):
20+
"""Add common template data before rendering"""
21+
data.update(**{
22+
'f': {
23+
'enumerate': enumerate,
24+
'round': lambda x: round(x, data.get('precision', 9) or 9),
25+
'zip': zip
26+
}
2427
})
28+
return jinja(template, data)
2529

2630

27-
def port_rvm(clf, classmap=None, test_set=None, **kwargs):
28-
from skbayes.rvm_ard_models import RVC
29-
assert isinstance(clf, RVC), 'Not an RVC classifier'
30-
return jinja('rvm/rvm.jinja', {
31-
'clf': clf,
32-
'FEATURES_DIM': clf.relevant_vectors_[0].shape[1],
33-
'KERNEL_TYPE': clf.kernel,
34-
'KERNEL_GAMMA': clf.gamma,
35-
'KERNEL_COEF': clf.coef0,
36-
'KERNEL_DEGREE': clf.degree,
37-
'N_CLASSES': len(clf.classes_),
38-
'classmap': classmap,
39-
'X': test_set[0] if test_set else None,
40-
'y': test_set[1] if test_set else None,
41-
'enumerate': enumerate,
42-
'zip': zip,
43-
'round': round
44-
})
31+
def port_rvm(clf, classname, **kwargs):
32+
"""Port a RVM classifier"""
33+
assert classname is None or len(classname) > 0, 'Invalid class name'
34+
template_data = {
35+
**kwargs,
36+
'kernel': {
37+
'type': clf.kernel,
38+
'gamma': clf.gamma,
39+
'coef0': clf.coef0,
40+
'degree': clf.degree
41+
},
42+
'sizes': {
43+
'features': len(clf.relevant_vectors_[0]),
44+
},
45+
'arrays': {
46+
'vectors': clf.relevant_vectors_,
47+
'coefs': clf.coef_,
48+
'actives': clf.active_,
49+
'intercepts': clf.intercept_,
50+
'mean': clf._x_mean,
51+
'std': clf._x_std
52+
},
53+
'classname': classname if classname is not None else 'RVM',
54+
}
55+
return _port('rvm/rvm.jinja', template_data)
4556

4657

47-
def port(clf,
48-
test_set=None,
49-
classmap=None,
50-
platform='arduino',
51-
**kwargs):
52-
assert type(clf).__name__ == 'SVC', 'Only sklearn.svm.SVC is supported for now'
58+
def port_svm(clf, classname=None, **kwargs):
59+
"""Port a SVC / LinearSVC classifier"""
5360
assert isinstance(clf.gamma, float), 'You probably didn\'t set an explicit value for gamma: 0.001 is a good default'
61+
assert classname is None or len(classname) > 0, 'Invalid class name'
5462
support_v = clf.support_vectors_
5563
n_classes = len(clf.n_support_)
5664
template_data = {
57-
'KERNEL_TYPE': clf.kernel,
58-
'KERNEL_GAMMA': clf.gamma,
59-
'KERNEL_COEF': clf.coef0,
60-
'KERNEL_DEGREE': clf.degree,
61-
'FEATURES_DIM': len(support_v[0]),
62-
'VECTORS_COUNT': len(support_v),
63-
'CLASSES_COUNT': n_classes,
64-
'DECISIONS_COUNT': n_classes * (n_classes - 1) // 2,
65-
'support_v': support_v,
66-
'n_support': clf.n_support_,
67-
'intercepts': clf.intercept_,
68-
'coefs': clf.dual_coef_,
69-
'X': test_set[0] if test_set else None,
70-
'y': test_set[1] if test_set else None,
71-
'classmap': classmap,
72-
'F': {
73-
'enumerate': enumerate,
74-
'round': round
65+
**kwargs,
66+
'kernel': {
67+
'type': clf.kernel,
68+
'gamma': clf.gamma,
69+
'coef0': clf.coef0,
70+
'degree': clf.degree
7571
},
76-
'isAttiny': platform == 'attiny',
72+
'sizes': {
73+
'features': len(support_v[0]),
74+
'vectors': len(support_v),
75+
'classes': n_classes,
76+
'decisions': n_classes * (n_classes - 1) // 2,
77+
'supports': clf.n_support_
78+
},
79+
'arrays': {
80+
'supports': support_v,
81+
'intercepts': clf.intercept_,
82+
'coefs': clf.dual_coef_
83+
},
84+
'classname': classname if classname is not None else 'SVM'
7785
}
78-
return jinja('svm.jinja', template_data)
86+
return _port('svm/svm.jinja', template_data)
87+
7988

89+
def port(
90+
clf,
91+
classname=None,
92+
classmap=None,
93+
platform=platforms.ARDUINO,
94+
precision=None):
95+
assert platform in platforms.ALL, 'Unknown platform %s. Use one of %s' % (platform, ', '.join(platforms.ALL))
96+
if isinstance(clf, (SVC, LinearSVC, OneClassSVM)):
97+
return port_svm(**locals())
98+
elif isinstance(clf, RVC):
99+
return port_rvm(**locals())
100+
raise TypeError('clf MUST be one of SVC, LinearSVC, OneClassSVC, RVC')

micromlgen/platforms.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
ARDUINO = 'arduino'
2+
ATTINY = 'attiny'
3+
ALL = [
4+
ARDUINO,
5+
ATTINY
6+
]

micromlgen/templates/_scalar_product.jinja

Lines changed: 0 additions & 14 deletions
This file was deleted.

micromlgen/templates/binary_classification.jinja

Lines changed: 0 additions & 6 deletions
This file was deleted.

micromlgen/templates/classmap.jinja

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ const char* classIdxToName(uint8_t classIdx) {
1010
return "{{ name }}";
1111
{% endfor %}
1212
default:
13-
return "UNKNOWN";
13+
return "Houston we have a problem";
1414
}
1515
}
16+
1617
{% endif %}

micromlgen/templates/compute_class.jinja

Lines changed: 0 additions & 11 deletions
This file was deleted.

micromlgen/templates/compute_decisions.jinja

Lines changed: 0 additions & 32 deletions
This file was deleted.

micromlgen/templates/compute_kernels.bck.jinja

Lines changed: 0 additions & 10 deletions
This file was deleted.

micromlgen/templates/compute_kernels.jinja

Lines changed: 0 additions & 12 deletions
This file was deleted.

micromlgen/templates/kernel_function.jinja

Lines changed: 0 additions & 34 deletions
This file was deleted.

micromlgen/templates/pca/_dot_product.jinja

Lines changed: 0 additions & 14 deletions
This file was deleted.

micromlgen/templates/pca/pca.jinja

Lines changed: 0 additions & 19 deletions
This file was deleted.

0 commit comments

Comments
 (0)