Skip to content

Commit ffbcec4

Browse files
authored
Merge pull request #549 from BrainJS/master
merge master
2 parents 818eb2d + 7b73a89 commit ffbcec4

File tree

9 files changed

+196
-211
lines changed

9 files changed

+196
-211
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
const { recurrent } = require('../../src/layer/recurrent');
1+
const { rnnCell } = require('../../src/layer/rnn-cell');
22

33
describe('Recurrent Layer', () => {
44
test('properly sets width and height', () => {
@@ -12,7 +12,7 @@ describe('Recurrent Layer', () => {
1212
},
1313
};
1414

15-
const layer = recurrent(settings, input, recurrentInput);
15+
const layer = rnnCell(settings, input, recurrentInput);
1616

1717
expect(layer.width).toEqual(1);
1818
expect(layer.height).toEqual(settings.height);

__tests__/recurrent/end-to-end.js

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
const { GPU } = require('gpu.js');
2-
const { add, input, multiply, output, random, recurrent, lstm } = require('../../src/layer');
2+
const { add, input, multiply, output, random, rnnCell, lstmCell } = require('../../src/layer');
33
const { setup, teardown } = require('../../src/utilities/kernel');
44

55
const { Recurrent } = require('../../src/recurrent');
@@ -35,7 +35,7 @@ describe('Recurrent Class: End to End', () => {
3535
inputLayer: () => input({ height: 1 }),
3636
hiddenLayers: [
3737
(inputLayer, recurrentInput) =>
38-
recurrent({ width: 1, height: 3 }, inputLayer, recurrentInput),
38+
rnnCell({ width: 1, height: 3 }, inputLayer, recurrentInput),
3939
],
4040
outputLayer: inputLayer => output({ height: 1 }, inputLayer),
4141
});
@@ -440,7 +440,7 @@ describe('Recurrent Class: End to End', () => {
440440
inputLayer: () => input({ width: 1 }),
441441
hiddenLayers: [
442442
(inputLayer, recurrentInput) =>
443-
recurrent({ width: 1, height: 1 }, inputLayer, recurrentInput),
443+
rnnCell({ width: 1, height: 1 }, inputLayer, recurrentInput),
444444
],
445445
outputLayer: inputLayer => output({ width: 1, height: 1 }, inputLayer),
446446
});
@@ -463,9 +463,9 @@ describe('Recurrent Class: End to End', () => {
463463
inputLayer: () => input({ width: 1 }),
464464
hiddenLayers: [
465465
(inputLayer, recurrentInput) =>
466-
recurrent({ height: 3, width: 1 }, inputLayer, recurrentInput),
466+
rnnCell({ height: 3, width: 1 }, inputLayer, recurrentInput),
467467
(inputLayer, recurrentInput) =>
468-
recurrent({ height: 1, width: 1 }, inputLayer, recurrentInput),
468+
rnnCell({ height: 1, width: 1 }, inputLayer, recurrentInput),
469469
],
470470
outputLayer: inputLayer => output({ height: 1 }, inputLayer),
471471
});
@@ -481,7 +481,7 @@ describe('Recurrent Class: End to End', () => {
481481
inputLayer: () => input({ height: 1 }),
482482
hiddenLayers: [
483483
(inputLayer, recurrentInput) =>
484-
recurrent({ height: 3 }, inputLayer, recurrentInput),
484+
rnnCell({ height: 3 }, inputLayer, recurrentInput),
485485
],
486486
outputLayer: inputLayer => output({ height: 1 }, inputLayer),
487487
});
@@ -500,38 +500,34 @@ describe('Recurrent Class: End to End', () => {
500500

501501
it('can learn xor', () => {
502502
const net = new Recurrent({
503+
praxisOpts: {
504+
regularizationStrength: 0.000001,
505+
learningRate: 0.01,
506+
},
503507
inputLayer: () => input({ height: 1 }),
504508
hiddenLayers: [
505-
(input, recurrentInput) => recurrent({ height: 3 }, input, recurrentInput)
509+
(input, recurrentInput) => lstmCell({ height: 20 }, input, recurrentInput)
506510
],
507511
outputLayer: input => output({ height: 1 }, input)
508512
});
509-
net.initialize();
510-
net.initializeDeep();
511-
expect(net._model.length).toBe(5);
512-
expect(net._layerSets[0].length).toBe(15);
513-
expect(net._layerSets[1].length).toBe(15);
514-
let error;
515-
for (let i = 0; i < 100; i++) {
516-
let sum = 0;
517-
sum = net._trainPattern([0, 0, 0], true)[0];
518-
sum += net._trainPattern([0, 1, 1], true)[0];
519-
sum += net._trainPattern([1, 0, 1], true)[0];
520-
sum += net._trainPattern([1, 1, 0], true)[0];
521-
error = sum / 4;
522-
}
523-
console.log(net.runInput([0, 0]));
524-
console.log(net.runInput([0, 1]));
525-
console.log(net.runInput([1, 0]));
526-
console.log(net.runInput([1, 1]));
527-
expect(error / 4).toBe(0.005);
513+
let xorNetValues = [
514+
[0.001, 0.001, 0.001],
515+
[0.001, 1, 1],
516+
[1, 0.001, 1],
517+
[1, 1, 0.001]
518+
];
519+
net.train(xorNetValues, { iterations: 300 });
520+
expect(net.run([0.001, 0.001])[0][0] < 0.1).toBeTruthy();
521+
expect(net.run([0.001, 1])[0][0] > 0.9).toBeTruthy();
522+
expect(net.run([1, 0.001])[0][0] > 0.9).toBeTruthy();
523+
expect(net.run([1, 1])[0][0] < 0.1).toBeTruthy();
528524
});
529525
test('can learn 1,2,3', () => {
530526
const net = new Recurrent({
531527
inputLayer: () => input({ height: 1 }),
532528
hiddenLayers: [
533529
(inputLayer, recurrentInput) =>
534-
lstm({ height: 10 }, inputLayer, recurrentInput),
530+
lstmCell({ height: 10 }, inputLayer, recurrentInput),
535531
],
536532
outputLayer: inputLayer => output({ height: 1 }, inputLayer),
537533
});
@@ -554,7 +550,7 @@ describe('Recurrent Class: End to End', () => {
554550
inputLayer: () => input({ height: 1 }),
555551
hiddenLayers: [
556552
(inputLayer, recurrentInput) =>
557-
recurrent({ height: 3 }, inputLayer, recurrentInput),
553+
rnnCell({ height: 3 }, inputLayer, recurrentInput),
558554
],
559555
outputLayer: inputLayer => output({ height: 1 }, inputLayer),
560556
});

0 commit comments

Comments
 (0)