From 9f08fa57630ec6f352c2bc73515058ae216ea7fd Mon Sep 17 00:00:00 2001 From: itamposis Date: Tue, 24 Sep 2019 10:25:07 +0300 Subject: [PATCH] code improvement by creating two new classes --- conf/conf.demo | 3 +- conf/conf.hmmtm | 3 +- conf/conf.lipo | 3 +- conf/conf.signal | 3 +- conf/conf.tat | 3 +- conf/conf.tmbb | 13 +- src/hmm/Backprop.java | 99 +++++---------- src/hmm/CML.java | 87 ++++++------- src/hmm/ForwardBackward.java | 146 +++++++++++++++++++++ src/hmm/GEM.java | 10 +- src/hmm/{juchmme.java => Juchmme.java} | 2 +- src/hmm/ML.java | 68 ++++------ src/hmm/TrainAlgo.java | 169 +------------------------ src/hmm/ViterbiTraining.java | 117 +++++++++++++++++ 14 files changed, 394 insertions(+), 332 deletions(-) create mode 100644 src/hmm/ForwardBackward.java rename src/hmm/{juchmme.java => Juchmme.java} (98%) create mode 100644 src/hmm/ViterbiTraining.java diff --git a/conf/conf.demo b/conf/conf.demo index 0ff997c..ec41700 100644 --- a/conf/conf.demo +++ b/conf/conf.demo @@ -92,7 +92,8 @@ NOISE_EM=true PRIOR_TRANS=0.001 #HNN OPTIONS -window=7 +windowLeft=5 +windowRight=5 nhidden=10 ADD_GRAD=0.0 DECAY=0.001 diff --git a/conf/conf.hmmtm b/conf/conf.hmmtm index 06a0b59..5b90d4e 100644 --- a/conf/conf.hmmtm +++ b/conf/conf.hmmtm @@ -92,7 +92,8 @@ NOISE_EM=true PRIOR_TRANS=0.001 #HNN OPTIONS -window=7 +windowLeft=5 +windowRight=5 nhidden=10 ADD_GRAD=0.0 DECAY=0.001 diff --git a/conf/conf.lipo b/conf/conf.lipo index 0ff997c..ec41700 100644 --- a/conf/conf.lipo +++ b/conf/conf.lipo @@ -92,7 +92,8 @@ NOISE_EM=true PRIOR_TRANS=0.001 #HNN OPTIONS -window=7 +windowLeft=5 +windowRight=5 nhidden=10 ADD_GRAD=0.0 DECAY=0.001 diff --git a/conf/conf.signal b/conf/conf.signal index 0ff997c..ec41700 100644 --- a/conf/conf.signal +++ b/conf/conf.signal @@ -92,7 +92,8 @@ NOISE_EM=true PRIOR_TRANS=0.001 #HNN OPTIONS -window=7 +windowLeft=5 +windowRight=5 nhidden=10 ADD_GRAD=0.0 DECAY=0.001 diff --git a/conf/conf.tat b/conf/conf.tat index 0ff997c..ec41700 100644 --- a/conf/conf.tat +++ b/conf/conf.tat @@ -92,7 +92,8 @@ NOISE_EM=true PRIOR_TRANS=0.001 #HNN OPTIONS -window=7 +windowLeft=5 +windowRight=5 nhidden=10 ADD_GRAD=0.0 DECAY=0.001 diff --git a/conf/conf.tmbb b/conf/conf.tmbb index 0c59d14..ceee55a 100644 --- a/conf/conf.tmbb +++ b/conf/conf.tmbb @@ -1,7 +1,7 @@ # TRAINING OPTIONS -RUN_CML=false -RUN_GRADIENT=false -HNN=false +RUN_CML=true +RUN_GRADIENT=true +HNN=true ALLOW_BEGIN=true ALLOW_END=true RUN_ViterbiTraining=false @@ -17,7 +17,7 @@ EMISSIONS=FILE WEIGHTS=RPROP # Multithreaded parallelization for multicores -PARALLEL=true +PARALLEL=false defCPU=true nCPU=10 @@ -92,8 +92,9 @@ NOISE_EM=true PRIOR_TRANS=0.001 #HNN OPTIONS -window=7 -nhidden=10 +windowLeft=5 +windowRight=5 +nhidden=7 ADD_GRAD=0.0 DECAY=0.001 #1: Sigmoid, 2: Sigmoid Modified, 3: Tanh diff --git a/src/hmm/Backprop.java b/src/hmm/Backprop.java index fd94dd1..026b6ae 100644 --- a/src/hmm/Backprop.java +++ b/src/hmm/Backprop.java @@ -3,32 +3,19 @@ import java.lang.Math; class Backprop extends TrainAlgo { - public double valLog; - private boolean valid; private double[][][] E; - private double[][] A; - private Probs tab; public Backprop(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, WeightsL weightsL) throws Exception { valid = true; Run(trainSet, tab0, valSeqs, 0.0D); } - public Backprop(final SeqSet trainSet, final Probs tab0, final double stopLog, WeightsL weightsL) throws Exception { valid = false; Run(trainSet, tab0, new SeqSet(0), stopLog); } - - public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, final double stopLog) throws Exception { - int ex = 0; - - SeqSet seqs = new SeqSet(trainSet.nseqs + ex); - - for (int i = 0; i < seqs.nseqs - ex; i++) - seqs.seq[i] = trainSet.seq[i]; - + public void Run(final SeqSet seqs, final Probs tab0, final SeqSet valSeqs, final double stopLog) throws Exception { double loglikelihood, loglikelihoodC, loglikelihoodF; double valLoglikelihood, valLoglikelihoodC, valLoglikelihoodF; @@ -39,14 +26,6 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f valLoglikelihoodC = Double.NEGATIVE_INFINITY; valLoglikelihoodF = Double.NEGATIVE_INFINITY; - Forward[] fwdsC = new Forward[seqs.nseqs]; - Backward[] bwdsC = new Backward[seqs.nseqs]; - double[] logPC = new double[seqs.nseqs]; - - Forward[] fwdsF = new Forward[seqs.nseqs]; - Backward[] bwdsF = new Backward[seqs.nseqs]; - double[] logPF = new double[seqs.nseqs]; - acts = new Activ[Params.NNclassLabels][seqs.nseqs][seqs.getMaxL()]; Activ[][][] valActs; @@ -56,32 +35,6 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f Weights bestw = new Weights(); tab = new Probs(tab0.aprob, tab0.weights); - //System.out.println( "*** Bootstrapping ***" ); - for (int i = 1; i <= Params.BOOT; i++) { - if (Params.WEIGHTS.equals("RANDOM_NORMAL")) - tab.weights.RandomizeNormal(Params.STDEV, 0); - else if (Params.WEIGHTS.equals("RANDOM_UNIFORM")) - tab.weights.RandomizeUniform(Params.RANGE, 0); - - hmm = new HMM(tab); - //hmm.print();////// - CalcActs(seqs); - - loglikelihoodC = fwdbwd(fwdsC, bwdsC, logPC, false, seqs); - loglikelihoodF = fwdbwd(fwdsF, bwdsF, logPF, true, seqs); - loglikelihood = loglikelihoodC - loglikelihoodF; - System.out.println("\tC=" + loglikelihoodC + ", F=" + loglikelihoodF);////////// - - - System.out.println(i + "/" + Params.BOOT + "\tlog likelihood = " + loglikelihood); - - if (loglikelihood > bestl) { - bestl = loglikelihood; - bestw = tab0.weights.GetClone(); - } - - } - if (bestl > Double.NEGATIVE_INFINITY) { tab.weights = bestw; System.out.println("*** Chosen " + loglikelihood + " ***"); @@ -99,21 +52,25 @@ else if (Params.WEIGHTS.equals("RANDOM_UNIFORM")) if (valid) CalcActs(valSeqs, valActs); - loglikelihoodC = fwdbwd(fwdsC, bwdsC, logPC, false, seqs, ex); - - if (valid) - valLoglikelihoodC = fwdbwd(false, valSeqs, valActs); + ForwardBackward fwdbwdC = new ForwardBackward(hmm, false, seqs, acts); + ForwardBackward fwdbwdF = new ForwardBackward(hmm, true, seqs, acts); + loglikelihoodC = fwdbwdC.getLogProb(); + loglikelihoodF = fwdbwdF.getLogProb(); + if (valid) { + ForwardBackward fb = new ForwardBackward(hmm, false, valSeqs, valActs); + valLoglikelihoodC = fb.getLogProb(); + } double sum_weights = WeightsSquare(tab.weights); - loglikelihoodF = fwdbwd(fwdsF, bwdsF, logPF, true, seqs, ex); - loglikelihood = loglikelihoodC - loglikelihoodF - Params.DECAY * sum_weights; System.out.println("\tC=" + loglikelihoodC + ", F=" + loglikelihoodF + ", SqWts=" + sum_weights); if (valid) { - valLoglikelihoodF = fwdbwd(true, valSeqs, valActs); + ForwardBackward fb = new ForwardBackward(hmm, true, valSeqs, valActs); + valLoglikelihoodF = fb.getLogProb(); + valLoglikelihood = valLoglikelihoodC - valLoglikelihoodF - Params.DECAY * sum_weights; System.out.println("\tVC=" + valLoglikelihoodC + ", VF=" + valLoglikelihoodF + ", SqWts=" + sum_weights); } @@ -156,13 +113,13 @@ else if (Params.WEIGHTS.equals("RANDOM_UNIFORM")) for (int s = 0; s < seqs.nseqs; s++) // Foreach sequence { - Forward fwdC = fwdsC[s]; - Backward bwdC = bwdsC[s]; - double PC = logPC[s]; // NOT exp. + Forward fwdC = fwdbwdC.getFwds(s); + Backward bwdC = fwdbwdC.getBwds(s); + double PC = fwdbwdC.getLogP(s); - Forward fwdF = fwdsF[s]; - Backward bwdF = bwdsF[s]; - double PF = logPF[s]; // NOT exp. + Forward fwdF = fwdbwdF.getFwds(s); + Backward bwdF = fwdbwdF.getBwds(s); + double PF = fwdbwdF.getLogP(s); int L = seqs.seq[s].getLen(); @@ -248,7 +205,7 @@ else if (Params.WEIGHTS.equals("RANDOM_UNIFORM")) deriv23 = new double[Params.NNclassLabels][1][Params.nhidden+1]; - ComputeDeriv(trainSet, E, deriv12, deriv23); + ComputeDeriv(seqs, E, deriv12, deriv23); UpdateWeights(tab.weights, Params.RPROP, Params.SILVA, deriv12, deriv23); //LineSearch(); @@ -267,22 +224,29 @@ else if (Params.WEIGHTS.equals("RANDOM_UNIFORM")) if (valid) CalcActs(valSeqs, valActs); - loglikelihoodC = fwdbwd(fwdsC, bwdsC, logPC, false, seqs, ex); + fwdbwdC = new ForwardBackward(hmm, false, seqs, acts); + loglikelihoodC = fwdbwdC.getLogProb(); + fwdbwdF = new ForwardBackward(hmm, true, seqs, acts); + loglikelihoodF = fwdbwdF.getLogProb(); - if (valid) - valLoglikelihoodC = fwdbwd(false, valSeqs, valActs); + if (valid) { + ForwardBackward fb = new ForwardBackward(hmm, false, valSeqs, valActs); + valLoglikelihoodC = fb.getLogProb(); + } //////////////////////////////////////////// sum_weights = WeightsSquare(tab.weights); //sum_weights *=Params.DECAY; ///////////////////////////////////////////// - loglikelihoodF = fwdbwd(fwdsF, bwdsF, logPF, true, seqs, ex); + //loglikelihoodF = fwdbwd(fwdsF, bwdsF, logPF, true, seqs); loglikelihood = loglikelihoodC - loglikelihoodF - Params.DECAY * sum_weights; System.out.println("\tC=" + loglikelihoodC + ", F=" + loglikelihoodF + ", SqWts=" + sum_weights);////////// if (valid) { - valLoglikelihoodF = fwdbwd(true, valSeqs, valActs); + ForwardBackward fb = new ForwardBackward(hmm, true, valSeqs, valActs); + valLoglikelihoodF = fb.getLogProb(); + valLoglikelihood = valLoglikelihoodC - valLoglikelihoodF - Params.DECAY * sum_weights; //////////???????? System.out.println("\tvalC=" + valLoglikelihoodC + ", valF=" + valLoglikelihoodF + ", SqWts=" + sum_weights);////////// } @@ -293,7 +257,6 @@ else if (Params.WEIGHTS.equals("RANDOM_UNIFORM")) if (valid) { System.out.print("\tval log likelihood = " + valLoglikelihood + "\t\t diff = "); - if (valLoglikelihood > oldvalLoglikelihood || iter < Params.ITER) { System.out.println("DOWN"); } else { diff --git a/src/hmm/CML.java b/src/hmm/CML.java index 451a37b..442ceda 100644 --- a/src/hmm/CML.java +++ b/src/hmm/CML.java @@ -3,12 +3,7 @@ import java.util.*; class CML extends TrainAlgo { - - public double valLog; - private boolean valid; - private double[][] E; - private double[][] A; - private Probs tab; + double[][] E; public CML(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, WeightsL weightsL) throws Exception { valid = true; @@ -76,24 +71,24 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f valLoglikelihood = Double.NEGATIVE_INFINITY; double loglikelihoodC = 0, loglikelihoodF = 0; - Forward[] fwdsC = new Forward[trainSet.nseqs]; - Backward[] bwdsC = new Backward[trainSet.nseqs]; - double[] logPC = new double[trainSet.nseqs]; - String[] vPathsC = new String[trainSet.nseqs]; - - Forward[] fwdsF = new Forward[trainSet.nseqs]; - Backward[] bwdsF = new Backward[trainSet.nseqs]; - double[] logPF = new double[trainSet.nseqs]; - String[] vPathsF = new String[trainSet.nseqs]; + ViterbiTraining vtC = null; + ViterbiTraining vtF = null; + ForwardBackward fwdbwdC = null; + ForwardBackward fwdbwdF = null; //Initialization Step if (TrainingWithViterbi) { - loglikelihoodC = ViterbiTraining(trainSet, vPathsC, logPC, false, weightsL); - loglikelihoodF = ViterbiTraining(trainSet, vPathsF, logPF, true, weightsL); + vtC = new ViterbiTraining(hmm, trainSet, false, weightsL); + loglikelihoodC =vtC.getLogProb(); + + vtF = new ViterbiTraining(hmm, trainSet, true, weightsL); + loglikelihoodF =vtF.getLogProb(); } else { // Compute Forward and Backward tables for the sequences - loglikelihoodC = fwdbwd(fwdsC, bwdsC, logPC, false, trainSet, weightsL); - loglikelihoodF = fwdbwd(fwdsF, bwdsF, logPF, true, trainSet, weightsL); + fwdbwdC = new ForwardBackward(hmm, false, trainSet, weightsL); + loglikelihoodC = fwdbwdC.getLogProb(); + fwdbwdF = new ForwardBackward(hmm, true, trainSet, weightsL); + loglikelihoodF = fwdbwdF.getLogProb(); } double loglikelihood = loglikelihoodC - loglikelihoodF; @@ -103,8 +98,11 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f System.out.println(iter + "\tlog likelihood = " + Params.fmt.format(loglikelihood)); if (valid) { - valLoglikelihoodC = fwdbwd(false, valSeqs); - valLoglikelihoodF = fwdbwd(true, valSeqs); + ForwardBackward fb = new ForwardBackward(hmm, false, trainSet); + valLoglikelihoodC = fb.getLogProb(); + + fb = new ForwardBackward(hmm, true, trainSet); + valLoglikelihoodF = fb.getLogProb(); valLoglikelihood = valLoglikelihoodC - valLoglikelihoodF; System.out.println("\tvalC=" + valLoglikelihoodC + ", valF=" + valLoglikelihoodF); @@ -158,18 +156,18 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f //If TRUE reestimation computed using the ViterbiTraining algorithm //if FALSE reestimation computed using the Forward-Backward alogirth if (!TrainingWithViterbi) { - Forward fwdC = fwdsC[s]; - Backward bwdC = bwdsC[s]; - double PC = logPC[s]; // NOT exp. + Forward fwdC = fwdbwdC.getFwds(s); + Backward bwdC = fwdbwdC.getBwds(s); + double PC = fwdbwdC.getLogP(s); - Forward fwdF = fwdsF[s]; - Backward bwdF = bwdsF[s]; - double PF = logPF[s]; // NOT exp. + Forward fwdF = fwdbwdF.getFwds(s); + Backward bwdF = fwdbwdF.getBwds(s); + double PF = fwdbwdF.getLogP(s); int seqLen = trainSet.seq[s].getLen(); for (int i = 1; i <= seqLen; i++) - for (int k = 0; k < Model.nstate; k++) // @@ ektos begin kai end? + for (int k = 0; k < Model.nstate; k++) // without begin and end { if (esyminv[trainSet.seq[s].getSym(i - 1)] < 0) throw new Exception("ERROR: Symbol " + trainSet.seq[s].getSym(i - 1) + @@ -181,7 +179,7 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f } //-Calc new transitions - for (int i = 0; i <= seqLen - 1; i++) // @@ 1 prin + 1 meta? + for (int i = 0; i <= seqLen - 1; i++) { int lab = (trainSet.seq[s].getNPObs((i + 1) - 1)); @@ -219,18 +217,13 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f - PF); } - } else - { - String vPathC = vPathsC[s]; - String vPathF = vPathsF[s]; - - ViterbiTrainingExp(vPathC, trainSet.seq[s], AC, EC, weightsL); - ViterbiTrainingExp(vPathF, trainSet.seq[s], AF, EF, weightsL); + } else { + vtC.Exp(s, trainSet.seq[s], AC, EC, weightsL); + vtF.Exp(s, trainSet.seq[s], AF, EF, weightsL); for (int i = 0; i < Model.nstate; i++) for (int j = 0; j < Model.nstate; j++) A[i][j] = AC[i][j] - AF[i][j]; - } //AddExpC_A( A, trainSet.seq[s], PC, fwdC, bwdC ); @@ -302,12 +295,17 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f hmm = new HMM(tab); if (TrainingWithViterbi) { - loglikelihoodC = ViterbiTraining(trainSet, vPathsC, logPC, false, weightsL); - loglikelihoodF = ViterbiTraining(trainSet, vPathsF, logPF, true, weightsL); + vtC = new ViterbiTraining(hmm, trainSet, false, weightsL); + loglikelihoodC =vtC.getLogProb(); + + vtF = new ViterbiTraining(hmm, trainSet, true, weightsL); + loglikelihoodF =vtF.getLogProb(); } else { // Compute Forward and Backward tables for the sequences - loglikelihoodC = fwdbwd(fwdsC, bwdsC, logPC, false, trainSet, weightsL); - loglikelihoodF = fwdbwd(fwdsF, bwdsF, logPF, true, trainSet, weightsL); + fwdbwdC = new ForwardBackward(hmm, false, trainSet, weightsL); + loglikelihoodC = fwdbwdC.getLogProb(); + fwdbwdF = new ForwardBackward(hmm, true, trainSet, weightsL); + loglikelihoodF = fwdbwdF.getLogProb(); } System.out.println("\tC=" + loglikelihoodC + ", F=" + loglikelihoodF); @@ -317,8 +315,11 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f System.out.println(iter + "\tlog likelihood = " + loglikelihood + "\t\t diff = " + logdiff); if (valid) { - valLoglikelihoodC = fwdbwd(false, valSeqs); - valLoglikelihoodF = fwdbwd(true, valSeqs); + ForwardBackward fb = new ForwardBackward(hmm, false, trainSet); + valLoglikelihoodC = fb.getLogProb(); + + fb = new ForwardBackward(hmm, true, trainSet); + valLoglikelihoodF = fb.getLogProb(); valLoglikelihood = valLoglikelihoodC - valLoglikelihoodF; System.out.println("\tvalC=" + valLoglikelihoodC + ", valF=" + valLoglikelihoodF); diff --git a/src/hmm/ForwardBackward.java b/src/hmm/ForwardBackward.java new file mode 100644 index 0000000..738ef16 --- /dev/null +++ b/src/hmm/ForwardBackward.java @@ -0,0 +1,146 @@ +/* + * Copyright (C) 2019. Greenweaves Software Pty Ltd + * This is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This software is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with GNU Emacs. If not, see + * REAR Reversal Distance + */ + +package hmm; + +public class ForwardBackward +{ + private double loglikelihood = 0.D; + private double[] logP; + private Forward[] fwds; + private Backward[] bwds; + + public ForwardBackward(HMM hmm, boolean free, SeqSet fbseqs, WeightsL weightsL) throws Exception { + this.fwds = new Forward[fbseqs.nseqs]; + this.bwds = new Backward[fbseqs.nseqs]; + this.logP = new double[fbseqs.nseqs]; + Activ[][][] acts = null; + + Run(hmm, this.fwds, this.bwds, this.logP, free, fbseqs, acts, weightsL); + } + + public ForwardBackward(HMM hmm, boolean free, SeqSet fbseqs) throws Exception { + this.fwds = new Forward[fbseqs.nseqs]; + this.bwds = new Backward[fbseqs.nseqs]; + this.logP = new double[fbseqs.nseqs]; + Activ[][][] acts = null; + WeightsL weightsL = new WeightsL(fbseqs.nseqs); + + Run(hmm, this.fwds, this.bwds, this.logP, free, fbseqs, acts, weightsL); + + } + + public ForwardBackward(HMM hmm, boolean free, SeqSet fbseqs, Activ[][][] fbacts) throws Exception { + this.fwds = new Forward[fbseqs.nseqs]; + this.bwds = new Backward[fbseqs.nseqs]; + this.logP = new double[fbseqs.nseqs]; + WeightsL weightsL = new WeightsL(fbseqs.nseqs); + + Run(hmm, this.fwds, this.bwds, this.logP, free, fbseqs, fbacts, weightsL); + + } + + public ForwardBackward(HMM hmm, boolean free, SeqSet fbSeqs, Activ[][][] fbActs, WeightsL weightsL) throws Exception { + this.fwds = new Forward[fbSeqs.nseqs]; + this.bwds = new Backward[fbSeqs.nseqs]; + this.logP = new double[fbSeqs.nseqs]; + + Run(hmm, this.fwds, this.bwds, this.logP, free, fbSeqs, fbActs, weightsL); + } + + public void Run(HMM hmm, Forward[] fwds, Backward[] bwds, double[] logP, boolean free, + SeqSet fbSeqs, Activ[][][] fbActs, WeightsL weightsL) throws Exception { + System.out.println("\tComputing Forward+Backward (" + ((free) ? "Free" : "Clumped") + ")"); + System.out.print("\t"); + + for (int i = 0; i < fbSeqs.nseqs; i++) + System.out.print("-"); + + System.out.println(""); + System.out.print("\t"); + + boolean errors = false; + for (int s = 0; s < fbSeqs.nseqs; s++) { + if (Params.HNN) { + double[][] lge = new double[hmm.nstte][fbSeqs.seq[s].getLen()]; + for (int i = 0; i < fbSeqs.seq[s].getLen(); i++) { + lge[0][i] = hmm.getLoge(0, fbSeqs.seq[s], i); + lge[hmm.nstte - 1][i] = hmm.getLoge(0, fbSeqs.seq[s], i); + + for (int k = 1; k < hmm.nstte - 1; k++) { + lge[k][i] = Math.log(fbActs[Model.slab[k]][s][i].layer3[0]); + } + } + + fwds[s] = new Forward(hmm, fbSeqs.seq[s], lge, free); + bwds[s] = new Backward(hmm, fbSeqs.seq[s], lge, free); + + } else { + try { + fwds[s] = new Forward(hmm, fbSeqs.seq[s], free); + bwds[s] = new Backward(hmm, fbSeqs.seq[s], free); + } catch (Exception e) { + System.out.println(e.getMessage() + "The error occured at sequence " + (s + 1) + " " + fbSeqs.seq[s].header); + errors = true; + continue; + } + } + + System.out.print("*"); + + logP[s] = fwds[s].logprob(); + + if (s < fbSeqs.nseqs) { + loglikelihood += logP[s] * weightsL.getWeightL(fbSeqs.seq[s].getIndexID()); + } + + } + System.out.println(); + + if (errors) + throw new Exception("Errors ocurred..."); + + } + + public double getLogProb(){ + return this.loglikelihood; + } + + public double[] getLogP(){ + return this.logP; + } + + public double getLogP(int s){ + return this.logP[s]; + } + + public Forward[] getFwds(){ + return this.fwds; + } + + public Forward getFwds(int s){ + return this.fwds[s]; + } + + public Backward[] getBwds(){ + return this.bwds; + } + + public Backward getBwds(int s){ + return this.bwds[s]; + } +} \ No newline at end of file diff --git a/src/hmm/GEM.java b/src/hmm/GEM.java index 9addb35..fff9177 100644 --- a/src/hmm/GEM.java +++ b/src/hmm/GEM.java @@ -152,7 +152,8 @@ private void EMstep(SeqSet trainSet, final Probs tab0, WeightsL weightsL) throws for (int b = 0; b < Model.nesym; b++) esyminv[Model.esym.charAt(b)] = b; - loglikelihood = fwdbwd(fwds, bwds, logP, false, trainSet, weightsL); + ForwardBackward fwdbwd = new ForwardBackward(hmm, false, trainSet, weightsL); + loglikelihood = fwdbwd.getLogProb(); if (loglikelihood == Double.NEGATIVE_INFINITY) System.out.println("Probable illegal transition found"); @@ -167,11 +168,10 @@ private void EMstep(SeqSet trainSet, final Probs tab0, WeightsL weightsL) throws System.out.print("."); //Compute estimates for A and E - - Forward fwd = fwds[s]; - Backward bwd = bwds[s]; + Forward fwd = fwdbwd.getFwds(s); + Backward bwd = fwdbwd.getBwds(s); int seqLen = trainSet.seq[s].getLen(); - double P = logP[s]; + double P = fwdbwd.getLogP(s); for (int i = 1; i <= seqLen; i++) for (int k = 0; k < Model.nstate; k++) { diff --git a/src/hmm/juchmme.java b/src/hmm/Juchmme.java similarity index 98% rename from src/hmm/juchmme.java rename to src/hmm/Juchmme.java index e596056..1ca8e86 100644 --- a/src/hmm/juchmme.java +++ b/src/hmm/Juchmme.java @@ -24,7 +24,7 @@ public static void main(String[] args) throws Exception { long startTime = System.currentTimeMillis(); System.out.println("JUCHMME :: Java Utility for Class Hidden Markov Models and Extensions"); - System.out.println("Version 1.0.3; April 2019"); + System.out.println("Version 1.0.4; September 2019"); System.out.println("Copyright (C) 2019 Pantelis Bagos"); System.out.println("Freely distributed under the GNU General Public Licence (GPLv3)"); System.out.println("--------------------------------------------------------------------------"); diff --git a/src/hmm/ML.java b/src/hmm/ML.java index b1f3cfc..b6db1c4 100644 --- a/src/hmm/ML.java +++ b/src/hmm/ML.java @@ -4,12 +4,7 @@ import java.util.*; class ML extends TrainAlgo { - public double valLog; - private boolean valid; // If true enable EARLY functionality private double[][] E; - private double[][] A; - - private Probs tab; public ML(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, WeightsL weightsL) throws Exception { valid = true; @@ -34,7 +29,7 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f double[][] down_gradA = new double[Model.nstate][Model.nstate]; double[][] down_gradE = new double[Model.nstate][Model.nesym]; - // Set up the inverse of b -> esym.charAt(b); assume all esyms <= 'Z' + // Set up the inverse of b -> esym.charAt(b); int[] esyminv = new int[Model.esyminv]; for (int i = 0; i < esyminv.length; i++) @@ -43,15 +38,6 @@ public void Run(final SeqSet trainSet, final Probs tab0, final SeqSet valSeqs, f for (int b = 0; b < Model.nesym; b++) esyminv[Model.esym.charAt(b)] = b; - /* - // Initially use random transition and emission matrices - for (int k=0; k + * REAR Reversal Distance + */ + +package hmm; + +import java.util.*; + +public class ViterbiTraining { + private String[] vPaths; + private double[] logPs; + private double loglikelihood = 0.D; + + public ViterbiTraining(HMM hmm, SeqSet seqs, boolean free, WeightsL weightsL) { + System.out.println("\tComputing Viterbi Training (" + ((free) ? "Free" : "Clumped") + ")"); + + vPaths = new String[seqs.nseqs]; + logPs = new double[seqs.nseqs]; + + System.out.print("\t"); + + for (int i = 0; i < seqs.nseqs; i++) + System.out.print("-"); + + System.out.println(""); + + Viterbi v; + for (int s = 0; s < seqs.nseqs; s++) { + v = new Viterbi(hmm, seqs.seq[s], free); + vPaths[s] = v.getPath(); // Viterbi Path + logPs[s] = v.getProb(); // Viterbi likelihood + loglikelihood += v.getProb() * weightsL.getWeightL(s); + } + + } + + public double getLogProb() { + return this.loglikelihood; + } + + public String[] getvPaths() { + return this.vPaths; + } + + public String getvPath(int s) { + return this.vPaths[s]; + } + + public double[] getLogPs() { + return this.logPs; + } + + public double getvLogP(int s) { + return this.logPs[s]; + } + + public void Exp(int indexOfPath, Seq x, double[][] A, double[][] E, WeightsL weightsL) { + String[] vPath; + int length = x.getLen();// Sequence Length + + //States length. All states must have the same length + int stateLen = Model.state[0].length(); + vPath = new String[length]; + + //Split ViterbiPath to a String Array with path states + vPath = vPaths[indexOfPath].split("(?<=\\G.{" + stateLen + "})"); + + Score(vPath, A, x, E, weightsL); + } + + private void Score(String seqPath[], double A[][], Seq x, double E[][], WeightsL weightsL) { + //if exists the begin state, find the next and scoring + if (Params.ALLOW_BEGIN) { + int k = Arrays.asList(Model.state).indexOf(seqPath[0]); + A[0][k] = A[0][k] + (1 * weightsL.getWeightL(x.getIndexID())); + } + + int seqLen = x.getLen(); + int sym, row, col, i; + for (i = 1; i <= seqLen - 1; i++) { + sym = x.getNESym(i - 1); + row = Arrays.asList(Model.state).indexOf(seqPath[i - 1]); + col = Arrays.asList(Model.state).indexOf(seqPath[i]); + + A[row][col] = A[row][col] + (1 * weightsL.getWeightL(x.getIndexID())); + E[row][sym] = E[row][sym] + (1 * weightsL.getWeightL(x.getIndexID())); + + } + + //Score for the last state + sym = x.getNESym(i - 1); + row = Arrays.asList(Model.state).indexOf(seqPath[i - 1]); + E[row][sym] = E[row][sym] + (1 * weightsL.getWeightL(x.getIndexID())); + + //if exists the end state, find the previous state and scoring + if (Params.ALLOW_END) { + row = Arrays.asList(Model.state).indexOf(seqPath[i - 1]); + col = (Model.nstate) - 1; + A[row][col] = A[row][col] + (1 * weightsL.getWeightL(x.getIndexID())); + } + + } + +} \ No newline at end of file