Skip to content
This repository was archived by the owner on May 25, 2024. It is now read-only.

Commit ae129f3

Browse files
committed
add RVC
1 parent 46c501c commit ae129f3

File tree

11 files changed

+60
-12
lines changed

11 files changed

+60
-12
lines changed

MANIFEST

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ micromlgen/templates/kernel_function.jinja
1515
micromlgen/templates/pca_function.jinja
1616
micromlgen/templates/self_test.jinja
1717
micromlgen/templates/svm.jinja
18+
micromlgen/templates/pca/_dot_product.jinja
19+
micromlgen/templates/pca/pca.jinja

dist/micromlgen-0.11.tar.gz

-3.88 KB
Binary file not shown.

dist/micromlgen-0.12.tar.gz

4.05 KB
Binary file not shown.

micromlgen/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from micromlgen.micromlgen import port, port_pca
1+
from micromlgen.micromlgen import port, port_pca, port_rvm
19 Bytes
Binary file not shown.
109 Bytes
Binary file not shown.

micromlgen/micromlgen.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,26 @@ def port_pca(pca):
2424
})
2525

2626

27+
def port_rvm(clf, classmap=None, test_set=None, **kwargs):
28+
from skbayes.rvm_ard_models import RVC
29+
assert isinstance(clf, RVC), 'Not an RVC classifier'
30+
return jinja('rvm/rvm.jinja', {
31+
'clf': clf,
32+
'FEATURES_DIM': clf.relevant_vectors_[0].shape[1],
33+
'KERNEL_TYPE': clf.kernel,
34+
'KERNEL_GAMMA': clf.gamma,
35+
'KERNEL_COEF': clf.coef0,
36+
'KERNEL_DEGREE': clf.degree,
37+
'N_CLASSES': len(clf.classes_),
38+
'classmap': classmap,
39+
'X': test_set[0] if test_set else None,
40+
'y': test_set[1] if test_set else None,
41+
'enumerate': enumerate,
42+
'zip': zip,
43+
'round': round
44+
})
45+
46+
2747
def port(clf,
2848
test_set=None,
2949
classmap=None,
@@ -56,3 +76,4 @@ def port(clf,
5676
'isAttiny': platform == 'attiny',
5777
}
5878
return jinja('svm.jinja', template_data)
79+

micromlgen/templates/pca/pca.jinja

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,16 @@
44
* Apply PCA to x vector.
55
* Overrides the original vector.
66
*/
7-
void pca(float x[{{ X_DIM }}]) {
8-
float u[{{ PCA_DIM }}];
7+
void pca(double x[{{ X_DIM }}]) {
8+
double u[{{ PCA_DIM }}] = { 0 };
99

10-
{% for vector in pca_components %}
11-
u[{{ loop.index - 1 }}] = dot_product(x, {{ X_DIM }}, {% for vi in vector %} {% if loop.index > 1 %}, {% endif %} {{ F.round(vi, 9) }} {% endfor %});
10+
{% for i, vector in F.enumerate(pca_components) %}
11+
{% for chunk in F.chunk(vector, 100) %}
12+
u[{{ i }}] += dot_product(x, {{ chunk|length }}, {% for vi in chunk %} {% if loop.index > 1 %}, {% endif %} {{ F.round(vi, 9) }} {% endfor %});
13+
{% endfor %}
1214
{% endfor %}
1315

1416
{% for vector in pca_components %}
1517
x[{{ loop.index - 1 }}] = u[{{ loop.index - 1 }}];
1618
{% endfor %}
17-
18-
{% for i in range(PCA_DIM, X_DIM) %}
19-
x[{{ i }}] = 0;
20-
{% endfor %}
2119
}

micromlgen/templates/rvm/rvm.jinja

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
{% include 'kernel_function.jinja' %}
2+
3+
uint8_t predict(double *x) {
4+
double decision[{{ clf.relevant_vectors_|length }}] = { 0 };
5+
{% for i, (rv, cf, act, b) in enumerate(zip(clf.relevant_vectors_, clf.coef_, clf.active_, clf.intercept_)) %}
6+
{% if rv.shape[0] == 0 %}
7+
decision[{{ i }}] = {{ b }};
8+
{% else %}
9+
decision[{{ i }}] = (compute_kernel(x, {% for vi in rv[0] %}{% if loop.index > 1 %},{% endif %} {{ round(vi, 7) }}{% endfor %}) - {{ clf._x_mean[act][0] }} ) * {{ cf[act][0] / clf._x_std[act][0] }} + {{ b }};
10+
{% endif %}
11+
{% endfor %}
12+
13+
uint8_t idx = 0;
14+
double val = decision[0];
15+
16+
for (uint8_t i = 1; i < {{ clf.relevant_vectors_|length }}; i++) {
17+
if (decision[i] > val) {
18+
idx = i;
19+
val = decision[i];
20+
}
21+
}
22+
23+
return idx;
24+
}
25+
26+
{% include 'self_test.jinja' %}
27+
{% include 'classmap.jinja' %}

micromlgen/templates/self_test.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ void self_test() {
1111
{% endfor %}
1212
};
1313

14-
int y[{{ X|length }}] = { {% for yi in y %}{% if loop.index > 1 %},{% endif %} {{ yi }} {% endfor %} };
14+
int y[{{ X|length }}] = { {% for yi in y %}{% if loop.index > 1 %},{% endif %} {{ yi|int }} {% endfor %} };
1515

1616
for (int i = 0; i < {{ X|length }}; i++) {
1717
int predicted = predict(X[i]);

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
setup(
33
name = 'micromlgen',
44
packages = ['micromlgen'],
5-
version = '0.12',
5+
version = '0.13',
66
license='MIT',
77
description = 'Generate C code for microcontrollers from Python\'s sklearn classifiers',
88
author = 'Simone Salerno',
99
author_email = 'eloquentarduino@gmail.com',
1010
url = 'https://github.com/eloquentarduino/micromlgen',
11-
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_012.tar.gz',
11+
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_013.tar.gz',
1212
keywords = ['ML', 'microcontrollers', 'sklearn', 'machine learning'],
1313
install_requires=[
1414
'jinja2',

0 commit comments

Comments
 (0)