Skip to content

Commit

Permalink
#1. Bug fix: output layer is not cleaned before calculating new values
Browse files Browse the repository at this point in the history
#2. Add dropout for LSTM
  • Loading branch information
zhongkaifu committed Jan 9, 2016
1 parent fe925d1 commit 1d2b3be
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 114 deletions.
6 changes: 0 additions & 6 deletions RNNSharp/BiRNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,6 @@ public neuron[][] InnerDecode(Sequence pSequence, out Matrix<neuron> outputHidde
return seqOutput;
}

public override void netFlush()
{
forwardRNN.netFlush();
backwardRNN.netFlush();
}

public override Matrix<double> learnSentenceForRNNCRF(Sequence pSequence, RunningMode runningMode)
{
//Reset the network
Expand Down
144 changes: 56 additions & 88 deletions RNNSharp/LSTMRNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class LSTMCell

//cell output
public double cellOutput;
public bool mask;
}

public struct LSTMWeight
Expand Down Expand Up @@ -68,10 +69,6 @@ public class LSTMRNN : RNN
protected LSTMWeightDerivative[][] input2hiddenDeri;
protected LSTMWeightDerivative[][] feature2hiddenDeri;

//for LSTM layer
const bool NORMAL = true;
const bool BIAS = false;

public LSTMRNN()
{
m_modeltype = MODELTYPE.LSTM;
Expand Down Expand Up @@ -248,29 +245,29 @@ public override void saveNetBin(string filename)
}


double TanH(double x)
double Sigmoid2(double x)
{
return Math.Tanh(x);
//sigmoid function return a bounded output between [-2,2]
return (4.0 / (1.0 + Math.Exp(-x))) - 2.0;
}

double TanHDerivative(double x)
double Sigmoid2Derivative(double x)
{
double tmp = Math.Tanh(x);
return 1 - tmp * tmp;
return 4.0 * Sigmoid(x) * (1.0 - Sigmoid(x));
}

double Sigmoid(double x)
{
return (1 / (1 + Math.Exp(-x)));
return (1.0 / (1.0 + Math.Exp(-x)));
}

double SigmoidDerivative(double x)
{
return Sigmoid(x) * (1 - Sigmoid(x));
return Sigmoid(x) * (1.0 - Sigmoid(x));
}


public LSTMWeight LSTMWeightInit(int iL)
public LSTMWeight LSTMWeightInit()
{
LSTMWeight w;

Expand All @@ -292,7 +289,7 @@ public override void initWeights()
input2hidden[i] = new LSTMWeight[L0];
for (int j = 0; j < L0; j++)
{
input2hidden[i][j] = LSTMWeightInit(L0);
input2hidden[i][j] = LSTMWeightInit();
}
}

Expand All @@ -304,7 +301,7 @@ public override void initWeights()
feature2hidden[i] = new LSTMWeight[fea_size];
for (int j = 0; j < fea_size; j++)
{
feature2hidden[i][j] = LSTMWeightInit(L0);
feature2hidden[i][j] = LSTMWeightInit();
}
}
}
Expand Down Expand Up @@ -418,26 +415,14 @@ public void matrixXvectorADD(neuron[] dest, LSTMCell[] srcvec, Matrix<double> sr
//ac mod
Parallel.For(0, (to - from), parallelOption, i =>
{
dest[i + from].cellOutput = 0;
for (int j = 0; j < to2 - from2; j++)
{
dest[i + from].cellOutput += srcvec[j + from2].cellOutput * srcmatrix[i][j];
}
});
}

public void matrixXvectorADD(LSTMCell[] dest, double[] srcvec, LSTMWeight[][] srcmatrix, int from, int to, int from2, int to2)
{
//ac mod
Parallel.For(0, (to - from), parallelOption, i =>
{
for (int j = 0; j < to2 - from2; j++)
{
dest[i + from].netIn += srcvec[j + from2] * srcmatrix[i][j].wInputInputGate;
}
});
}


public override void LearnBackTime(State state, int numStates, int curState)
{
}
Expand All @@ -463,8 +448,8 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
{
var entry = sparse.GetEntry(k);
LSTMWeightDerivative w = w_i[entry.Key];
w_i[entry.Key].dSInputCell = w.dSInputCell * c.yForget + TanHDerivative(c.netCellState) * c.yIn * entry.Value;
w_i[entry.Key].dSInputInputGate = w.dSInputInputGate * c.yForget + TanH(c.netCellState) * SigmoidDerivative(c.netIn) * entry.Value;
w_i[entry.Key].dSInputCell = w.dSInputCell * c.yForget + Sigmoid2Derivative(c.netCellState) * c.yIn * entry.Value;
w_i[entry.Key].dSInputInputGate = w.dSInputInputGate * c.yForget + Sigmoid2(c.netCellState) * SigmoidDerivative(c.netIn) * entry.Value;
w_i[entry.Key].dSInputForgetGate = w.dSInputForgetGate * c.yForget + c.previousCellState * SigmoidDerivative(c.netForget) * entry.Value;
}
Expand All @@ -475,15 +460,15 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
for (int j = 0; j < fea_size; j++)
{
LSTMWeightDerivative w = w_i[j];
w_i[j].dSInputCell = w.dSInputCell * c.yForget + TanHDerivative(c.netCellState) * c.yIn * neuFeatures[j];
w_i[j].dSInputInputGate = w.dSInputInputGate * c.yForget + TanH(c.netCellState) * SigmoidDerivative(c.netIn) * neuFeatures[j];
w_i[j].dSInputCell = w.dSInputCell * c.yForget + Sigmoid2Derivative(c.netCellState) * c.yIn * neuFeatures[j];
w_i[j].dSInputInputGate = w.dSInputInputGate * c.yForget + Sigmoid2(c.netCellState) * SigmoidDerivative(c.netIn) * neuFeatures[j];
w_i[j].dSInputForgetGate = w.dSInputForgetGate * c.yForget + c.previousCellState * SigmoidDerivative(c.netForget) * neuFeatures[j];
}
}
//partial derivatives for internal connections
c.dSWCellIn = c.dSWCellIn * c.yForget + TanH(c.netCellState) * SigmoidDerivative(c.netIn) * c.cellState;
c.dSWCellIn = c.dSWCellIn * c.yForget + Sigmoid2(c.netCellState) * SigmoidDerivative(c.netIn) * c.cellState;
//partial derivatives for internal connections, initially zero as dS is zero and previous cell state is zero
c.dSWCellForget = c.dSWCellForget * c.yForget + c.previousCellState * SigmoidDerivative(c.netForget) * c.previousCellState;
Expand All @@ -505,18 +490,12 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
weightedSum = NormalizeErr(weightedSum);
//using the error find the gradient of the output gate
double gradientOutputGate = SigmoidDerivative(c.netOut) * TanHDerivative(c.cellState) * weightedSum;
double gradientOutputGate = SigmoidDerivative(c.netOut) * c.cellState * weightedSum;
//internal cell state error
double cellStateError = c.yOut * weightedSum;
//weight updates
//already done the deltas for the hidden-output connections
//output gates. for each connection to the hidden layer
//to the input layer
LSTMWeight[] w_i = input2hidden[i];
LSTMWeightDerivative[] wd_i = input2hiddenDeri[i];
for (int k = 0; k < sparseFeatureSize; k++)
Expand Down Expand Up @@ -545,30 +524,22 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
}
}
//for the internal connection
double deltaOutputGateCell = alpha * gradientOutputGate * c.cellState;
//using internal partial derivative
double deltaInputGateCell = alpha * cellStateError * c.dSWCellIn;
double deltaForgetGateCell = alpha * cellStateError * c.dSWCellForget;
//update internal weights
c.wCellIn += deltaInputGateCell;
c.wCellForget += deltaForgetGateCell;
c.wCellOut += deltaOutputGateCell;
c.wCellIn += alpha * cellStateError * c.dSWCellIn;
c.wCellForget += alpha * cellStateError * c.dSWCellForget;
c.wCellOut += alpha * gradientOutputGate * c.cellState;
neuHidden[i] = c;
});

//update weights for hidden to output layer
for (int i = 0; i < L1; i++)
Parallel.For(0, L1, parallelOption, i =>
{
for (int k = 0; k < L2; k++)
{
mat_hidden2output[k][i] += alpha * neuHidden[i].cellOutput * neuOutput[k].er;
}
}
});
}


Expand All @@ -580,35 +551,16 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
var sparse = state.GetSparseData();
int sparseFeatureSize = sparse.GetNumberOfEntries();

//loop through all input gates in hidden layer
//for each hidden neuron
Parallel.For(0, L1, parallelOption, j =>
{
//rest the value of the net input to zero
neuHidden[j].netIn = 0;
//hidden(t-1) -> hidden(t)
neuHidden[j].previousCellState = neuHidden[j].cellState;
//for each input neuron
for (int i = 0; i < sparseFeatureSize; i++)
{
var entry = sparse.GetEntry(i);
neuHidden[j].netIn += entry.Value * input2hidden[j][entry.Key].wInputInputGate;
}
});

//fea(t) -> hidden(t)
if (fea_size > 0)
{
matrixXvectorADD(neuHidden, neuFeatures, feature2hidden, 0, L1, 0, fea_size);
}

Parallel.For(0, L1, parallelOption, j =>
{
LSTMCell cell_j = neuHidden[j];
//hidden(t-1) -> hidden(t)
cell_j.previousCellState = cell_j.cellState;
//rest the value of the net input to zero
cell_j.netIn = 0;
cell_j.netForget = 0;
//reset each netCell state to zero
cell_j.netCellState = 0;
Expand All @@ -619,16 +571,19 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
var entry = sparse.GetEntry(i);
LSTMWeight w = input2hidden[j][entry.Key];
//loop through all forget gates in hiddden layer
cell_j.netIn += entry.Value * w.wInputInputGate;
cell_j.netForget += entry.Value * w.wInputForgetGate;
cell_j.netCellState += entry.Value * w.wInputCell;
cell_j.netOut += entry.Value * w.wInputOutputGate;
}
//fea(t) -> hidden(t)
if (fea_size > 0)
{
for (int i = 0; i < fea_size; i++)
{
LSTMWeight w = feature2hidden[j][i];
cell_j.netIn += neuFeatures[i] * w.wInputInputGate;
cell_j.netForget += neuFeatures[i] * w.wInputForgetGate;
cell_j.netCellState += neuFeatures[i] * w.wInputCell;
cell_j.netOut += neuFeatures[i] * w.wInputOutputGate;
Expand All @@ -643,18 +598,24 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
//include internal connection multiplied by the previous cell state
cell_j.netForget += cell_j.previousCellState * cell_j.wCellForget;
cell_j.yForget = Sigmoid(cell_j.netForget);
//cell state is equal to the previous cell state multipled by the forget gate and the cell inputs multiplied by the input gate
cell_j.cellState = cell_j.yForget * cell_j.previousCellState + cell_j.yIn * TanH(cell_j.netCellState);
if (cell_j.mask == true)
{
cell_j.cellState = 0;
}
else
{
//cell state is equal to the previous cell state multipled by the forget gate and the cell inputs multiplied by the input gate
cell_j.cellState = cell_j.yForget * cell_j.previousCellState + cell_j.yIn * Sigmoid2(cell_j.netCellState);
}
////include the internal connection multiplied by the CURRENT cell state
cell_j.netOut += cell_j.cellState * cell_j.wCellOut;
//squash output gate
cell_j.yOut = Sigmoid(cell_j.netOut);
cell_j.cellOutput = TanH(cell_j.cellState) * cell_j.yOut;
cell_j.cellOutput = cell_j.cellState * cell_j.yOut;
neuHidden[j] = cell_j;
Expand All @@ -673,18 +634,25 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
SoftmaxLayer(neuOutput);
}

public override void netFlush() //cleans all activations and error vectors
public override void netReset(bool updateNet = false) //cleans hidden layer activation + bptt history
{
neuFeatures = new double[fea_size];
for (int a = 0; a < L1; a++)
{
neuHidden[a].mask = false;
}

for (int i = 0; i < L1; i++)
if (updateNet == true)
{
LSTMCellInit(neuHidden[i]);
//Train mode
for (int a = 0; a < L1; a++)
{
if (rand.NextDouble() < dropout)
{
neuHidden[a].mask = true;
}
}
}
}

public override void netReset(bool updateNet = false) //cleans hidden layer activation + bptt history
{
Parallel.For(0, L1, parallelOption, i =>
{
LSTMCellInit(neuHidden[i]);
Expand Down
7 changes: 1 addition & 6 deletions RNNSharp/RNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,6 @@ public virtual double TrainNet(DataSet trainingSet, int iter)
logp = 0;
counterTokenForLM = 0;

netFlush();

//Shffle training corpus
trainingSet.Shuffle();

Expand Down Expand Up @@ -620,8 +618,6 @@ public void matrixXvectorADD(neuron[] dest, neuron[] srcvec, Matrix<double> srcm
}
}

public abstract void netFlush();

public int[] DecodeNN(Sequence seq)
{
Matrix<double> ys = PredictSentence(seq, RunningMode.Test);
Expand Down Expand Up @@ -841,8 +837,7 @@ public virtual bool ValidateNet(DataSet validationSet)
counter = 0;
logp = 0;
counterTokenForLM = 0;

netFlush();

int numSequence = validationSet.GetSize();
for (int curSequence = 0; curSequence < numSequence; curSequence++)
{
Expand Down
10 changes: 5 additions & 5 deletions RNNSharp/RNNEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ public void Train()

betterValidateNet = true;
}
else
{
Logger.WriteLine(Logger.Level.info, "Loading previous best model from file {0}...", m_modelSetting.GetModelFile());
rnn.loadNetBin(m_modelSetting.GetModelFile());
}
//else
//{
// Logger.WriteLine(Logger.Level.info, "Loading previous best model from file {0}...", m_modelSetting.GetModelFile());
// rnn.loadNetBin(m_modelSetting.GetModelFile());
//}


if (ppl >= lastPPL && lastAlpha != rnn.Alpha)
Expand Down
9 changes: 0 additions & 9 deletions RNNSharp/SimpleRNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,6 @@ public override void LearnBackTime(State state, int numStates, int curState)
}
}


public override void netFlush() //cleans all activations and error vectors
{
neuHidden = new neuron[L1];
neuOutput = new neuron[L2];
}

public override void loadNetBin(string filename)
{
Logger.WriteLine(Logger.Level.info, "Loading SimpleRNN model: {0}", filename);
Expand Down Expand Up @@ -460,9 +453,7 @@ public override void saveNetBin(string filename)
StreamWriter sw = new StreamWriter(filename);
BinaryWriter fo = new BinaryWriter(sw.BaseStream);


fo.Write((int)m_modeltype);

fo.Write((int)m_modeldirection);

// Signiture , 0 is for RNN or 1 is for RNN-CRF
Expand Down

0 comments on commit 1d2b3be

Please sign in to comment.