Skip to content
This repository was archived by the owner on May 25, 2024. It is now read-only.

Commit b969b1a

Browse files
committed
port_trainset
1 parent 0bac123 commit b969b1a

File tree

5 files changed

+40
-4
lines changed

5 files changed

+40
-4
lines changed

MANIFEST

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ micromlgen/platforms.py
88
micromlgen/utils.py
99
micromlgen/templates/classmap.jinja
1010
micromlgen/templates/testset.jinja
11+
micromlgen/templates/trainset.jinja
1112
micromlgen/templates/pca/pca.jinja
1213
micromlgen/templates/rvm/rvm.jinja
1314
micromlgen/templates/svm/svm.jinja

micromlgen/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
import micromlgen.platforms as platforms
22
from micromlgen.micromlgen import port
3-
from micromlgen.utils import port_testset
3+
from micromlgen.utils import port_testset, port_trainset

micromlgen/templates/trainset.jinja

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
namespace Eloquent {
4+
namespace ML {
5+
namespace Train {
6+
7+
/**
8+
* A tailor made training set
9+
*/
10+
class {{ classname }} {
11+
public:
12+
{{ classname }}() :
13+
_x{ {% for x in X %}
14+
{ {% for xi in x %} {% if loop.index > 1 %}, {% endif %} {{ f.round(xi) }} {% endfor %} },
15+
{% endfor %}
16+
},
17+
_y{ {% for yi in y %} {% if loop.index > 1 %}, {% endif %} {{ f.round(yi) }} {% endfor %} }
18+
{}
19+
20+
template<class Classifier>
21+
void fit(Classifier clf) {
22+
clf.fit(_x, _y, {{ X|length }});
23+
}
24+
25+
protected:
26+
float _x[{{ X|length }}][{{ X[0]|length }}];
27+
int _y[{{ X|length }}];
28+
};
29+
}
30+
}
31+
}

micromlgen/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,9 @@ def jinja(template_file, data):
4646
return prettify(code)
4747

4848

49+
def port_trainset(X, y, classname='TrainSet'):
50+
return jinja('trainset.jinja', locals())
51+
52+
4953
def port_testset(X, y, classname='TestSet'):
50-
return jinja('testset.jinja', locals())
54+
return jinja('testset.jinja', locals())

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
setup(
88
name = 'micromlgen',
99
packages = ['micromlgen'],
10-
version = '1.1.1',
10+
version = '1.1.3',
1111
license='MIT',
1212
description = 'Generate C code for microcontrollers from Python\'s sklearn classifiers',
1313
author = 'Simone Salerno',
1414
author_email = 'eloquentarduino@gmail.com',
1515
url = 'https://github.com/eloquentarduino/micromlgen',
16-
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_111.tar.gz',
16+
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_113.tar.gz',
1717
keywords = [
1818
'ML',
1919
'microcontrollers',

0 commit comments

Comments
 (0)