Skip to content

Commit

Permalink
#1. Fix CRF training bug: learning network should not be called when …
Browse files Browse the repository at this point in the history
…running validation

#2. Support model vector quantization reduce model size to 1/4 original
#3. Refactoring code and speed up training
#4. Fixing feature extracting bug
  • Loading branch information
zhongkaifu committed Jan 27, 2016
1 parent dfe598f commit f251876
Show file tree
Hide file tree
Showing 25 changed files with 1,046 additions and 1,245 deletions.
307 changes: 175 additions & 132 deletions RNNSharp/BiRNN.cs

Large diffs are not rendered by default.

119 changes: 38 additions & 81 deletions RNNSharp/DataSet.cs
Original file line number Diff line number Diff line change
@@ -1,133 +1,90 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

/// <summary>
/// RNNSharp written by Zhongkai Fu (fuzhongkai@gmail.com)
/// </summary>
namespace RNNSharp
{
public class DataSet
{
List<Sequence> m_Data;
int m_tagSize;
List<List<double>> m_LabelBigramTransition;

/// <summary>
/// Split current corpus into two parts according given ratio
/// </summary>
/// <param name="ratio"></param>
/// <param name="ds1"></param>
/// <param name="ds2"></param>
public void SplitDataSet(double ratio, out DataSet ds1, out DataSet ds2)
{
Random rnd = new Random(DateTime.Now.Millisecond);
ds1 = new DataSet(m_tagSize);
ds2 = new DataSet(m_tagSize);

for (int i = 0; i < m_Data.Count; i++)
{
if (rnd.NextDouble() < ratio)
{
ds1.Add(m_Data[i]);
}
else
{
ds2.Add(m_Data[i]);
}
}

ds1.BuildLabelBigramTransition();
ds2.BuildLabelBigramTransition();
}

public void Add(Sequence sequence) { m_Data.Add(sequence); }
public List<Sequence> SequenceList { get; set; }
public int TagSize { get; set; }
public List<List<float>> CRFLabelBigramTransition { get; set; }

public void Shuffle()
{
Random rnd = new Random(DateTime.Now.Millisecond);
for (int i = 0; i < m_Data.Count; i++)
for (int i = 0; i < SequenceList.Count; i++)
{
int m = rnd.Next() % m_Data.Count;
Sequence tmp = m_Data[i];
m_Data[i] = m_Data[m];
m_Data[m] = tmp;
int m = rnd.Next() % SequenceList.Count;
Sequence tmp = SequenceList[i];
SequenceList[i] = SequenceList[m];
SequenceList[m] = tmp;
}
}

public DataSet(int tagSize)
{
m_tagSize = tagSize;
m_Data = new List<Sequence>();
m_LabelBigramTransition = new List<List<double>>();
}

public int GetSize()
{
return m_Data.Count;
TagSize = tagSize;
SequenceList = new List<Sequence>();
CRFLabelBigramTransition = new List<List<float>>();
}

public Sequence Get(int i) { return m_Data[i]; }
public int GetTagSize() { return m_tagSize; }


public int GetDenseDimension()
public int DenseFeatureSize()
{
if (0 == m_Data.Count) return 0;
return m_Data[0].GetDenseDimension();
if (0 == SequenceList.Count) return 0;
return SequenceList[0].GetDenseDimension();
}

public int GetSparseDimension()
{
if (0 == m_Data.Count) return 0;
return m_Data[0].GetSparseDimension();
if (0 == SequenceList.Count) return 0;
return SequenceList[0].GetSparseDimension();
}


public List<List<double>> GetLabelBigramTransition() { return m_LabelBigramTransition; }


public void BuildLabelBigramTransition(double smooth = 1.0)
public void BuildLabelBigramTransition(float smooth = 1.0f)
{
m_LabelBigramTransition = new List<List<double>>();
CRFLabelBigramTransition = new List<List<float>>();

for (int i = 0; i < m_tagSize; i++)
for (int i = 0; i < TagSize; i++)
{
m_LabelBigramTransition.Add(new List<double>());
CRFLabelBigramTransition.Add(new List<float>());
}
for (int i = 0; i < m_tagSize; i++)
for (int i = 0; i < TagSize; i++)
{
for (int j = 0; j < m_tagSize; j++)
for (int j = 0; j < TagSize; j++)
{
m_LabelBigramTransition[i].Add(smooth);
CRFLabelBigramTransition[i].Add(smooth);
}
}

for (int i = 0; i < m_Data.Count; i++)
for (int i = 0; i < SequenceList.Count; i++)
{
var sequence = m_Data[i];
if (sequence.GetSize() <= 1)
var sequence = SequenceList[i];
if (sequence.States.Length <= 1)
continue;

int pLabel = sequence.Get(0).GetLabel();
for (int j = 1; j < sequence.GetSize(); j++)
int pLabel = sequence.States[0].Label;
for (int j = 1; j < sequence.States.Length; j++)
{
int label = sequence.Get(j).GetLabel();
m_LabelBigramTransition[label][pLabel]++;
int label = sequence.States[j].Label;
CRFLabelBigramTransition[label][pLabel]++;
pLabel = label;
}
}

for (int i = 0; i < m_tagSize; i++)
for (int i = 0; i < TagSize; i++)
{
double sum = 0;
for (int j = 0; j < m_tagSize; j++)
for (int j = 0; j < TagSize; j++)
{
sum += m_LabelBigramTransition[i][j];
sum += CRFLabelBigramTransition[i][j];
}

for (int j = 0; j < m_tagSize; j++)
for (int j = 0; j < TagSize; j++)
{
m_LabelBigramTransition[i][j] = Math.Log(m_LabelBigramTransition[i][j] / sum);
CRFLabelBigramTransition[i][j] = (float)Math.Log(CRFLabelBigramTransition[i][j] / sum);
}
}
}
Expand Down
44 changes: 17 additions & 27 deletions RNNSharp/Featurizer.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.IO;
using Txt2Vec;
using AdvUtils;

/// <summary>
/// RNNSharp written by Zhongkai Fu (fuzhongkai@gmail.com)
/// </summary>
namespace RNNSharp
{
enum TFEATURE_WEIGHT_TYPE_ENUM
Expand All @@ -17,14 +16,14 @@ enum TFEATURE_WEIGHT_TYPE_ENUM

public class Featurizer
{
public TagSet TagSet { get; set; }

Dictionary<string, List<int>> m_FeatureConfiguration;
int m_SparseDimension;
int m_DenseDimension;
int m_WordEmbeddingCloumn;
TFEATURE_WEIGHT_TYPE_ENUM m_TFeatureWeightType = TFEATURE_WEIGHT_TYPE_ENUM.BINARY;

WordEMWrapFeaturizer m_WordEmbedding;
TagSet m_TagSet;
TemplateFeaturizer m_TFeaturizer;

static string TFEATURE_CONTEXT = "TFEATURE_CONTEXT";
Expand All @@ -35,12 +34,6 @@ public class Featurizer
static string WORDEMBEDDING_COLUMN = "WORDEMBEDDING_COLUMN";
static string TFEATURE_WEIGHT_TYPE = "TFEATURE_WEIGHT_TYPE";

public TagSet GetTagSet()
{
return m_TagSet;
}


//The format of configuration file
public void LoadFeatureConfigFromFile(string strFileName)
{
Expand Down Expand Up @@ -125,7 +118,7 @@ public int TruncPosition(int current, int lower, int upper)
public Featurizer(string strFeatureConfigFileName, TagSet tagSet)
{
LoadFeatureConfigFromFile(strFeatureConfigFileName);
m_TagSet = tagSet;
TagSet = tagSet;
InitComponentFeaturizer();
}

Expand All @@ -143,7 +136,7 @@ void InitComponentFeaturizer()

if (fc.ContainsKey(RT_FEATURE_CONTEXT) == true)
{
m_SparseDimension += m_TagSet.GetSize() * fc[RT_FEATURE_CONTEXT].Count;
m_SparseDimension += TagSet.GetSize() * fc[RT_FEATURE_CONTEXT].Count;
}

m_DenseDimension = 0;
Expand Down Expand Up @@ -173,15 +166,15 @@ public void ShowFeatureSize()
Logger.WriteLine(Logger.Level.info, "Template feature context size: {0}", m_TFeaturizer.GetFeatureSize() * fc[TFEATURE_CONTEXT].Count);

if (fc.ContainsKey(RT_FEATURE_CONTEXT) == true)
Logger.WriteLine(Logger.Level.info, "Run time feature size: {0}", m_TagSet.GetSize() * fc[RT_FEATURE_CONTEXT].Count);
Logger.WriteLine(Logger.Level.info, "Run time feature size: {0}", TagSet.GetSize() * fc[RT_FEATURE_CONTEXT].Count);

if (fc.ContainsKey(WORDEMBEDDING_CONTEXT) == true)
Logger.WriteLine(Logger.Level.info, "Word embedding feature size: {0}", m_WordEmbedding.GetDimension() * fc[WORDEMBEDDING_CONTEXT].Count);
}

void ExtractSparseFeature(int currentState, int numStates, List<string[]> features, State pState)
{
Dictionary<int, double> sparseFeature = new Dictionary<int, double>();
Dictionary<int, float> sparseFeature = new Dictionary<int, float>();
int start = 0;
var fc = m_FeatureConfiguration;

Expand Down Expand Up @@ -224,14 +217,14 @@ void ExtractSparseFeature(int currentState, int numStates, List<string[]> featur
if (fc.ContainsKey(RT_FEATURE_CONTEXT) == true)
{
List<int> v = fc[RT_FEATURE_CONTEXT];
pState.SetNumRuntimeFeature(v.Count);
pState.RuntimeFeatures = new PriviousLabelFeature[v.Count];
for (int j = 0; j < v.Count; j++)
{
if (v[j] < 0)
{
pState.AddRuntimeFeaturePlacehold(j, v[j], sparseFeature.Count, start);
sparseFeature[start] = 0; //Placehold a position
start += m_TagSet.GetSize();
start += TagSet.GetSize();
}
else
{
Expand All @@ -240,7 +233,7 @@ void ExtractSparseFeature(int currentState, int numStates, List<string[]> featur
}
}

SparseVector spSparseFeature = pState.GetSparseData();
SparseVector spSparseFeature = pState.SparseData;
spSparseFeature.SetDimension(m_SparseDimension);
spSparseFeature.SetData(sparseFeature);
}
Expand Down Expand Up @@ -284,19 +277,16 @@ public Vector ExtractDenseFeature(int currentState, int numStates, List<string[]

public Sequence ExtractFeatures(Sentence sentence)
{
Sequence sequence = new Sequence();
int n = sentence.GetTokenSize();
List<string[]> features = sentence.GetFeatureSet();
int n = sentence.TokensList.Count;
Sequence sequence = new Sequence(n);

//For each token, get its sparse and dense feature set according configuration and training corpus
sequence.SetSize(n);
for (int i = 0; i < n; i++)
{
State state = sequence.Get(i);
ExtractSparseFeature(i, n, features, state);
State state = sequence.States[i];
ExtractSparseFeature(i, n, sentence.TokensList, state);

var spDenseFeature = ExtractDenseFeature(i, n, features);
state.SetDenseData(spDenseFeature);
state.DenseData = ExtractDenseFeature(i, n, sentence.TokensList);
}

return sequence;
Expand Down
Loading

0 comments on commit f251876

Please sign in to comment.