Skip to content

Commit 5e7a4e5

Browse files
committed
Committing before I change to promise so I don't break things too bad
1 parent 7556c1c commit 5e7a4e5

File tree

6 files changed

+216
-208
lines changed

6 files changed

+216
-208
lines changed

browser.js

Lines changed: 83 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -964,22 +964,7 @@ var NeuralNetwork = function () {
964964

965965
_createClass(NeuralNetwork, [{
966966
key: 'initialize',
967-
value: function initialize(data) {
968-
var sizes = [];
969-
var inputSize = data[0].input.length;
970-
var outputSize = data[0].output.length;
971-
var hiddenSizes = this.hiddenSizes;
972-
if (!hiddenSizes) {
973-
sizes.push(Math.max(3, Math.floor(inputSize / 2)));
974-
} else {
975-
hiddenSizes.forEach(function (size) {
976-
sizes.push(size);
977-
});
978-
}
979-
980-
sizes.unshift(inputSize);
981-
sizes.push(outputSize);
982-
967+
value: function initialize(sizes) {
983968
this.sizes = sizes;
984969
this.outputLayer = this.sizes.length - 1;
985970
this.biases = []; // weights for bias nodes
@@ -1149,6 +1134,70 @@ var NeuralNetwork = function () {
11491134
return output;
11501135
}
11511136

1137+
/**
1138+
*
1139+
* @param data
1140+
* @returns sizes
1141+
*/
1142+
1143+
}, {
1144+
key: '_getSizesFromData',
1145+
value: function _getSizesFromData(data) {
1146+
var sizes = [];
1147+
var inputSize = data[0].input.length;
1148+
var outputSize = data[0].output.length;
1149+
var hiddenSizes = this.hiddenSizes;
1150+
if (!hiddenSizes) {
1151+
sizes.push(Math.max(3, Math.floor(inputSize / 2)));
1152+
} else {
1153+
hiddenSizes.forEach(function (size) {
1154+
sizes.push(size);
1155+
});
1156+
}
1157+
1158+
sizes.unshift(inputSize);
1159+
sizes.push(outputSize);
1160+
return sizes;
1161+
}
1162+
1163+
/**
1164+
*
1165+
* @param data
1166+
* @param learning Rate
1167+
* @returns error
1168+
*/
1169+
1170+
}, {
1171+
key: '_calculateTrainingError',
1172+
value: function _calculateTrainingError(data, learningRate) {
1173+
var sum = 0;
1174+
for (var i = 0; i < data.length; ++i) {
1175+
sum += this.trainPattern(data[i].input, data[i].output, learningRate);
1176+
}
1177+
return sum / data.length;
1178+
}
1179+
1180+
/**
1181+
*
1182+
* @param status { iterations: number, error: number}
1183+
* @param options
1184+
*/
1185+
1186+
}, {
1187+
key: '_checkTrainingTick',
1188+
value: function _checkTrainingTick(data, status, options) {
1189+
status.iterations++;
1190+
status.error = this._calculateTrainingError(data, options.learningRate);
1191+
1192+
if (options.log && status.iterations % options.logPeriod === 0) {
1193+
console.log('iterations: ' + status.iterations + ', training error: ' + status.error);
1194+
}
1195+
1196+
if (options.callback && status.iterations % options.callbackPeriod === 0) {
1197+
options.callback(Object.assign(status));
1198+
}
1199+
}
1200+
11521201
/**
11531202
*
11541203
* @param data
@@ -1163,42 +1212,21 @@ var NeuralNetwork = function () {
11631212

11641213
var options = Object.assign({}, this.constructor.trainDefaults, _options);
11651214
data = this.formatData(data);
1166-
var iterations = options.iterations;
1167-
var errorThresh = options.errorThresh;
1168-
var log = options.log === true ? console.log : options.log;
1169-
var logPeriod = options.logPeriod;
1170-
var learningRate = _options.learningRate || this.learningRate || options.learningRate;
1171-
var callback = options.callback;
1172-
var callbackPeriod = options.callbackPeriod;
1215+
options.learningRate = _options.learningRate || this.learningRate || options.learningRate;
11731216
var endTime = Date.now() + options.trainTimeMs;
1174-
var res = {
1217+
var status = {
11751218
error: 1,
11761219
iterations: 0
11771220
};
11781221

11791222
if (this.sizes === null) {
1180-
this.initialize(data);
1223+
var sizes = this._getSizesFromData(data);
1224+
this.initialize(sizes);
11811225
}
1182-
1183-
while (res.iterations < iterations && res.error > errorThresh && Date.now() > endTime) {
1184-
res.iterations++;
1185-
var sum = 0;
1186-
for (var i = 0; i < data.length; ++i) {
1187-
sum += this.trainPattern(data[i].input, data[i].output, learningRate);
1188-
}
1189-
1190-
res.error = sum / data.length;
1191-
1192-
if (log && res.iterations % logPeriod === 0) {
1193-
log('iterations:', res.iterations, 'training error:', res.error);
1194-
}
1195-
1196-
if (callback && res.iterations % callbackPeriod === 0) {
1197-
// JSON.parse/stringify to clone the object so the callback doesn't have side effects to training
1198-
callback(JSON.parse(JSON.stringify(res)));
1199-
}
1226+
while (status.iterations < options.iterations && status.error > options.errorThresh && Date.now() > endTime) {
1227+
this._checkTrainingTick(data, status, options);
12001228
}
1201-
return res;
1229+
return status;
12021230
}
12031231

12041232
/**
@@ -1218,57 +1246,38 @@ var NeuralNetwork = function () {
12181246

12191247
var cb = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : function () {};
12201248

1221-
if (typeof _options === "function") {
1249+
if (typeof _options === 'function') {
12221250
cb = _options;
12231251
_options = {};
12241252
}
12251253
var options = Object.assign({}, this.constructor.trainDefaults, _options);
12261254
data = this.formatData(data);
1227-
var iterations = options.iterations;
1228-
var errorThresh = options.errorThresh;
1229-
var log = options.log === true ? console.log : options.log;
1230-
var logPeriod = options.logPeriod;
1231-
var learningRate = _options.learningRate || this.learningRate || options.learningRate;
1232-
var callback = options.callback;
1233-
var callbackPeriod = options.callbackPeriod;
1255+
options.learningRate = _options.learningRate || this.learningRate || options.learningRate;
12341256
var endTime = Date.now() + options.trainTimeMs;
1235-
var res = {
1257+
1258+
var status = {
12361259
error: 1,
12371260
iterations: 0
12381261
};
12391262

12401263
if (this.sizes === null) {
1241-
this.initialize(data);
1264+
var sizes = this._getSizesFromData(data);
1265+
this.initialize(sizes);
12421266
}
12431267

1244-
var items = new Array(iterations);
1268+
var items = new Array(options.iterations);
12451269
var thaw = new _thaw2.default(items, {
12461270
delay: true,
12471271
each: function each() {
1248-
res.iterations++;
1249-
var sum = 0;
1250-
for (var i = 0; i < data.length; ++i) {
1251-
sum += _this.trainPattern(data[i].input, data[i].output, learningRate);
1252-
}
1253-
1254-
res.error = sum / data.length;
1255-
1256-
if (log && res.iterations % logPeriod === 0) {
1257-
log('iterations: ' + res.iterations + ' training error: ' + res.error);
1258-
}
1259-
1260-
if (callback && res.iterations % callbackPeriod === 0) {
1261-
// JSON.parse/stringify to clone the object so the callback doesn't have side effects to training
1262-
callback(JSON.parse(JSON.stringify(res)));
1263-
}
1272+
_this._checkTrainingTick(data, status, options);
12641273

1265-
if (res.error < errorThresh || endTime > 0 && Date.now() > endTime) {
1274+
if (status.error < options.errorThresh || Date.now() < endTime) {
12661275
thaw.stop();
12671276
}
12681277
},
12691278
done: function done() {
1270-
if (cb && typeof cb === "function") {
1271-
cb(res);
1279+
if (cb && typeof cb === 'function') {
1280+
cb(status);
12721281
}
12731282
}
12741283
});

0 commit comments

Comments
 (0)