Skip to content
This repository was archived by the owner on May 25, 2024. It is now read-only.

Commit aea8a9b

Browse files
committed
pca + update README
1 parent 3775a6a commit aea8a9b

File tree

11 files changed

+82
-18
lines changed

11 files changed

+82
-18
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
venv
22
.idea
33
tests
4-
dist
4+
dist
5+
micromlgen/__pycache__

MANIFEST

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ micromlgen/micromlgen.py
66
micromlgen/platforms.py
77
micromlgen/utils.py
88
micromlgen/templates/classmap.jinja
9-
micromlgen/templates/xy.jinja
9+
micromlgen/templates/testset.jinja
1010
micromlgen/templates/rvm/rvm.jinja
1111
micromlgen/templates/svm/svm.jinja
1212
micromlgen/templates/svm/computations/class.jinja

README.md

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ to an introduction to the topic.
88

99
`pip install micromlgen`
1010

11-
## Use
11+
## Support (and Relevant) Vector Machines
12+
13+
`micromlgen` can port to plain C SVM-based (SVC, LinearSVC, OneClassSVM)
14+
and RVM-based (from `skbayes.rvm_ard_models` package) classifiers.
1215

1316
```python
1417
from micromlgen import port
@@ -44,18 +47,18 @@ if __name__ == '__main__':
4447
}))
4548
```
4649

47-
You can pass a test set to generate self test code
50+
## PCA
51+
52+
It can export a PCA transformer.
4853

4954
```python
50-
from micromlgen import port
51-
from sklearn.svm import SVC
55+
from sklearn.decomposition import PCA
5256
from sklearn.datasets import load_iris
53-
57+
from micromlgen import port
5458

5559
if __name__ == '__main__':
56-
iris = load_iris()
57-
X_train, X_test = iris.data[:-10, :], iris.data[-10:, :]
58-
y_train, y_test = iris.target[:-10], iris.target[-10:]
59-
clf = SVC(kernel='linear').fit(X_train, y_train)
60-
print(port(clf, test_set=(X_test, y_test)))
60+
X = load_iris().data
61+
pca = PCA(n_components=2, whiten=False).fit(X)
62+
63+
print(port(pca))
6164
```
-297 Bytes
Binary file not shown.
-2.07 KB
Binary file not shown.
-208 Bytes
Binary file not shown.
-1.8 KB
Binary file not shown.

micromlgen/micromlgen.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from skbayes.rvm_ard_models import RVC
22
from sklearn.svm import SVC, LinearSVC, OneClassSVM
3+
from sklearn.decomposition import PCA
34

45
from micromlgen import platforms
56
from micromlgen.utils import jinja
@@ -65,6 +66,18 @@ def port_svm(clf, classname=None, **kwargs):
6566
return jinja('svm/svm.jinja', template_data)
6667

6768

69+
def port_pca(pca, classname=None, **kwargs):
70+
"""Port a PCA"""
71+
template_data = {
72+
'arrays': {
73+
'components': pca.components_,
74+
'mean': pca.mean_
75+
},
76+
'classname': classname if classname is not None else 'PCA'
77+
}
78+
return jinja('pca/pca.jinja', template_data)
79+
80+
6881
def port(
6982
clf,
7083
classname=None,
@@ -76,4 +89,6 @@ def port(
7689
return port_svm(**locals())
7790
elif isinstance(clf, RVC):
7891
return port_rvm(**locals())
79-
raise TypeError('clf MUST be one of SVC, LinearSVC, OneClassSVC, RVC')
92+
elif isinstance(clf, PCA):
93+
return port_pca(pca=clf, **locals())
94+
raise TypeError('clf MUST be one of SVC, LinearSVC, OneClassSVC, RVC, PCA')

micromlgen/templates/pca/pca.jinja

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#pragma once
2+
3+
namespace Eloquent {
4+
namespace ML {
5+
namespace Port {
6+
7+
class {{ classname }} {
8+
public:
9+
10+
/**
11+
* Apply dimensionality reduction
12+
* @warn Will override the source vector if no dest provided!
13+
*/
14+
void transform(float *x, float *dest = NULL) {
15+
static float u[{{ arrays.components|length }}] = { 0 };
16+
17+
{% for i, component in f.enumerate(arrays.components) %}
18+
u[{{ i }}] = dot(x, {% for j, cj in f.enumerate(component) %} {% if j > 0 %},{% endif %} {{ f.round(cj) }} {% endfor %});
19+
{% endfor %}
20+
21+
memcpy(dest != NULL ? dest : x, u, sizeof(float) * {{ arrays.components|length }});
22+
}
23+
24+
protected:
25+
26+
/**
27+
* Compute dot product with varargs
28+
*/
29+
float dot(float *x, ...) {
30+
va_list w;
31+
va_start(w, {{ arrays.components[0]|length }});
32+
33+
static float mean[] = { {% for i, m in f.enumerate(arrays.mean) %}{% if i > 0 %},{% endif %} {{ f.round(m) }} {% endfor %} };
34+
float dot = 0.0;
35+
36+
for (uint16_t i = 0; i < {{ arrays.components[0]|length }}; i++) {
37+
dot += (x[i] - mean[i]) * va_arg(w, double);
38+
}
39+
40+
return dot;
41+
}
42+
};
43+
}
44+
}
45+
}

micromlgen/templates/rvm/rvm.jinja

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ namespace Eloquent {
1010

1111
}
1212

13-
uint8_t predict(double *x) {
14-
double decision[{{ arrays.vectors|length }}] = { 0 };
13+
uint8_t predict(float *x) {
14+
float decision[{{ arrays.vectors|length }}] = { 0 };
1515
{% for i, (rv, cf, act, b) in f.enumerate(f.zip(arrays.vectors, arrays.coefs, arrays.actives, arrays.intercepts)) %}
1616
{% if rv.shape[0] == 0 %}
1717
decision[{{ i }}] = {{ b }};
@@ -21,7 +21,7 @@ namespace Eloquent {
2121
{% endfor %}
2222

2323
uint8_t idx = 0;
24-
double val = decision[0];
24+
float val = decision[0];
2525

2626
for (uint8_t i = 1; i < {{ arrays.vectors|length }}; i++) {
2727
if (decision[i] > val) {

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
setup(
88
name = 'micromlgen',
99
packages = ['micromlgen'],
10-
version = '1.0.2',
10+
version = '1.1.0',
1111
license='MIT',
1212
description = 'Generate C code for microcontrollers from Python\'s sklearn classifiers',
1313
author = 'Simone Salerno',
1414
author_email = 'eloquentarduino@gmail.com',
1515
url = 'https://github.com/eloquentarduino/micromlgen',
16-
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_10ter.tar.gz',
16+
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_11.tar.gz',
1717
keywords = [
1818
'ML',
1919
'microcontrollers',

0 commit comments

Comments
 (0)