@@ -964,22 +964,7 @@ var NeuralNetwork = function () {
964964
965965 _createClass(NeuralNetwork, [{
966966 key: 'initialize',
967- value: function initialize(data) {
968- var sizes = [];
969- var inputSize = data[0].input.length;
970- var outputSize = data[0].output.length;
971- var hiddenSizes = this.hiddenSizes;
972- if (!hiddenSizes) {
973- sizes.push(Math.max(3, Math.floor(inputSize / 2)));
974- } else {
975- hiddenSizes.forEach(function (size) {
976- sizes.push(size);
977- });
978- }
979-
980- sizes.unshift(inputSize);
981- sizes.push(outputSize);
982-
967+ value: function initialize(sizes) {
983968 this.sizes = sizes;
984969 this.outputLayer = this.sizes.length - 1;
985970 this.biases = []; // weights for bias nodes
@@ -1149,6 +1134,70 @@ var NeuralNetwork = function () {
11491134 return output;
11501135 }
11511136
1137+ /**
1138+ *
1139+ * @param data
1140+ * @returns sizes
1141+ */
1142+
1143+ }, {
1144+ key: '_getSizesFromData',
1145+ value: function _getSizesFromData(data) {
1146+ var sizes = [];
1147+ var inputSize = data[0].input.length;
1148+ var outputSize = data[0].output.length;
1149+ var hiddenSizes = this.hiddenSizes;
1150+ if (!hiddenSizes) {
1151+ sizes.push(Math.max(3, Math.floor(inputSize / 2)));
1152+ } else {
1153+ hiddenSizes.forEach(function (size) {
1154+ sizes.push(size);
1155+ });
1156+ }
1157+
1158+ sizes.unshift(inputSize);
1159+ sizes.push(outputSize);
1160+ return sizes;
1161+ }
1162+
1163+ /**
1164+ *
1165+ * @param data
1166+ * @param learning Rate
1167+ * @returns error
1168+ */
1169+
1170+ }, {
1171+ key: '_calculateTrainingError',
1172+ value: function _calculateTrainingError(data, learningRate) {
1173+ var sum = 0;
1174+ for (var i = 0; i < data.length; ++i) {
1175+ sum += this.trainPattern(data[i].input, data[i].output, learningRate);
1176+ }
1177+ return sum / data.length;
1178+ }
1179+
1180+ /**
1181+ *
1182+ * @param status { iterations: number, error: number}
1183+ * @param options
1184+ */
1185+
1186+ }, {
1187+ key: '_checkTrainingTick',
1188+ value: function _checkTrainingTick(data, status, options) {
1189+ status.iterations++;
1190+ status.error = this._calculateTrainingError(data, options.learningRate);
1191+
1192+ if (options.log && status.iterations % options.logPeriod === 0) {
1193+ console.log('iterations: ' + status.iterations + ', training error: ' + status.error);
1194+ }
1195+
1196+ if (options.callback && status.iterations % options.callbackPeriod === 0) {
1197+ options.callback(Object.assign(status));
1198+ }
1199+ }
1200+
11521201 /**
11531202 *
11541203 * @param data
@@ -1163,42 +1212,21 @@ var NeuralNetwork = function () {
11631212
11641213 var options = Object.assign({}, this.constructor.trainDefaults, _options);
11651214 data = this.formatData(data);
1166- var iterations = options.iterations;
1167- var errorThresh = options.errorThresh;
1168- var log = options.log === true ? console.log : options.log;
1169- var logPeriod = options.logPeriod;
1170- var learningRate = _options.learningRate || this.learningRate || options.learningRate;
1171- var callback = options.callback;
1172- var callbackPeriod = options.callbackPeriod;
1215+ options.learningRate = _options.learningRate || this.learningRate || options.learningRate;
11731216 var endTime = Date.now() + options.trainTimeMs;
1174- var res = {
1217+ var status = {
11751218 error: 1,
11761219 iterations: 0
11771220 };
11781221
11791222 if (this.sizes === null) {
1180- this.initialize(data);
1223+ var sizes = this._getSizesFromData(data);
1224+ this.initialize(sizes);
11811225 }
1182-
1183- while (res.iterations < iterations && res.error > errorThresh && Date.now() > endTime) {
1184- res.iterations++;
1185- var sum = 0;
1186- for (var i = 0; i < data.length; ++i) {
1187- sum += this.trainPattern(data[i].input, data[i].output, learningRate);
1188- }
1189-
1190- res.error = sum / data.length;
1191-
1192- if (log && res.iterations % logPeriod === 0) {
1193- log('iterations:', res.iterations, 'training error:', res.error);
1194- }
1195-
1196- if (callback && res.iterations % callbackPeriod === 0) {
1197- // JSON.parse/stringify to clone the object so the callback doesn't have side effects to training
1198- callback(JSON.parse(JSON.stringify(res)));
1199- }
1226+ while (status.iterations < options.iterations && status.error > options.errorThresh && Date.now() > endTime) {
1227+ this._checkTrainingTick(data, status, options);
12001228 }
1201- return res ;
1229+ return status ;
12021230 }
12031231
12041232 /**
@@ -1218,57 +1246,38 @@ var NeuralNetwork = function () {
12181246
12191247 var cb = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : function () {};
12201248
1221- if (typeof _options === " function" ) {
1249+ if (typeof _options === ' function' ) {
12221250 cb = _options;
12231251 _options = {};
12241252 }
12251253 var options = Object.assign({}, this.constructor.trainDefaults, _options);
12261254 data = this.formatData(data);
1227- var iterations = options.iterations;
1228- var errorThresh = options.errorThresh;
1229- var log = options.log === true ? console.log : options.log;
1230- var logPeriod = options.logPeriod;
1231- var learningRate = _options.learningRate || this.learningRate || options.learningRate;
1232- var callback = options.callback;
1233- var callbackPeriod = options.callbackPeriod;
1255+ options.learningRate = _options.learningRate || this.learningRate || options.learningRate;
12341256 var endTime = Date.now() + options.trainTimeMs;
1235- var res = {
1257+
1258+ var status = {
12361259 error: 1,
12371260 iterations: 0
12381261 };
12391262
12401263 if (this.sizes === null) {
1241- this.initialize(data);
1264+ var sizes = this._getSizesFromData(data);
1265+ this.initialize(sizes);
12421266 }
12431267
1244- var items = new Array(iterations);
1268+ var items = new Array(options. iterations);
12451269 var thaw = new _thaw2.default(items, {
12461270 delay: true,
12471271 each: function each() {
1248- res.iterations++;
1249- var sum = 0;
1250- for (var i = 0; i < data.length; ++i) {
1251- sum += _this.trainPattern(data[i].input, data[i].output, learningRate);
1252- }
1253-
1254- res.error = sum / data.length;
1255-
1256- if (log && res.iterations % logPeriod === 0) {
1257- log('iterations: ' + res.iterations + ' training error: ' + res.error);
1258- }
1259-
1260- if (callback && res.iterations % callbackPeriod === 0) {
1261- // JSON.parse/stringify to clone the object so the callback doesn't have side effects to training
1262- callback(JSON.parse(JSON.stringify(res)));
1263- }
1272+ _this._checkTrainingTick(data, status, options);
12641273
1265- if (res .error < errorThresh || endTime > 0 && Date.now() > endTime) {
1274+ if (status .error < options. errorThresh || Date.now() < endTime) {
12661275 thaw.stop();
12671276 }
12681277 },
12691278 done: function done() {
1270- if (cb && typeof cb === " function" ) {
1271- cb(res );
1279+ if (cb && typeof cb === ' function' ) {
1280+ cb(status );
12721281 }
12731282 }
12741283 });
0 commit comments