Skip to content

Commit ca437f3

Browse files
fix: Fix CrossValidate to have tests for when data too small
Also upgrade to use newer es6 syntax for defining optional properties. Upgrade examples to be more straightforward when using CrossValidate. Fix values from being set to NaN when training with smaller data in CrossValidate.
1 parent f0a1a56 commit ca437f3

File tree

14 files changed

+85
-53
lines changed

14 files changed

+85
-53
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,9 @@ With multiple networks you can train in parallel like this:
279279
### Cross Validation
280280
[Cross Validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)) can provide a less fragile way of training on larger data sets. The brain.js api provides Cross Validation in this example:
281281
```js
282-
const crossValidate = new CrossValidate(brain.NeuralNetwork, networkOptions);
283-
const stats = crossValidate.train(data, trainingOptions, k); //note k (or KFolds) is optional
282+
const crossValidate = new brain.CrossValidate(brain.NeuralNetwork, networkOptions);
283+
crossValidate.train(data, trainingOptions, k); //note k (or KFolds) is optional
284+
const json = crossValidate.toJSON(); // all stats in json as well as neural networks
284285
const net = crossValidate.toNeuralNetwork();
285286

286287

bower.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@
3131
"node_modules",
3232
"test"
3333
],
34-
"version": "1.4.1"
34+
"version": "1.4.2"
3535
}

browser.js

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* license: MIT (http://opensource.org/licenses/MIT)
77
* author: Heather Arthur <fayearthur@gmail.com>
88
* homepage: https://github.com/brainjs/brain.js#readme
9-
* version: 1.4.1
9+
* version: 1.4.2
1010
*
1111
* acorn:
1212
* license: MIT (http://opensource.org/licenses/MIT)
@@ -214,8 +214,13 @@ var CrossValidate = function () {
214214

215215
}, {
216216
key: "train",
217-
value: function train(data, trainOpts, k) {
218-
k = k || 4;
217+
value: function train(data) {
218+
var trainOpts = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
219+
var k = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 4;
220+
221+
if (data.length <= k) {
222+
throw new Error("Training set size is too small for " + data.length + " k folds of " + k);
223+
}
219224
var size = data.length / k;
220225

221226
if (data.constructor === Array) {
@@ -1946,8 +1951,8 @@ var NeuralNetwork = function () {
19461951
falseNeg: falseNeg,
19471952
falsePos: falsePos,
19481953
total: data.length,
1949-
precision: truePos / (truePos + falsePos),
1950-
recall: truePos / (truePos + falseNeg),
1954+
precision: truePos > 0 ? truePos / (truePos + falsePos) : 0,
1955+
recall: truePos > 0 ? truePos / (truePos + falseNeg) : 0,
19511956
accuracy: (trueNeg + truePos) / data.length
19521957
});
19531958
}

browser.min.js

Lines changed: 7 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/cross-validate.js

Lines changed: 7 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/cross-validate.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples-typescript/cross-validate.ts

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,11 @@ const trainingData = [
88
{ input: [1, 1], output: [0] },
99
{ input: [1, 0], output: [1] },
1010

11-
// xor repeats
11+
// repeat xor data to have enough to train with
1212
{ input: [0, 1], output: [1] },
1313
{ input: [0, 0], output: [0] },
1414
{ input: [1, 1], output: [0] },
15-
{ input: [1, 0], output: [1] },
16-
17-
// xor repeats
18-
{ input: [0, 1], output: [1] },
19-
{ input: [0, 0], output: [0] },
20-
{ input: [1, 1], output: [0] },
21-
{ input: [1, 0], output: [1] },
22-
23-
// xor repeats
24-
{ input: [0, 1], output: [1] },
25-
{ input: [0, 0], output: [0] },
26-
{ input: [1, 1], output: [0] },
27-
{ input: [1, 0], output: [1] },
15+
{ input: [1, 0], output: [1] }
2816
];
2917

3018
const netOptions = {

examples/cross-validate.js

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,11 @@ const trainingData = [
88
{ input: [1, 1], output: [0] },
99
{ input: [1, 0], output: [1] },
1010

11-
// xor repeats
11+
// repeat xor data to have enough to train with
1212
{ input: [0, 1], output: [1] },
1313
{ input: [0, 0], output: [0] },
1414
{ input: [1, 1], output: [0] },
15-
{ input: [1, 0], output: [1] },
16-
17-
// xor repeats
18-
{ input: [0, 1], output: [1] },
19-
{ input: [0, 0], output: [0] },
20-
{ input: [1, 1], output: [0] },
21-
{ input: [1, 0], output: [1] },
22-
23-
// xor repeats
24-
{ input: [0, 1], output: [1] },
25-
{ input: [0, 0], output: [0] },
26-
{ input: [1, 1], output: [0] },
27-
{ input: [1, 0], output: [1] },
15+
{ input: [1, 0], output: [1] }
2816
];
2917

3018
const netOptions = {

0 commit comments

Comments
 (0)