Skip to content
This repository was archived by the owner on May 25, 2024. It is now read-only.

Commit b4a809d

Browse files
committed
add PrincipalFFT
1 parent 3c0fcd0 commit b4a809d

File tree

8 files changed

+107
-6
lines changed

8 files changed

+107
-6
lines changed

micromlgen/micromlgen.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
from micromlgen.logisticregression import is_logisticregression, port_logisticregression
88
from micromlgen.gaussiannb import is_gaussiannb, port_gaussiannb
99
from micromlgen.pca import is_pca, port_pca
10+
from micromlgen.principalfft import is_principalfft, port_principalfft
1011

1112

1213
def port(
1314
clf,
1415
classname=None,
1516
classmap=None,
1617
platform=platforms.ARDUINO,
17-
precision=None):
18+
precision=None,
19+
**kwargs):
1820
"""Port a classifier to plain C++"""
1921
assert platform in platforms.ALL, 'Unknown platform %s. Use one of %s' % (platform, ', '.join(platforms.ALL))
2022
if is_svm(clf):
@@ -33,4 +35,6 @@ def port(
3335
return port_gaussiannb(**locals())
3436
elif is_pca(clf):
3537
return port_pca(**locals())
38+
elif is_principalfft(clf):
39+
return port_principalfft(**locals(), **kwargs)
3640
raise TypeError('clf MUST be one of %s' % ', '.join(platforms.ALLOWED_CLASSIFIERS))

micromlgen/platforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
'RandomForest',
1616
'GaussianNB',
1717
'LogisticRegression',
18-
'PCA'
18+
'PCA',
19+
'PrincipalFFT'
1920
]

micromlgen/principalfft.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from micromlgen.utils import jinja, check_type
2+
from math import pi
3+
4+
5+
def is_principalfft(clf):
6+
"""Test if classifier can be ported"""
7+
return check_type(clf, 'PrincipalFFT')
8+
9+
10+
def port_principalfft(clf, optimize_sin=False, lookup_cos=None, lookup_sin=None, **kwargs):
11+
"""Port PrincipalFFT classifier"""
12+
return jinja('principalfft/principalfft.jinja', {
13+
'fft': clf,
14+
'PI': pi,
15+
'size': len(clf.idx),
16+
'optmize_sin': optimize_sin,
17+
'lookup_cos': lookup_cos,
18+
'lookup_sin': lookup_sin,
19+
}, **kwargs)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
const float {{ op }}LUT[{{ size }}][{{ fft.original_size }}] = {
2+
{% for i in range(0, size) %}
3+
{ {% for n in range(0, fft.original_size) %} {{ math[op](2 * PI / fft.original_size * fft.idx[i] * n) }}, {% endfor %} },
4+
{% endfor %}
5+
};
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
const bool {{ op }}LUT[{{ size }}][{{ fft.original_size }}] = {
2+
{% for i in range(0, size) %}
3+
{ {% for n in range(0, fft.original_size) %} {{ "true" if math[op](2 * PI / fft.original_size * fft.idx[i] * n) > 0 else "false" }}, {% endfor %} },
4+
{% endfor %}
5+
};
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
void principalFFT(float *features, float *fft) {
2+
// apply principal FFT (naive implementation for the top N frequencies only)
3+
const int topFrequencies[] = { {{ f.to_array(fft.idx, True) }} };
4+
5+
{% if lookup_cos %}
6+
{% with op="cos" %}
7+
{% include "principalfft/lut.jinja" %}
8+
{% endwith %}
9+
10+
{# sin lookup is available only if cos lookup is used #}
11+
{% if lookup_sin %}
12+
{% with op="sin" %}
13+
{% include "principalfft/lut.jinja" %}
14+
{% endwith %}
15+
{% else %}
16+
{% with op="sin" %}
17+
{% include "principalfft/lut_bool.jinja" %}
18+
{% endwith %}
19+
{% endif %}
20+
{% endif %}
21+
22+
for (int i = 0; i < {{ size }}; i++) {
23+
const int k = topFrequencies[i];
24+
{% if not lookup_cos %}
25+
const float harmonic = {{ 2 * PI / fft.original_size }} * k;
26+
{% endif %}
27+
float re = 0;
28+
float im = 0;
29+
30+
// optimized case
31+
if (k == 0) {
32+
for (int n = 0; n < {{ fft.original_size }}; n++) {
33+
re += features[n];
34+
}
35+
}
36+
else {
37+
for (int n = 0; n < {{ fft.original_size }}; n++) {
38+
{% if lookup_cos %}
39+
const float cos_n = cosLUT[i][n];
40+
41+
{% if lookup_sin %}
42+
const float sin_n = sinLUT[i][n];
43+
{% else %}
44+
const float sin_n = sinLUT[i][n] * sqrt(1 - cos_n * cos_n);
45+
{% endif %}
46+
{% else %}
47+
const float harmonicN = harmonic * n;
48+
const float cos_n = cos(harmonicN);
49+
const float sin_n = sin(harmonicN);
50+
{% endif %}
51+
52+
re += features[n] * cos_n;
53+
im -= features[n] * sin_n;
54+
}
55+
}
56+
57+
fft[i] = sqrt(re * re + im * im);
58+
}
59+
}

micromlgen/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import re
3+
from math import sin, cos
34
from inspect import getmro
45
from jinja2 import FileSystemLoader, Environment
56

@@ -51,16 +52,23 @@ def jinja(template_file, data, defaults=None, **kwargs):
5152
template = Environment(loader=loader).get_template(template_file)
5253
data = {k: v for k, v in data.items() if v is not None}
5354
kwargs = {k: v for k, v in kwargs.items() if v is not None}
55+
precision = data.get('precision', 12) or 12
56+
precision_fmt = '%.' + str(precision) + 'f'
5457
if defaults is None:
5558
defaults = {}
5659
defaults.setdefault('platform', 'arduino')
5760
defaults.setdefault('classmap', None)
5861
defaults.update({
5962
'f': {
6063
'enumerate': enumerate,
61-
'round': lambda x: round(x, data.get('precision', 12) or 12),
64+
'round': lambda x: round(x, precision),
6265
'zip': zip,
63-
'signed': lambda x: '' if x == 0 else '+' + str(x) if x >= 0 else x
66+
'signed': lambda x: '' if x == 0 else '+' + str(x) if x >= 0 else x,
67+
'to_array': lambda x, as_int=False: ', '.join([precision_fmt % xx if not as_int else str(xx) for xx in x])
68+
},
69+
'math': {
70+
'cos': cos,
71+
'sin': sin
6472
}
6573
})
6674
data = {

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
setup(
88
name = 'micromlgen',
99
packages = ['micromlgen'],
10-
version = '1.1.8',
10+
version = '1.1.9',
1111
license='MIT',
1212
description = 'Generate C code for microcontrollers from Python\'s sklearn classifiers',
1313
author = 'Simone Salerno',
1414
author_email = 'eloquentarduino@gmail.com',
1515
url = 'https://github.com/eloquentarduino/micromlgen',
16-
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_118.tar.gz',
16+
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_119.tar.gz',
1717
keywords = [
1818
'ML',
1919
'microcontrollers',

0 commit comments

Comments
 (0)