From 2d58c7ac445772dfcf093aa8bea714b47a29b5e1 Mon Sep 17 00:00:00 2001 From: Zhongkai Fu Date: Sun, 28 Feb 2016 10:28:04 -0800 Subject: [PATCH] #1. Code refactoring #2. Normalize LSTM cell value in weights updating --- RNNSharp/LSTMRNN.cs | 30 ++++++++++++++++++++++++------ RNNSharp/ModelSetting.cs | 3 +++ RNNSharp/RNNEncoder.cs | 24 +++++++++++++++++++----- RNNSharp/SimpleRNN.cs | 26 +++++++++++++------------- RNNSharpConsole/Program.cs | 8 +++++++- 5 files changed, 66 insertions(+), 25 deletions(-) diff --git a/RNNSharp/LSTMRNN.cs b/RNNSharp/LSTMRNN.cs index 70c4793..a640263 100644 --- a/RNNSharp/LSTMRNN.cs +++ b/RNNSharp/LSTMRNN.cs @@ -66,6 +66,9 @@ public class LSTMRNN : RNN private new Vector4 vecMaxGrad; private new Vector4 vecMinGrad; + private new Vector3 vecMaxGrad3; + private new Vector3 vecMinGrad3; + public LSTMRNN() { ModelType = MODELTYPE.LSTM; @@ -502,6 +505,10 @@ public override void CleanStatus() vecNormalLearningRate3 = new Vector3(LearningRate, LearningRate, LearningRate); vecMaxGrad = new Vector4((float)GradientCutoff, (float)GradientCutoff, (float)GradientCutoff, (float)GradientCutoff); vecMinGrad = new Vector4((float)(-GradientCutoff), (float)(-GradientCutoff), (float)(-GradientCutoff), (float)(-GradientCutoff)); + + vecMaxGrad3 = new Vector3((float)GradientCutoff, (float)GradientCutoff, (float)GradientCutoff); + vecMinGrad3 = new Vector3((float)(-GradientCutoff), (float)(-GradientCutoff), (float)(-GradientCutoff)); + } public override void InitMem() @@ -530,7 +537,7 @@ public override void InitMem() private void CreateCell(BinaryReader br) { - neuFeatures = new SingleVector(DenseFeatureSize); + neuFeatures = null; OutputLayer = new SimpleLayer(L2); neuHidden = new LSTMCell[L1]; @@ -598,6 +605,13 @@ public override void LearnOutputWeight() }); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private Vector3 ComputeLearningRate(Vector3 vecDelta, ref Vector3 vecWeightLearningRate) + { + vecWeightLearningRate += vecDelta * vecDelta; + return vecNormalLearningRate3 / (Vector3.SquareRoot(vecWeightLearningRate) + Vector3.One); + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private Vector4 ComputeLearningRate(Vector4 vecDelta, ref Vector4 vecWeightLearningRate) { @@ -651,9 +665,11 @@ public override void LearnNet(State state, int numStates, int curState) } wd_i[entry.Key] = wd; + //Computing final err delta Vector4 vecDelta = new Vector4(wd, entry.Value); vecDelta = vecErr * vecDelta; + //Computing actual learning rate Vector4 vecLearningRate = ComputeLearningRate(vecDelta, ref wlr_i[entry.Key]); w_i[entry.Key] += vecLearningRate * vecDelta; } @@ -678,6 +694,7 @@ public override void LearnNet(State state, int numStates, int curState) Vector4 vecDelta = new Vector4(wd, feature); vecDelta = vecErr * vecDelta; + //Computing actual learning rate Vector4 vecLearningRate = ComputeLearningRate(vecDelta, ref wlr_i[j]); w_i[j] += vecLearningRate * vecDelta; } @@ -692,14 +709,15 @@ public override void LearnNet(State state, int numStates, int curState) //update internal weights Vector3 vecCellDelta = new Vector3((float)c.dSWCellIn, (float)c.dSWCellForget, (float)c.cellState); Vector3 vecCellErr = new Vector3(cellStateError, cellStateError, gradientOutputGate); - Vector3 vecCellLearningRate = CellLearningRate[i]; + + //Normalize err by gradient cut-off + vecCellErr = Vector3.Clamp(vecCellErr, vecMinGrad3, vecMaxGrad3); vecCellDelta = vecCellErr * vecCellDelta; - vecCellLearningRate += (vecCellDelta * vecCellDelta); - CellLearningRate[i] = vecCellLearningRate; - //LearningRate / (1.0 + Math.Sqrt(dg)); - vecCellLearningRate = vecNormalLearningRate3 / (Vector3.One + Vector3.SquareRoot(vecCellLearningRate)); + //Computing actual learning rate + Vector3 vecCellLearningRate = ComputeLearningRate(vecCellDelta, ref CellLearningRate[i]); + vecCellDelta = vecCellLearningRate * vecCellDelta; c.wCellIn += vecCellDelta.X; diff --git a/RNNSharp/ModelSetting.cs b/RNNSharp/ModelSetting.cs index 6cbc0fa..1736ce2 100644 --- a/RNNSharp/ModelSetting.cs +++ b/RNNSharp/ModelSetting.cs @@ -19,6 +19,7 @@ public class ModelSetting public int ModelType { get; set; } public int ModelDirection { get; set; } public int VQ { get; set; } + public float GradientCutoff { get; set; } public void DumpSetting() { @@ -49,6 +50,7 @@ public void DumpSetting() Logger.WriteLine("RNN-CRF: {0}", IsCRFTraining); Logger.WriteLine("SIMD: {0}, Size: {1}bits", System.Numerics.Vector.IsHardwareAccelerated, Vector.Count * sizeof(double) * 8); + Logger.WriteLine("Gradient cut-off: {0}", GradientCutoff); if (SaveStep > 0) { Logger.WriteLine("Save temporary model after every {0} sentences", SaveStep); @@ -60,6 +62,7 @@ public ModelSetting() MaxIteration = 20; Bptt = 4; LearningRate = 0.1f; + GradientCutoff = 15.0f; NumHidden = 200; IsCRFTraining = true; } diff --git a/RNNSharp/RNNEncoder.cs b/RNNSharp/RNNEncoder.cs index ed77c46..fd43ca2 100644 --- a/RNNSharp/RNNEncoder.cs +++ b/RNNSharp/RNNEncoder.cs @@ -64,7 +64,7 @@ public void Train() rnn.MaxIter = m_modelSetting.MaxIteration; rnn.IsCRFTraining = m_modelSetting.IsCRFTraining; rnn.LearningRate = m_modelSetting.LearningRate; - rnn.GradientCutoff = 15.0; + rnn.GradientCutoff = m_modelSetting.GradientCutoff; rnn.Dropout = m_modelSetting.Dropout; rnn.L1 = m_modelSetting.NumHidden; @@ -116,18 +116,32 @@ public void Train() betterValidateNet = rnn.ValidateNet(ValidationSet, iter); } - if ((ValidationSet != null && betterValidateNet == false) || - (ValidationSet == null && ppl >= lastPPL)) + if (ppl >= lastPPL) { + //We cannot get a better result on training corpus, so reduce learning rate rnn.LearningRate = rnn.LearningRate / 2.0f; } - else + + if (betterValidateNet == true) { - //If current model is better than before, save it into file + //We got better result on validated corpus, save this model Logger.WriteLine("Saving better model into file {0}...", m_modelSetting.ModelFile); rnn.SaveModel(m_modelSetting.ModelFile); } + + //if ((ValidationSet != null && betterValidateNet == false) || + // (ValidationSet == null && ppl >= lastPPL)) + //{ + // rnn.LearningRate = rnn.LearningRate / 2.0f; + //} + //else + //{ + // //If current model is better than before, save it into file + // Logger.WriteLine("Saving better model into file {0}...", m_modelSetting.ModelFile); + // rnn.SaveModel(m_modelSetting.ModelFile); + //} + lastPPL = ppl; iter++; diff --git a/RNNSharp/SimpleRNN.cs b/RNNSharp/SimpleRNN.cs index 04b4fbf..f574f97 100644 --- a/RNNSharp/SimpleRNN.cs +++ b/RNNSharp/SimpleRNN.cs @@ -17,9 +17,9 @@ public class SimpleRNN : RNN protected double[][] bptt_fea; protected SparseVector[] bptt_inputs = new SparseVector[MAX_RNN_HIST]; - protected Matrix mat_bptt_syn0_w; - protected Matrix mat_bptt_syn0_ph; - protected Matrix mat_bptt_synf; + protected Matrix Input2HiddenWeightsDelta; + protected Matrix HiddenBpttWeightsDelta; + protected Matrix Feature2HiddenWeightsDelta; //Last hidden layer status protected SimpleLayer neuLastHidden; @@ -263,7 +263,7 @@ private void learnBptt(State state) int i = 0; if (DenseFeatureSize > 0) { - vector_a = mat_bptt_synf[a]; + vector_a = Feature2HiddenWeightsDelta[a]; i = 0; while (i < DenseFeatureSize - Vector.Count) { @@ -283,7 +283,7 @@ private void learnBptt(State state) } //sparse weight update hidden->input - vector_a = mat_bptt_syn0_w[a]; + vector_a = Input2HiddenWeightsDelta[a]; for (i = 0; i < sparse.Count; i++) { var entry = sparse.GetEntry(i); @@ -291,7 +291,7 @@ private void learnBptt(State state) } //bptt weight update - vector_a = mat_bptt_syn0_ph[a]; + vector_a = HiddenBpttWeightsDelta[a]; i = 0; while (i < L1 - Vector.Count) { @@ -340,7 +340,7 @@ private void learnBptt(State state) //Update bptt feature weights vector_b = HiddenBpttWeights[b]; - vector_bf = mat_bptt_syn0_ph[b]; + vector_bf = HiddenBpttWeightsDelta[b]; vector_lr = HiddenBpttWeightsLearningRate[b]; int i = 0; @@ -383,7 +383,7 @@ private void learnBptt(State state) if (DenseFeatureSize > 0) { vector_b = Feature2HiddenWeights[b]; - vector_bf = mat_bptt_synf[b]; + vector_bf = Feature2HiddenWeightsDelta[b]; vector_lr = Feature2HiddenWeightsLearningRate[b]; i = 0; @@ -426,7 +426,7 @@ private void learnBptt(State state) //Update sparse feature weights vector_b = Input2HiddenWeights[b]; - vector_bf = mat_bptt_syn0_w[b]; + vector_bf = Input2HiddenWeightsDelta[b]; for (int step = 0; step < bptt + bptt_block - 2; step++) { var sparse = bptt_inputs[step]; @@ -466,9 +466,9 @@ public void resetBpttMem() bptt_fea[i] = new double[DenseFeatureSize]; } - mat_bptt_syn0_w = new Matrix(L1, L0); - mat_bptt_syn0_ph = new Matrix(L1, L1); - mat_bptt_synf = new Matrix(L1, DenseFeatureSize); + Input2HiddenWeightsDelta = new Matrix(L1, L0); + HiddenBpttWeightsDelta = new Matrix(L1, L1); + Feature2HiddenWeightsDelta = new Matrix(L1, DenseFeatureSize); } public override void CleanStatus() @@ -641,7 +641,7 @@ public override void LoadModel(string filename) private void CreateCells() { - neuFeatures = new SingleVector(DenseFeatureSize); + neuFeatures = null; OutputLayer = new SimpleLayer(L2); neuHidden = new SimpleLayer(L1); } diff --git a/RNNSharpConsole/Program.cs b/RNNSharpConsole/Program.cs index 333f1c0..e374805 100644 --- a/RNNSharpConsole/Program.cs +++ b/RNNSharpConsole/Program.cs @@ -30,6 +30,7 @@ class Program static int nBest = 1; static int iDir = 0; static int iVQ = 0; + static float gradientCutoff = 15.0f; static void UsageTitle() { @@ -94,8 +95,11 @@ static void UsageTrain() Console.WriteLine(" -vq "); Console.WriteLine("\tModel vector quantization, 0 is disable, 1 is enable. default is 0"); + Console.WriteLine(" -grad "); + Console.WriteLine("\tGradient cut-off. Default is 15.0f"); + Console.WriteLine(); - Console.WriteLine("Example: RNNSharpConsole.exe -mode train -trainfile train.txt -validfile valid.txt -modelfile model.bin -ftrfile features.txt -tagfile tags.txt -modeltype 0 -layersize 200 -alpha 0.1 -crf 1 -maxiter 20 -savestep 200K -dir 0 -vq 0"); + Console.WriteLine("Example: RNNSharpConsole.exe -mode train -trainfile train.txt -validfile valid.txt -modelfile model.bin -ftrfile features.txt -tagfile tags.txt -modeltype 0 -layersize 200 -alpha 0.1 -crf 1 -maxiter 20 -savestep 200K -dir 0 -vq 0 -grad 15.0"); } @@ -147,6 +151,7 @@ static void InitParameters(string[] args) if ((i = ArgPos("-nbest", args)) >= 0) nBest = int.Parse(args[i + 1]); if ((i = ArgPos("-dir", args)) >= 0) iDir = int.Parse(args[i + 1]); if ((i = ArgPos("-vq", args)) >= 0) iVQ = int.Parse(args[i + 1]); + if ((i = ArgPos("-grad", args)) >= 0) gradientCutoff = float.Parse(args[i + 1]); if ((i = ArgPos("-savestep", args)) >= 0) { @@ -433,6 +438,7 @@ private static void Train() RNNConfig.LearningRate = alpha; RNNConfig.Dropout = dropout; RNNConfig.Bptt = bptt; + RNNConfig.GradientCutoff = gradientCutoff; //Dump RNN setting on console RNNConfig.DumpSetting();