11const { GPU } = require ( 'gpu.js' ) ;
2- const { add, input, multiply, output, random, recurrent , lstm } = require ( '../../src/layer' ) ;
2+ const { add, input, multiply, output, random, rnnCell , lstmCell } = require ( '../../src/layer' ) ;
33const { setup, teardown } = require ( '../../src/utilities/kernel' ) ;
44
55const { Recurrent } = require ( '../../src/recurrent' ) ;
@@ -35,7 +35,7 @@ describe('Recurrent Class: End to End', () => {
3535 inputLayer : ( ) => input ( { height : 1 } ) ,
3636 hiddenLayers : [
3737 ( inputLayer , recurrentInput ) =>
38- recurrent ( { width : 1 , height : 3 } , inputLayer , recurrentInput ) ,
38+ rnnCell ( { width : 1 , height : 3 } , inputLayer , recurrentInput ) ,
3939 ] ,
4040 outputLayer : inputLayer => output ( { height : 1 } , inputLayer ) ,
4141 } ) ;
@@ -440,7 +440,7 @@ describe('Recurrent Class: End to End', () => {
440440 inputLayer : ( ) => input ( { width : 1 } ) ,
441441 hiddenLayers : [
442442 ( inputLayer , recurrentInput ) =>
443- recurrent ( { width : 1 , height : 1 } , inputLayer , recurrentInput ) ,
443+ rnnCell ( { width : 1 , height : 1 } , inputLayer , recurrentInput ) ,
444444 ] ,
445445 outputLayer : inputLayer => output ( { width : 1 , height : 1 } , inputLayer ) ,
446446 } ) ;
@@ -463,9 +463,9 @@ describe('Recurrent Class: End to End', () => {
463463 inputLayer : ( ) => input ( { width : 1 } ) ,
464464 hiddenLayers : [
465465 ( inputLayer , recurrentInput ) =>
466- recurrent ( { height : 3 , width : 1 } , inputLayer , recurrentInput ) ,
466+ rnnCell ( { height : 3 , width : 1 } , inputLayer , recurrentInput ) ,
467467 ( inputLayer , recurrentInput ) =>
468- recurrent ( { height : 1 , width : 1 } , inputLayer , recurrentInput ) ,
468+ rnnCell ( { height : 1 , width : 1 } , inputLayer , recurrentInput ) ,
469469 ] ,
470470 outputLayer : inputLayer => output ( { height : 1 } , inputLayer ) ,
471471 } ) ;
@@ -481,7 +481,7 @@ describe('Recurrent Class: End to End', () => {
481481 inputLayer : ( ) => input ( { height : 1 } ) ,
482482 hiddenLayers : [
483483 ( inputLayer , recurrentInput ) =>
484- recurrent ( { height : 3 } , inputLayer , recurrentInput ) ,
484+ rnnCell ( { height : 3 } , inputLayer , recurrentInput ) ,
485485 ] ,
486486 outputLayer : inputLayer => output ( { height : 1 } , inputLayer ) ,
487487 } ) ;
@@ -500,38 +500,34 @@ describe('Recurrent Class: End to End', () => {
500500
501501 it ( 'can learn xor' , ( ) => {
502502 const net = new Recurrent ( {
503+ praxisOpts : {
504+ regularizationStrength : 0.000001 ,
505+ learningRate : 0.01 ,
506+ } ,
503507 inputLayer : ( ) => input ( { height : 1 } ) ,
504508 hiddenLayers : [
505- ( input , recurrentInput ) => recurrent ( { height : 3 } , input , recurrentInput )
509+ ( input , recurrentInput ) => lstmCell ( { height : 20 } , input , recurrentInput )
506510 ] ,
507511 outputLayer : input => output ( { height : 1 } , input )
508512 } ) ;
509- net . initialize ( ) ;
510- net . initializeDeep ( ) ;
511- expect ( net . _model . length ) . toBe ( 5 ) ;
512- expect ( net . _layerSets [ 0 ] . length ) . toBe ( 15 ) ;
513- expect ( net . _layerSets [ 1 ] . length ) . toBe ( 15 ) ;
514- let error ;
515- for ( let i = 0 ; i < 100 ; i ++ ) {
516- let sum = 0 ;
517- sum = net . _trainPattern ( [ 0 , 0 , 0 ] , true ) [ 0 ] ;
518- sum += net . _trainPattern ( [ 0 , 1 , 1 ] , true ) [ 0 ] ;
519- sum += net . _trainPattern ( [ 1 , 0 , 1 ] , true ) [ 0 ] ;
520- sum += net . _trainPattern ( [ 1 , 1 , 0 ] , true ) [ 0 ] ;
521- error = sum / 4 ;
522- }
523- console . log ( net . runInput ( [ 0 , 0 ] ) ) ;
524- console . log ( net . runInput ( [ 0 , 1 ] ) ) ;
525- console . log ( net . runInput ( [ 1 , 0 ] ) ) ;
526- console . log ( net . runInput ( [ 1 , 1 ] ) ) ;
527- expect ( error / 4 ) . toBe ( 0.005 ) ;
513+ let xorNetValues = [
514+ [ 0.001 , 0.001 , 0.001 ] ,
515+ [ 0.001 , 1 , 1 ] ,
516+ [ 1 , 0.001 , 1 ] ,
517+ [ 1 , 1 , 0.001 ]
518+ ] ;
519+ net . train ( xorNetValues , { iterations : 300 } ) ;
520+ expect ( net . run ( [ 0.001 , 0.001 ] ) [ 0 ] [ 0 ] < 0.1 ) . toBeTruthy ( ) ;
521+ expect ( net . run ( [ 0.001 , 1 ] ) [ 0 ] [ 0 ] > 0.9 ) . toBeTruthy ( ) ;
522+ expect ( net . run ( [ 1 , 0.001 ] ) [ 0 ] [ 0 ] > 0.9 ) . toBeTruthy ( ) ;
523+ expect ( net . run ( [ 1 , 1 ] ) [ 0 ] [ 0 ] < 0.1 ) . toBeTruthy ( ) ;
528524 } ) ;
529525 test ( 'can learn 1,2,3' , ( ) => {
530526 const net = new Recurrent ( {
531527 inputLayer : ( ) => input ( { height : 1 } ) ,
532528 hiddenLayers : [
533529 ( inputLayer , recurrentInput ) =>
534- lstm ( { height : 10 } , inputLayer , recurrentInput ) ,
530+ lstmCell ( { height : 10 } , inputLayer , recurrentInput ) ,
535531 ] ,
536532 outputLayer : inputLayer => output ( { height : 1 } , inputLayer ) ,
537533 } ) ;
@@ -554,7 +550,7 @@ describe('Recurrent Class: End to End', () => {
554550 inputLayer : ( ) => input ( { height : 1 } ) ,
555551 hiddenLayers : [
556552 ( inputLayer , recurrentInput ) =>
557- recurrent ( { height : 3 } , inputLayer , recurrentInput ) ,
553+ rnnCell ( { height : 3 } , inputLayer , recurrentInput ) ,
558554 ] ,
559555 outputLayer : inputLayer => output ( { height : 1 } , inputLayer ) ,
560556 } ) ;
0 commit comments