Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .eslintrc.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
},
"root": true,
"rules": {
"arrow-parens": "off",
"class-methods-use-this": "off",
"linebreak-style": "off",
"max-classes-per-file": "off",
Expand All @@ -19,6 +20,7 @@
"no-param-reassign": "off",
"no-plusplus": "off",
"no-prototype-builtins": "off",
"no-restricted-globals": "off",
"no-underscore-dangle": "off",
"semi": "error"
}
Expand Down
2 changes: 1 addition & 1 deletion .prettierrc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"semi": true,
"singleQuote": true,
"trailingComma": "none"
"trailingComma": "es5"
}
1 change: 1 addition & 0 deletions __tests__/.eslintrc.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"no-param-reassign": "off",
"no-plusplus": "off",
"no-prototype-builtins": "off",
"no-restricted-globals": "off",
"no-underscore-dangle": "off",
"semi": "error"
}
Expand Down
132 changes: 79 additions & 53 deletions src/feed-forward.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ class FeedForward {
inputLayer: null,
outputLayer: null,
praxisOpts: null,
praxis: (layer, settings) => praxis.momentumRootMeanSquaredPropagation({ ...layer }, layer.praxisOpts || settings),
praxis: (layer, settings) =>
praxis.momentumRootMeanSquaredPropagation(
{ ...layer },
layer.praxisOpts || settings
),
};
}

Expand All @@ -40,21 +44,19 @@ class FeedForward {
*/
static _validateTrainingOptions(options) {
const validations = {
iterations: val => typeof val === 'number' && val > 0,
errorThresh: val => typeof val === 'number' && val > 0 && val < 1,
log: val => typeof val === 'function' || typeof val === 'boolean',
logPeriod: val => typeof val === 'number' && val > 0,
learningRate: val => typeof val === 'number' && val > 0 && val < 1,
callback: val => typeof val === 'function' || val === null,
callbackPeriod: val => typeof val === 'number' && val > 0,
timeout: val => typeof val === 'number' && val > 0,
iterations: (val) => typeof val === 'number' && val > 0,
errorThresh: (val) => typeof val === 'number' && val > 0 && val < 1,
log: (val) => typeof val === 'function' || typeof val === 'boolean',
logPeriod: (val) => typeof val === 'number' && val > 0,
learningRate: (val) => typeof val === 'number' && val > 0 && val < 1,
callback: (val) => typeof val === 'function' || val === null,
callbackPeriod: (val) => typeof val === 'number' && val > 0,
timeout: (val) => typeof val === 'number' && val > 0,
};
Object.keys(FeedForward.trainDefaults).forEach(key => {
Object.keys(FeedForward.trainDefaults).forEach((key) => {
if (validations.hasOwnProperty(key) && !validations[key](options[key])) {
throw new Error(
`[${key}, ${
options[key]
}] is out of normal training range, your network will probably not train.`
`[${key}, ${options[key]}] is out of normal training range, your network will probably not train.`
);
}
});
Expand All @@ -72,7 +74,7 @@ class FeedForward {
this.trainOpts.log = log;
} else if (log) {
// eslint-disable-next-line
this.trainOpts.log = console.log
this.trainOpts.log = console.log;
} else {
this.trainOpts.log = false;
}
Expand All @@ -86,15 +88,20 @@ class FeedForward {
* learningRate: (number)
*/
_updateTrainingOptions(opts) {
Object.keys(this.constructor.trainDefaults).forEach(opt => {
Object.keys(this.constructor.trainDefaults).forEach((opt) => {
this.trainOpts[opt] = opts.hasOwnProperty(opt)
? opts[opt]
: this.trainOpts[opt];
});
this.constructor._validateTrainingOptions(this.trainOpts);
this._setLogMethod(opts.log || this.trainOpts.log);
if (this.trainOpts.callback && this.trainOpts.callbackPeriod !== this.trainOpts.errorCheckInterval) {
console.warn(`options.callbackPeriod with value of ${ this.trainOpts.callbackPeriod } does not match options.errorCheckInterval with value of ${ this.trainOpts.errorCheckInterval }, if logging error, it will repeat. These values may need to match`);
if (
this.trainOpts.callback &&
this.trainOpts.callbackPeriod !== this.trainOpts.errorCheckInterval
) {
console.warn(
`options.callbackPeriod with value of ${this.trainOpts.callbackPeriod} does not match options.errorCheckInterval with value of ${this.trainOpts.errorCheckInterval}, if logging error, it will repeat. These values may need to match`
);
}
}

Expand All @@ -121,9 +128,10 @@ class FeedForward {
this.praxis = null;
Object.assign(this, this.constructor.defaults, options);
this.trainOpts = {};
this._updateTrainingOptions(
Object.assign({}, this.constructor.trainDefaults, options)
);
this._updateTrainingOptions({
...this.constructor.trainDefaults,
...options,
});
Object.assign(this, this.constructor.structure);
this._inputLayer = null;
this._hiddenLayers = null;
Expand Down Expand Up @@ -159,15 +167,19 @@ class FeedForward {
initialize() {
this._connectLayers();
this.initializeLayers(this.layers);
this._model = this.layers.filter(l => l instanceof Model);
this._model = this.layers.filter((l) => l instanceof Model);
}

initializeLayers(layers) {
for (let i = 0; i < layers.length; i++) {
const layer = layers[i];
// TODO: optimize for when training or just running
layer.setupKernels(true);
if (layer instanceof Model && layer.hasOwnProperty('praxis') && layer.praxis === null) {
if (
layer instanceof Model &&
layer.hasOwnProperty('praxis') &&
layer.praxis === null
) {
layer.praxis = this.praxis(layer, layer.praxisOpts || this.praxisOpts);
layer.praxis.setupKernels();
}
Expand Down Expand Up @@ -386,11 +398,7 @@ class FeedForward {
adjustWeights() {
const { _model } = this;
for (let i = 0; i < _model.length; i++) {
_model[i].learn(
null,
null,
this.trainOpts.learningRate
);
_model[i].learn(null, null, this.trainOpts.learningRate);
}
}

Expand All @@ -409,43 +417,57 @@ class FeedForward {

// turn sparse hash input into arrays with 0s as filler
const inputDatumCheck = data[0].input;
if (!Array.isArray(inputDatumCheck) && !(inputDatumCheck instanceof Float32Array)) {
if (
!Array.isArray(inputDatumCheck) &&
!(inputDatumCheck instanceof Float32Array)
) {
if (!this.inputLookup) {
this.inputLookup = lookup.buildLookup(data.map(value => value.input));
this.inputLookup = lookup.buildLookup(data.map((value) => value.input));
}
data = data.map(datumParam => {
data = data.map((datumParam) => {
const array = lookup.toArray(this.inputLookup, datumParam.input);
return Object.assign({}, datumParam, { input: array });
return { ...datumParam, input: array };
}, this);
}

const outputDatumCheck = data[0].output;
if (!Array.isArray(outputDatumCheck) && !(outputDatumCheck instanceof Float32Array)) {
if (
!Array.isArray(outputDatumCheck) &&
!(outputDatumCheck instanceof Float32Array)
) {
if (!this.outputLookup) {
this.outputLookup = lookup.buildLookup(data.map(value => value.output));
this.outputLookup = lookup.buildLookup(
data.map((value) => value.output)
);
}
data = data.map(datumParam => {
data = data.map((datumParam) => {
const array = lookup.toArray(this.outputLookup, datumParam.output);
return Object.assign({}, datumParam, { output: array });
return { ...datumParam, output: array };
}, this);
}
return data;
}

transferData(formattedData) {
const transferredData = new Array(formattedData.length);
const transferInput = makeKernel(function(value) {
return value[this.thread.x];
}, {
output: [formattedData[0].input.length],
immutable: true,
});
const transferOutput = makeKernel(function(value) {
return value[this.thread.x];
}, {
output: [formattedData[0].output.length],
immutable: true,
});
const transferInput = makeKernel(
function (value) {
return value[this.thread.x];
},
{
output: [formattedData[0].input.length],
immutable: true,
}
);
const transferOutput = makeKernel(
function (value) {
return value[this.thread.x];
},
{
output: [formattedData[0].output.length],
immutable: true,
}
);

for (let i = 0; i < formattedData.length; i++) {
const formattedDatum = formattedData[i];
Expand Down Expand Up @@ -497,7 +519,7 @@ class FeedForward {
return {
type: this.constructor.name,
sizes: [this._inputLayer.height]
.concat(this._hiddenLayers.map(l => l.height))
.concat(this._hiddenLayers.map((l) => l.height))
.concat([this._outputLayer.height]),
layers: jsonLayers,
};
Expand Down Expand Up @@ -525,19 +547,23 @@ class FeedForward {
);
} else {
if (!jsonLayer.hasOwnProperty('inputLayer1Index'))
throw new Error('inputLayer1Index not defined');
throw new Error(
'Cannot create network from provided JOSN. inputLayer1Index not defined.'
);
if (!jsonLayer.hasOwnProperty('inputLayer2Index'))
throw new Error('inputLayer2Index not defined');
throw new Error(
'Cannot create network from provided JOSN. inputLayer2Index not defined.'
);
const inputLayer1 = layers[jsonLayer.inputLayer1Index];
const inputLayer2 = layers[jsonLayer.inputLayer2Index];

if (inputLayer1 === undefined)
throw new Error(
`layer of index ${jsonLayer.inputLayer1Index} not found`
`Cannot create network from provided JOSN. layer of index ${jsonLayer.inputLayer1Index} not found.`
);
if (inputLayer2 === undefined)
throw new Error(
`layer of index ${jsonLayer.inputLayer2Index} not found`
`Cannot create network from provided JOSN. layer of index ${jsonLayer.inputLayer2Index} not found.`
);

layers.push(
Expand Down Expand Up @@ -575,5 +601,5 @@ class FeedForward {
}

module.exports = {
FeedForward
FeedForward,
};
8 changes: 7 additions & 1 deletion src/likely.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
* @returns {*}
*/
module.exports = function likely(input, net) {
if (!net) {
throw new TypeError(
`Required parameter 'net' is of type ${typeof net}. Must be of type 'brain.NeuralNetwork'`
);
}

const output = net.run(input);
let maxProp = null;
let maxValue = -1;

Object.keys(output).forEach(key => {
Object.keys(output).forEach((key) => {
const value = output[key];
if (value > maxValue) {
maxProp = key;
Expand Down
Loading