Skip to content

Commit

Permalink
zhongkaifu#1. Improve logger
Browse files Browse the repository at this point in the history
zhongkaifu#2. Using error token ratio to verify validated set performance
  • Loading branch information
zhongkaifu committed Dec 24, 2015
1 parent 5ad3ee9 commit 0838a69
Show file tree
Hide file tree
Showing 17 changed files with 139 additions and 121 deletions.
3 changes: 3 additions & 0 deletions ConvertCorpus/ConvertCorpus.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
<WarningLevel>4</WarningLevel>
</PropertyGroup>
<ItemGroup>
<Reference Include="AdvUtils">
<HintPath>..\dll\AdvUtils.dll</HintPath>
</Reference>
<Reference Include="System" />
<Reference Include="System.Core" />
<Reference Include="System.Xml.Linq" />
Expand Down
3 changes: 2 additions & 1 deletion ConvertCorpus/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Text;
using System.Threading.Tasks;
using System.IO;
using AdvUtils;

namespace ConvertCorpus
{
Expand All @@ -18,7 +19,7 @@ static int ArgPos(string str, string[] args)
{
if (a == args.Length - 1)
{
Console.WriteLine("Argument missing for {0}", str);
Logger.WriteLine(Logger.Level.info, "Argument missing for {0}", str);
return -1;
}
return a;
Expand Down
3 changes: 2 additions & 1 deletion RNNSharp/BiRNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using AdvUtils;

namespace RNNSharp
{
Expand Down Expand Up @@ -432,7 +433,7 @@ public override void saveNetBin(string filename)

public override void loadNetBin(string filename)
{
Console.WriteLine("Loading bi-directional model: {0}", filename);
Logger.WriteLine(Logger.Level.info, "Loading bi-directional model: {0}", filename);

forwardRNN.loadNetBin(filename + ".forward");
backwardRNN.loadNetBin(filename + ".backward");
Expand Down
17 changes: 9 additions & 8 deletions RNNSharp/Featurizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Threading.Tasks;
using System.IO;
using Txt2Vec;
using AdvUtils;

namespace RNNSharp
{
Expand Down Expand Up @@ -68,25 +69,25 @@ public void LoadFeatureConfigFromFile(string strFileName)
string strValue = kv[1].Trim().ToLower();
if (strKey == WORDEMBEDDING_FILENAME)
{
Console.WriteLine("Loading word embedding feature set...");
Logger.WriteLine(Logger.Level.info, "Loading word embedding feature set...");
m_WordEmbedding = new WordEMWrapFeaturizer(strValue);
continue;
}
else if (strKey == TFEATURE_FILENAME)
{
Console.WriteLine("Loading template feature set...");
Logger.WriteLine(Logger.Level.info, "Loading template feature set...");
m_TFeaturizer = new TemplateFeaturizer(strValue);
continue;
}
else if (strKey == WORDEMBEDDING_COLUMN)
{
m_WordEmbeddingCloumn = int.Parse(strValue);
Console.WriteLine("Word embedding feature column: {0}", m_WordEmbeddingCloumn);
Logger.WriteLine(Logger.Level.info, "Word embedding feature column: {0}", m_WordEmbeddingCloumn);
continue;
}
else if (strKey == TFEATURE_WEIGHT_TYPE)
{
Console.WriteLine("TFeature weighting type: {0}", strValue);
Logger.WriteLine(Logger.Level.info, "TFeature weighting type: {0}", strValue);
if (strValue == "binary")
{
m_TFeatureWeightType = TFEATURE_WEIGHT_TYPE_ENUM.BINARY;
Expand Down Expand Up @@ -166,16 +167,16 @@ public void ShowFeatureSize()
var fc = m_FeatureConfiguration;

if (m_TFeaturizer != null)
Console.WriteLine("Template feature size: {0}", m_TFeaturizer.GetFeatureSize());
Logger.WriteLine(Logger.Level.info, "Template feature size: {0}", m_TFeaturizer.GetFeatureSize());

if (fc.ContainsKey(TFEATURE_CONTEXT) == true)
Console.WriteLine("Template feature context size: {0}", m_TFeaturizer.GetFeatureSize() * fc[TFEATURE_CONTEXT].Count);
Logger.WriteLine(Logger.Level.info, "Template feature context size: {0}", m_TFeaturizer.GetFeatureSize() * fc[TFEATURE_CONTEXT].Count);

if (fc.ContainsKey(RT_FEATURE_CONTEXT) == true)
Console.WriteLine("Run time feature size: {0}", m_TagSet.GetSize() * fc[RT_FEATURE_CONTEXT].Count);
Logger.WriteLine(Logger.Level.info, "Run time feature size: {0}", m_TagSet.GetSize() * fc[RT_FEATURE_CONTEXT].Count);

if (fc.ContainsKey(WORDEMBEDDING_CONTEXT) == true)
Console.WriteLine("Word embedding feature size: {0}", m_WordEmbedding.GetDimension() * fc[WORDEMBEDDING_CONTEXT].Count);
Logger.WriteLine(Logger.Level.info, "Word embedding feature size: {0}", m_WordEmbedding.GetDimension() * fc[WORDEMBEDDING_CONTEXT].Count);
}

void ExtractSparseFeature(int currentState, int numStates, List<string[]> features, State pState)
Expand Down
5 changes: 3 additions & 2 deletions RNNSharp/LSTMRNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Text;
using System.Threading.Tasks;
using System.IO;
using AdvUtils;

namespace RNNSharp
{
Expand Down Expand Up @@ -135,7 +136,7 @@ private void saveLSTMWeight(LSTMWeight[][] weight, BinaryWriter fo)

public override void loadNetBin(string filename)
{
Console.WriteLine("Loading LSTM-RNN model: {0}", filename);
Logger.WriteLine(Logger.Level.info, "Loading LSTM-RNN model: {0}", filename);

StreamReader sr = new StreamReader(filename);
BinaryReader br = new BinaryReader(sr.BaseStream);
Expand Down Expand Up @@ -367,7 +368,7 @@ public override void initMem()
}
}

Console.WriteLine("[TRACE] Initializing weights, random value is {0}", rand.NextDouble());// yy debug
Logger.WriteLine(Logger.Level.info, "[TRACE] Initializing weights, random value is {0}", rand.NextDouble());// yy debug
initWeights();
}

Expand Down
25 changes: 13 additions & 12 deletions RNNSharp/ModelSetting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using AdvUtils;

namespace RNNSharp
{
Expand Down Expand Up @@ -66,34 +67,34 @@ public long GetSaveStep()

public void DumpSetting()
{
Console.WriteLine("Model File: {0}", m_strModelFile);
Logger.WriteLine(Logger.Level.info, "Model File: {0}", m_strModelFile);
if (m_ModelType == 0)
{
Console.WriteLine("Model Structure: Simple RNN");
Console.WriteLine("BPTT: {0}", m_Bptt);
Logger.WriteLine(Logger.Level.info, "Model Structure: Simple RNN");
Logger.WriteLine(Logger.Level.info, "BPTT: {0}", m_Bptt);
}
else if (m_ModelType == 1)
{
Console.WriteLine("Model Structure: LSTM-RNN");
Logger.WriteLine(Logger.Level.info, "Model Structure: LSTM-RNN");
}

if (m_iDir == 0)
{
Console.WriteLine("RNN Direction: Forward");
Logger.WriteLine(Logger.Level.info, "RNN Direction: Forward");
}
else
{
Console.WriteLine("RNN Direction: Bi-directional");
Logger.WriteLine(Logger.Level.info, "RNN Direction: Bi-directional");
}

Console.WriteLine("Learning rate: {0}", m_LearningRate);
Console.WriteLine("Dropout: {0}", m_Dropout);
Console.WriteLine("Max Iteration: {0}", m_MaxIteration);
Console.WriteLine("Hidden layer size: {0}", m_NumHidden);
Console.WriteLine("RNN-CRF: {0}", m_bCRFTraining);
Logger.WriteLine(Logger.Level.info, "Learning rate: {0}", m_LearningRate);
Logger.WriteLine(Logger.Level.info, "Dropout: {0}", m_Dropout);
Logger.WriteLine(Logger.Level.info, "Max Iteration: {0}", m_MaxIteration);
Logger.WriteLine(Logger.Level.info, "Hidden layer size: {0}", m_NumHidden);
Logger.WriteLine(Logger.Level.info, "RNN-CRF: {0}", m_bCRFTraining);
if (m_SaveStep > 0)
{
Console.WriteLine("Save temporary model after every {0} sentences", m_SaveStep);
Logger.WriteLine(Logger.Level.info, "Save temporary model after every {0} sentences", m_SaveStep);
}
}

Expand Down
36 changes: 19 additions & 17 deletions RNNSharp/RNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Threading.Tasks;
using System.Threading;
using System.IO;
using AdvUtils;

namespace RNNSharp
{
Expand Down Expand Up @@ -483,7 +484,7 @@ public virtual double TrainNet(DataSet trainingSet, int iter)
{
DateTime start = DateTime.Now;
int[] predicted;
Console.WriteLine("[TRACE] Iter " + iter + " begins with learning rate alpha = " + alpha + " ...");
Logger.WriteLine(Logger.Level.info, "[TRACE] Iter " + iter + " begins with learning rate alpha = " + alpha + " ...");

//Initialize varibles
counter = 0;
Expand All @@ -498,7 +499,7 @@ public virtual double TrainNet(DataSet trainingSet, int iter)
int numSequence = trainingSet.GetSize();
int tknErrCnt = 0;
int sentErrCnt = 0;
Console.WriteLine("[TRACE] Progress = 0/" + numSequence / 1000.0 + "K\r");
Logger.WriteLine(Logger.Level.info, "[TRACE] Progress = 0/" + numSequence / 1000.0 + "K\r");
for (int curSequence = 0; curSequence < numSequence; curSequence++)
{
Sequence pSequence = trainingSet.Get(curSequence);
Expand Down Expand Up @@ -532,18 +533,17 @@ public virtual double TrainNet(DataSet trainingSet, int iter)

if ((curSequence + 1) % 1000 == 0)
{
Console.WriteLine("[TRACE] Progress = {0} ", (curSequence + 1) / 1000 + "K/" + numSequence / 1000.0 + "K");
Console.WriteLine(" train cross-entropy = {0} ", -logp / Math.Log10(2.0) / counter);
Console.WriteLine(" Error token ratio = {0}%", (double)tknErrCnt / (double)counter * 100);
Console.WriteLine(" Error sentence ratio = {0}%", (double)sentErrCnt / (double)curSequence * 100);
Logger.WriteLine(Logger.Level.info, "[TRACE] Progress = {0} ", (curSequence + 1) / 1000 + "K/" + numSequence / 1000.0 + "K");
Logger.WriteLine(Logger.Level.info, " train cross-entropy = {0} ", -logp / Math.Log10(2.0) / counter);
Logger.WriteLine(Logger.Level.info, " Error token ratio = {0}%", (double)tknErrCnt / (double)counter * 100);
Logger.WriteLine(Logger.Level.info, " Error sentence ratio = {0}%", (double)sentErrCnt / (double)curSequence * 100);
}

if (m_SaveStep > 0 && (curSequence + 1) % m_SaveStep == 0)
{
//After processed every m_SaveStep sentences, save current model into a temporary file
Console.Write("Saving temporary model into file...");
Logger.WriteLine(Logger.Level.info, "Saving temporary model into file...");
saveNetBin(m_strModelFile + ".tmp");
Console.WriteLine("Done.");
}
}

Expand All @@ -552,9 +552,9 @@ public virtual double TrainNet(DataSet trainingSet, int iter)

double entropy = -logp / Math.Log10(2.0) / counter;
double ppl = exp_10(-logp / counter);
Console.WriteLine("[TRACE] Iter " + iter + " completed");
Console.WriteLine("[TRACE] Sentences = " + numSequence + ", time escape = " + duration + "s, speed = " + numSequence / duration.TotalSeconds);
Console.WriteLine("[TRACE] In training: log probability = " + logp + ", cross-entropy = " + entropy + ", perplexity = " + ppl);
Logger.WriteLine(Logger.Level.info, "[TRACE] Iter " + iter + " completed");
Logger.WriteLine(Logger.Level.info, "[TRACE] Sentences = " + numSequence + ", time escape = " + duration + "s, speed = " + numSequence / duration.TotalSeconds);
Logger.WriteLine(Logger.Level.info, "[TRACE] In training: log probability = " + logp + ", cross-entropy = " + entropy + ", perplexity = " + ppl);

return ppl;
}
Expand Down Expand Up @@ -831,7 +831,7 @@ public void CalculateOutputLayerError(State state, int timeat)

public virtual bool ValidateNet(DataSet validationSet)
{
Console.WriteLine("[TRACE] Start validation ...");
Logger.WriteLine(Logger.Level.info, "[TRACE] Start validation ...");
int wordcn = 0;
int[] predicted;
int tknErrCnt = 0;
Expand Down Expand Up @@ -875,17 +875,19 @@ public virtual bool ValidateNet(DataSet validationSet)

double entropy = -logp / Math.Log10(2.0) / counter;
double ppl = exp_10(-logp / counter);
double tknErrRatio = (double)tknErrCnt / (double)wordcn * 100;
double sentErrRatio = (double)sentErrCnt / (double)numSequence * 100;

Console.WriteLine("[TRACE] In validation: error token ratio = {0}% error sentence ratio = {1}%", (double)tknErrCnt / (double)wordcn * 100, (double)sentErrCnt / (double)numSequence * 100);
Console.WriteLine("[TRACE] In training: log probability = " + logp + ", cross-entropy = " + entropy + ", perplexity = " + ppl);
Console.WriteLine();
Logger.WriteLine(Logger.Level.info, "[TRACE] In validation: error token ratio = {0}% error sentence ratio = {1}%", tknErrRatio, sentErrRatio);
Logger.WriteLine(Logger.Level.info, "[TRACE] In training: log probability = " + logp + ", cross-entropy = " + entropy + ", perplexity = " + ppl);
Logger.WriteLine(Logger.Level.info, "");

bool bUpdate = false;
if (ppl < minTknErrRatio)
if (tknErrRatio < minTknErrRatio)
{
//We have better result on validated set, save this model
bUpdate = true;
minTknErrRatio = ppl;
minTknErrRatio = tknErrRatio;
}

return bUpdate;
Expand Down
9 changes: 5 additions & 4 deletions RNNSharp/RNNDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Text;
using System.Threading.Tasks;
using System.IO;
using AdvUtils;

namespace RNNSharp
{
Expand All @@ -21,7 +22,7 @@ public RNNDecoder(string strModelFileName, Featurizer featurizer)

if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
{
Console.WriteLine("Model Structure: Bi-directional RNN");
Logger.WriteLine(Logger.Level.info, "Model Structure: Bi-directional RNN");
if (modelType == MODELTYPE.SIMPLE)
{
m_Rnn = new BiRNN(new SimpleRNN(), new SimpleRNN());
Expand All @@ -35,18 +36,18 @@ public RNNDecoder(string strModelFileName, Featurizer featurizer)
{
if (modelType == MODELTYPE.SIMPLE)
{
Console.WriteLine("Model Structure: Simple RNN");
Logger.WriteLine(Logger.Level.info, "Model Structure: Simple RNN");
m_Rnn = new SimpleRNN();
}
else
{
Console.WriteLine("Model Structure: LSTM-RNN");
Logger.WriteLine(Logger.Level.info, "Model Structure: LSTM-RNN");
m_Rnn = new LSTMRNN();
}
}

m_Rnn.loadNetBin(strModelFileName);
Console.WriteLine("CRF Model: {0}", m_Rnn.IsCRFModel());
Logger.WriteLine(Logger.Level.info, "CRF Model: {0}", m_Rnn.IsCRFModel());
m_Featurizer = featurizer;
}

Expand Down
33 changes: 16 additions & 17 deletions RNNSharp/RNNEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Text;
using System.Threading.Tasks;
using System.IO;
using AdvUtils;

namespace RNNSharp
{
Expand Down Expand Up @@ -101,52 +102,50 @@ public void Train()
rnn.setTagBigramTransition(m_LabelBigramTransition);
}

Console.WriteLine();
Logger.WriteLine(Logger.Level.info, "");

Console.WriteLine("[TRACE] Iterative training begins ...");
Logger.WriteLine(Logger.Level.info, "[TRACE] Iterative training begins ...");
double lastPPL = double.MaxValue;
double lastAlpha = rnn.Alpha;
int iter = 0;
while (true)
{
if (rnn.MaxIter > 0 && iter > rnn.MaxIter)
{
Console.WriteLine("We have trained this model {0} iteration, exit.");
Logger.WriteLine(Logger.Level.info, "We have trained this model {0} iteration, exit.");
break;
}

//Start to train model
double ppl = rnn.TrainNet(m_TrainingSet, iter);

//Validate the model by validated corpus
bool betterValidateNet = false;
if (rnn.ValidateNet(m_ValidationSet) == true)
{
//If current model is better than before, save it into file
Console.Write("Saving better model into file {0}...", m_modelSetting.GetModelFile());
Logger.WriteLine(Logger.Level.info, "Saving better model into file {0}...", m_modelSetting.GetModelFile());
rnn.saveNetBin(m_modelSetting.GetModelFile());
Console.WriteLine("Done.");
}
//else
//{
// Console.Write("Loading previous best model from file {0}...", m_modelSetting.GetModelFile());
// rnn.loadNetBin(m_modelSetting.GetModelFile());
// Console.WriteLine("Done.");

// lastAlpha = rnn.Alpha;
// rnn.Alpha = rnn.Alpha / 2.0;
//}
betterValidateNet = true;
}
else
{
Logger.WriteLine(Logger.Level.info, "Loading previous best model from file {0}...", m_modelSetting.GetModelFile());
rnn.loadNetBin(m_modelSetting.GetModelFile());
}


if (ppl >= lastPPL && lastAlpha != rnn.Alpha)
{
//Although we reduce alpha value, we still cannot get better result.
Console.WriteLine("Current perplexity({0}) is larger than the previous one({1}). End training early.", ppl, lastPPL);
Console.WriteLine("Current alpha: {0}, the previous alpha: {1}", rnn.Alpha, lastAlpha);
Logger.WriteLine(Logger.Level.info, "Current perplexity({0}) is larger than the previous one({1}). End training early.", ppl, lastPPL);
Logger.WriteLine(Logger.Level.info, "Current alpha: {0}, the previous alpha: {1}", rnn.Alpha, lastAlpha);
break;
}

lastAlpha = rnn.Alpha;
if (ppl >= lastPPL)
if (betterValidateNet == false)
{
rnn.Alpha = rnn.Alpha / 2.0;
}
Expand Down
Loading

0 comments on commit 0838a69

Please sign in to comment.