Skip to content

Commit

Permalink
code improvement by creating two new classes
Browse files Browse the repository at this point in the history
  • Loading branch information
itamposis committed Sep 24, 2019
1 parent 75ef12a commit 9f08fa5
Show file tree
Hide file tree
Showing 14 changed files with 394 additions and 332 deletions.
3 changes: 2 additions & 1 deletion conf/conf.demo
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ NOISE_EM=true
PRIOR_TRANS=0.001

#HNN OPTIONS
window=7
windowLeft=5
windowRight=5
nhidden=10
ADD_GRAD=0.0
DECAY=0.001
Expand Down
3 changes: 2 additions & 1 deletion conf/conf.hmmtm
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ NOISE_EM=true
PRIOR_TRANS=0.001

#HNN OPTIONS
window=7
windowLeft=5
windowRight=5
nhidden=10
ADD_GRAD=0.0
DECAY=0.001
Expand Down
3 changes: 2 additions & 1 deletion conf/conf.lipo
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ NOISE_EM=true
PRIOR_TRANS=0.001

#HNN OPTIONS
window=7
windowLeft=5
windowRight=5
nhidden=10
ADD_GRAD=0.0
DECAY=0.001
Expand Down
3 changes: 2 additions & 1 deletion conf/conf.signal
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ NOISE_EM=true
PRIOR_TRANS=0.001

#HNN OPTIONS
window=7
windowLeft=5
windowRight=5
nhidden=10
ADD_GRAD=0.0
DECAY=0.001
Expand Down
3 changes: 2 additions & 1 deletion conf/conf.tat
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ NOISE_EM=true
PRIOR_TRANS=0.001

#HNN OPTIONS
window=7
windowLeft=5
windowRight=5
nhidden=10
ADD_GRAD=0.0
DECAY=0.001
Expand Down
13 changes: 7 additions & 6 deletions conf/conf.tmbb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TRAINING OPTIONS
RUN_CML=false
RUN_GRADIENT=false
HNN=false
RUN_CML=true
RUN_GRADIENT=true
HNN=true
ALLOW_BEGIN=true
ALLOW_END=true
RUN_ViterbiTraining=false
Expand All @@ -17,7 +17,7 @@ EMISSIONS=FILE
WEIGHTS=RPROP

# Multithreaded parallelization for multicores
PARALLEL=true
PARALLEL=false
defCPU=true
nCPU=10

Expand Down Expand Up @@ -92,8 +92,9 @@ NOISE_EM=true
PRIOR_TRANS=0.001

#HNN OPTIONS
window=7
nhidden=10
windowLeft=5
windowRight=5
nhidden=7
ADD_GRAD=0.0
DECAY=0.001
#1: Sigmoid, 2: Sigmoid Modified, 3: Tanh
Expand Down
99 changes: 31 additions & 68 deletions src/hmm/Backprop.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,19 @@
import java.lang.Math;

class Backprop extends TrainAlgo {
public double valLog;
private boolean valid;
private double[][][] E;
private double[][] A;
private Probs tab;

public Backprop(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, WeightsL weightsL) throws Exception {
valid = true;
Run(trainSet, tab0, valSeqs, 0.0D);
}


public Backprop(final SeqSet trainSet, final Probs tab0, final double stopLog, WeightsL weightsL) throws Exception {
valid = false;
Run(trainSet, tab0, new SeqSet(0), stopLog);
}


public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, final double stopLog) throws Exception {
int ex = 0;

SeqSet seqs = new SeqSet(trainSet.nseqs + ex);

for (int i = 0; i < seqs.nseqs - ex; i++)
seqs.seq[i] = trainSet.seq[i];

public void Run(final SeqSet seqs, final Probs tab0, final SeqSet valSeqs, final double stopLog) throws Exception {
double loglikelihood, loglikelihoodC, loglikelihoodF;
double valLoglikelihood, valLoglikelihoodC, valLoglikelihoodF;

Expand All @@ -39,14 +26,6 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f
valLoglikelihoodC = Double.NEGATIVE_INFINITY;
valLoglikelihoodF = Double.NEGATIVE_INFINITY;

Forward[] fwdsC = new Forward[seqs.nseqs];
Backward[] bwdsC = new Backward[seqs.nseqs];
double[] logPC = new double[seqs.nseqs];

Forward[] fwdsF = new Forward[seqs.nseqs];
Backward[] bwdsF = new Backward[seqs.nseqs];
double[] logPF = new double[seqs.nseqs];

acts = new Activ[Params.NNclassLabels][seqs.nseqs][seqs.getMaxL()];
Activ[][][] valActs;

Expand All @@ -56,32 +35,6 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f
Weights bestw = new Weights();
tab = new Probs(tab0.aprob, tab0.weights);

//System.out.println( "*** Bootstrapping ***" );
for (int i = 1; i <= Params.BOOT; i++) {
if (Params.WEIGHTS.equals("RANDOM_NORMAL"))
tab.weights.RandomizeNormal(Params.STDEV, 0);
else if (Params.WEIGHTS.equals("RANDOM_UNIFORM"))
tab.weights.RandomizeUniform(Params.RANGE, 0);

hmm = new HMM(tab);
//hmm.print();//////
CalcActs(seqs);

loglikelihoodC = fwdbwd(fwdsC, bwdsC, logPC, false, seqs);
loglikelihoodF = fwdbwd(fwdsF, bwdsF, logPF, true, seqs);
loglikelihood = loglikelihoodC - loglikelihoodF;
System.out.println("\tC=" + loglikelihoodC + ", F=" + loglikelihoodF);//////////


System.out.println(i + "/" + Params.BOOT + "\tlog likelihood = " + loglikelihood);

if (loglikelihood > bestl) {
bestl = loglikelihood;
bestw = tab0.weights.GetClone();
}

}

if (bestl > Double.NEGATIVE_INFINITY) {
tab.weights = bestw;
System.out.println("*** Chosen " + loglikelihood + " ***");
Expand All @@ -99,21 +52,25 @@ else if (Params.WEIGHTS.equals("RANDOM_UNIFORM"))
if (valid)
CalcActs(valSeqs, valActs);

loglikelihoodC = fwdbwd(fwdsC, bwdsC, logPC, false, seqs, ex);

if (valid)
valLoglikelihoodC = fwdbwd(false, valSeqs, valActs);
ForwardBackward fwdbwdC = new ForwardBackward(hmm, false, seqs, acts);
ForwardBackward fwdbwdF = new ForwardBackward(hmm, true, seqs, acts);
loglikelihoodC = fwdbwdC.getLogProb();
loglikelihoodF = fwdbwdF.getLogProb();

if (valid) {
ForwardBackward fb = new ForwardBackward(hmm, false, valSeqs, valActs);
valLoglikelihoodC = fb.getLogProb();
}

double sum_weights = WeightsSquare(tab.weights);

loglikelihoodF = fwdbwd(fwdsF, bwdsF, logPF, true, seqs, ex);

loglikelihood = loglikelihoodC - loglikelihoodF - Params.DECAY * sum_weights;
System.out.println("\tC=" + loglikelihoodC + ", F=" + loglikelihoodF + ", SqWts=" + sum_weights);

if (valid) {
valLoglikelihoodF = fwdbwd(true, valSeqs, valActs);
ForwardBackward fb = new ForwardBackward(hmm, true, valSeqs, valActs);
valLoglikelihoodF = fb.getLogProb();

valLoglikelihood = valLoglikelihoodC - valLoglikelihoodF - Params.DECAY * sum_weights;
System.out.println("\tVC=" + valLoglikelihoodC + ", VF=" + valLoglikelihoodF + ", SqWts=" + sum_weights);
}
Expand Down Expand Up @@ -156,13 +113,13 @@ else if (Params.WEIGHTS.equals("RANDOM_UNIFORM"))

for (int s = 0; s < seqs.nseqs; s++) // Foreach sequence
{
Forward fwdC = fwdsC[s];
Backward bwdC = bwdsC[s];
double PC = logPC[s]; // NOT exp.
Forward fwdC = fwdbwdC.getFwds(s);
Backward bwdC = fwdbwdC.getBwds(s);
double PC = fwdbwdC.getLogP(s);

Forward fwdF = fwdsF[s];
Backward bwdF = bwdsF[s];
double PF = logPF[s]; // NOT exp.
Forward fwdF = fwdbwdF.getFwds(s);
Backward bwdF = fwdbwdF.getBwds(s);
double PF = fwdbwdF.getLogP(s);

int L = seqs.seq[s].getLen();

Expand Down Expand Up @@ -248,7 +205,7 @@ else if (Params.WEIGHTS.equals("RANDOM_UNIFORM"))
deriv23 = new double[Params.NNclassLabels][1][Params.nhidden+1];


ComputeDeriv(trainSet, E, deriv12, deriv23);
ComputeDeriv(seqs, E, deriv12, deriv23);
UpdateWeights(tab.weights, Params.RPROP, Params.SILVA, deriv12, deriv23);

//LineSearch();
Expand All @@ -267,22 +224,29 @@ else if (Params.WEIGHTS.equals("RANDOM_UNIFORM"))
if (valid)
CalcActs(valSeqs, valActs);

loglikelihoodC = fwdbwd(fwdsC, bwdsC, logPC, false, seqs, ex);
fwdbwdC = new ForwardBackward(hmm, false, seqs, acts);
loglikelihoodC = fwdbwdC.getLogProb();
fwdbwdF = new ForwardBackward(hmm, true, seqs, acts);
loglikelihoodF = fwdbwdF.getLogProb();

if (valid)
valLoglikelihoodC = fwdbwd(false, valSeqs, valActs);
if (valid) {
ForwardBackward fb = new ForwardBackward(hmm, false, valSeqs, valActs);
valLoglikelihoodC = fb.getLogProb();
}

////////////////////////////////////////////
sum_weights = WeightsSquare(tab.weights);
//sum_weights *=Params.DECAY;
/////////////////////////////////////////////

loglikelihoodF = fwdbwd(fwdsF, bwdsF, logPF, true, seqs, ex);
//loglikelihoodF = fwdbwd(fwdsF, bwdsF, logPF, true, seqs);
loglikelihood = loglikelihoodC - loglikelihoodF - Params.DECAY * sum_weights;
System.out.println("\tC=" + loglikelihoodC + ", F=" + loglikelihoodF + ", SqWts=" + sum_weights);//////////

if (valid) {
valLoglikelihoodF = fwdbwd(true, valSeqs, valActs);
ForwardBackward fb = new ForwardBackward(hmm, true, valSeqs, valActs);
valLoglikelihoodF = fb.getLogProb();

valLoglikelihood = valLoglikelihoodC - valLoglikelihoodF - Params.DECAY * sum_weights; //////////????????
System.out.println("\tvalC=" + valLoglikelihoodC + ", valF=" + valLoglikelihoodF + ", SqWts=" + sum_weights);//////////
}
Expand All @@ -293,7 +257,6 @@ else if (Params.WEIGHTS.equals("RANDOM_UNIFORM"))
if (valid) {
System.out.print("\tval log likelihood = " + valLoglikelihood + "\t\t diff = ");


if (valLoglikelihood > oldvalLoglikelihood || iter < Params.ITER) {
System.out.println("DOWN");
} else {
Expand Down
Loading

0 comments on commit 9f08fa5

Please sign in to comment.