From 53fdc1b40b5e0362a7ab473db56a9d646b51fdd3 Mon Sep 17 00:00:00 2001 From: piotrszul Date: Wed, 10 Apr 2019 11:21:45 +1000 Subject: [PATCH] [#104] Implement regression tests (#109) * Fixed the issue with boostrap sample, but using the actual sample indexes (with repeats) rather than the distinct set of indexes (#101) * Tech/104/0.2 (#105) * Added regression test cases generation scripts * Added regression cases and unit test * Moved execution of regression tests to 'regression-test' profile * Enabled all regression test cases * Update command line for regression tests * Fixed chr22 regression cmd line * Clean up: removed all splitting classes * Refactored regression tests to use Parameterized * Update the regression cases generation scripts --- .gitignore | 1 + dev/test-gen-regression-cases.sh | 4 +- .../algo/ClassificationSplitter.java | 7 - .../algo/JClassificationSplitter.java | 87 -------- .../JConfusionClassificationSplitter.java | 92 -------- .../JContinousClassificationFastSplitter.java | 50 ----- .../algo/JMaskedClassificationSplitter.java | 68 ------ .../split/JOrderedFastIndexedSplitter.java | 10 +- .../algo/split/JOrderedIndexedSplitter.java | 10 +- .../au/csiro/variantspark/algo/Split.scala | 10 + .../regression/.CNAE-9-imp_category.csv.crc | Bin 32 -> 0 bytes .../regression/.chr22-imp_22_16050408.csv.crc | Bin 36 -> 0 bytes ...nth_2000_500_fact_10_0.0-imp_cat10.csv.crc | Bin 32 -> 0 bytes ...ynth_2000_500_fact_10_0.0-imp_cat2.csv.crc | Bin 32 -> 0 bytes ...h_2000_500_fact_10_0.995-imp_cat10.csv.crc | Bin 32 -> 0 bytes ...th_2000_500_fact_10_0.995-imp_cat2.csv.crc | Bin 32 -> 0 bytes ...ynth_2000_500_fact_3_0.0-imp_cat10.csv.crc | Bin 32 -> 0 bytes ...synth_2000_500_fact_3_0.0-imp_cat2.csv.crc | Bin 32 -> 0 bytes ...th_2000_500_fact_3_0.995-imp_cat10.csv.crc | Bin 32 -> 0 bytes ...nth_2000_500_fact_3_0.995-imp_cat2.csv.crc | Bin 32 -> 0 bytes .../data/regression/chr22-imp_22_16050408.csv | 200 +++++++++--------- .../algo/ClassificationSplitterTest.scala | 83 -------- .../algo/split/IndexedSplitterGiniTest.scala | 90 ++++++++ .../perf/ClassificationSplitterPerfTest.scala | 104 --------- .../ImportanceDatasetRegressionTest.scala | 32 +++ .../regression/ImportanceRegressionTest.scala | 104 ++------- .../ImportanceSynthRegressionTest.scala | 43 ++++ 27 files changed, 313 insertions(+), 682 deletions(-) delete mode 100644 src/main/java/au/csiro/variantspark/algo/ClassificationSplitter.java delete mode 100644 src/main/java/au/csiro/variantspark/algo/JClassificationSplitter.java delete mode 100644 src/main/java/au/csiro/variantspark/algo/JConfusionClassificationSplitter.java delete mode 100644 src/main/java/au/csiro/variantspark/algo/JContinousClassificationFastSplitter.java delete mode 100644 src/main/java/au/csiro/variantspark/algo/JMaskedClassificationSplitter.java delete mode 100644 src/test/data/regression/.CNAE-9-imp_category.csv.crc delete mode 100644 src/test/data/regression/.chr22-imp_22_16050408.csv.crc delete mode 100644 src/test/data/regression/.synth_2000_500_fact_10_0.0-imp_cat10.csv.crc delete mode 100644 src/test/data/regression/.synth_2000_500_fact_10_0.0-imp_cat2.csv.crc delete mode 100644 src/test/data/regression/.synth_2000_500_fact_10_0.995-imp_cat10.csv.crc delete mode 100644 src/test/data/regression/.synth_2000_500_fact_10_0.995-imp_cat2.csv.crc delete mode 100644 src/test/data/regression/.synth_2000_500_fact_3_0.0-imp_cat10.csv.crc delete mode 100644 src/test/data/regression/.synth_2000_500_fact_3_0.0-imp_cat2.csv.crc delete mode 100644 src/test/data/regression/.synth_2000_500_fact_3_0.995-imp_cat10.csv.crc delete mode 100644 src/test/data/regression/.synth_2000_500_fact_3_0.995-imp_cat2.csv.crc delete mode 100644 src/test/scala/au/csiro/variantspark/algo/ClassificationSplitterTest.scala create mode 100644 src/test/scala/au/csiro/variantspark/algo/split/IndexedSplitterGiniTest.scala delete mode 100644 src/test/scala/au/csiro/variantspark/perf/ClassificationSplitterPerfTest.scala create mode 100644 src/test/scala/au/csiro/variantspark/test/regression/ImportanceDatasetRegressionTest.scala create mode 100644 src/test/scala/au/csiro/variantspark/test/regression/ImportanceSynthRegressionTest.scala diff --git a/.gitignore b/.gitignore index d905302c..3aa56429 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,4 @@ build dist _build spark-warehouse +.*.crc diff --git a/dev/test-gen-regression-cases.sh b/dev/test-gen-regression-cases.sh index be2eb385..85a35c88 100755 --- a/dev/test-gen-regression-cases.sh +++ b/dev/test-gen-regression-cases.sh @@ -32,7 +32,7 @@ PREFIX="CNAE-9" "${FWDIR}/bin/variant-spark" --spark --master local[2] -- importance -if "${DATA_DIR}/${PREFIX}-wide.csv" -ff "${DATA_DIR}/${PREFIX}-labels.csv" \ -fc "${RESP}" \ -on 100 -of "${OUTPUT_DIR}/${PREFIX}-imp_${RESP}.csv" \ - -ivo 10 \ + -io "{\"defVariableType\":\"ORDINAL(10)\"}" \ -it csv -v -ro -rn 100 -rbs 50 -sp 4 -sr 17 # @@ -49,7 +49,7 @@ for CASE in ${FWDIR}/src/test/data/synth/*-meta.txt; do "${FWDIR}/bin/variant-spark" --spark --master local[2] -- importance -if "${DATA_DIR}/${PREFIX}-wide.csv" -ff "${DATA_DIR}/${PREFIX}-labels.csv" \ -fc "${RESP}" \ -on 100 -of "${OUTPUT_DIR}/${PREFIX}-imp_${RESP}.csv" \ - -ivo ${IVO} \ + -io "{\"defVariableType\":\"ORDINAL(${IVO})\"}" \ -it csv -v -ro -rn 100 -rbs 50 -sp 4 -sr 17 done done diff --git a/src/main/java/au/csiro/variantspark/algo/ClassificationSplitter.java b/src/main/java/au/csiro/variantspark/algo/ClassificationSplitter.java deleted file mode 100644 index 7eb1c91c..00000000 --- a/src/main/java/au/csiro/variantspark/algo/ClassificationSplitter.java +++ /dev/null @@ -1,7 +0,0 @@ -package au.csiro.variantspark.algo; - -public interface ClassificationSplitter { - SplitInfo findSplit(double[] data, int[] splitIndices); - SplitInfo findSplit(int[] data, int[] splitIndices); - SplitInfo findSplit(byte[] data, int[] splitIndices); -} diff --git a/src/main/java/au/csiro/variantspark/algo/JClassificationSplitter.java b/src/main/java/au/csiro/variantspark/algo/JClassificationSplitter.java deleted file mode 100644 index 5dd7e0b5..00000000 --- a/src/main/java/au/csiro/variantspark/algo/JClassificationSplitter.java +++ /dev/null @@ -1,87 +0,0 @@ -package au.csiro.variantspark.algo; - -import java.util.Arrays; - -import au.csiro.variantspark.algo.impurity.FastGini; - -/** - * Fast gini based splitter. - * NOT MULITHREADED !!! (Caches state to avoid heap allocations) - * - * @author szu004 - * - */ -@SuppressWarnings("JavaDoc") -public class JClassificationSplitter implements ClassificationSplitter { - private final int[] leftSplitCounts; - private final int[] rightSplitCounts; - private final double[] leftRightGini = new double[2]; - private final int[] labels; - private final int nLevels; - - /** - * The outbounded version - * @param labels - * @param nCategories - */ - public JClassificationSplitter(int[] labels, int nCategories) { - this(labels, nCategories, Integer.MIN_VALUE); - } - - public JClassificationSplitter(int[] labels, int nCategories, int nLevels) { - this.labels = labels; - this.leftSplitCounts = new int[nCategories]; - this.rightSplitCounts = new int[nCategories]; - this.nLevels = nLevels; - } - - - @Override - public SplitInfo findSplit(double[] data,int[] splitIndices) { - SplitInfo result = null; - double minGini = Double.MAX_VALUE; - if (splitIndices.length < 2) { - return result; - } - - int actualNLevels = (nLevels > 0) ? nLevels : getLevelCount(data); - - for(int sp = 0 ; sp < actualNLevels - 1; sp ++) { - Arrays.fill(leftSplitCounts, 0); - Arrays.fill(rightSplitCounts, 0); - for(int i:splitIndices) { - if ((int)data[i] <=sp) { - leftSplitCounts[labels[i]]++; - } else { - rightSplitCounts[labels[i]]++; - } - } - double g = FastGini.splitGini(leftSplitCounts, rightSplitCounts, leftRightGini, true); - if (g < minGini ) { - result = new SplitInfo(sp, g, leftRightGini[0], leftRightGini[1]); - minGini = g; - } - } - return result; - } - - private int getLevelCount(double[] data) { - int maxLevel = 0; - for(double d:data) { - if ((int)d > maxLevel) { - maxLevel = (int)d; - } - } - return maxLevel+1; - } - - @Override - public SplitInfo findSplit(int[] data, int[] splitIndices) { - throw new RuntimeException("Not implemented yet"); - } - - @Override - public SplitInfo findSplit(byte[] data, int[] splitIndices) { - throw new RuntimeException("Not implemented yet"); - } -} diff --git a/src/main/java/au/csiro/variantspark/algo/JConfusionClassificationSplitter.java b/src/main/java/au/csiro/variantspark/algo/JConfusionClassificationSplitter.java deleted file mode 100644 index dda267c0..00000000 --- a/src/main/java/au/csiro/variantspark/algo/JConfusionClassificationSplitter.java +++ /dev/null @@ -1,92 +0,0 @@ -package au.csiro.variantspark.algo; - -import java.util.Arrays; -import java.util.function.BiConsumer; - -import au.csiro.variantspark.algo.impurity.FastGini; - -/** - * Fast gini based splitter. NOT MULITHREADED !!! - * Caches state to avoid heap allocations - * - * @author szu004 - * - */ -public class JConfusionClassificationSplitter implements ClassificationSplitter { - private final int[] leftSplitCounts; - private final int[] rightSplitCounts; - private final int[][] confusion; - private final double[] leftRightGini = new double[2]; - private final int[] labels; - private final int nCategories; - private final int nLevels; - - public JConfusionClassificationSplitter(int[] labels, int nCategories, int nLevels) { - this.labels = labels; - this.nCategories = nCategories; - this.nLevels = nLevels; - confusion = new int[nLevels][this.nCategories]; - leftSplitCounts = new int[this.nCategories]; - rightSplitCounts = new int[this.nCategories]; - } - - - @Override - public SplitInfo findSplit(double[] data, int[] splitIndices) { - return dofindSplit(splitIndices, (idx, conf) -> { - for (int i : idx) { - conf[(int) data[i]][labels[i]]++; - } - }); - } - - @Override - public SplitInfo findSplit(int[] data, int[] splitIndices) { - return dofindSplit(splitIndices, (idx, conf) -> { - for (int i : idx) { - conf[data[i]][labels[i]]++; - } - }); - } - - @Override - public SplitInfo findSplit(byte[] data, int[] splitIndices) { - return dofindSplit(splitIndices, (idx, conf) -> { - for (int i : idx) { - conf[(int) data[i]][labels[i]]++; - } - }); - } - - private SplitInfo dofindSplit(int[] splitIndices, BiConsumer confusionCalc) { - SplitInfo result = null; - double minGini = Double.MAX_VALUE; - - if (splitIndices.length < 2) { - return result; - } - - for (int[] aConfusion : confusion) { - Arrays.fill(aConfusion, 0); - } - - confusionCalc.accept(splitIndices, confusion); - - Arrays.fill(leftSplitCounts, 0); - Arrays.fill(rightSplitCounts, 0); - for (int[] l : confusion) { - ArrayOps.addEq(rightSplitCounts, l); - } - - for (int sp = 0; sp < nLevels - 1; sp++) { - ArrayOps.addEq(leftSplitCounts, confusion[sp]); - ArrayOps.subEq(rightSplitCounts, confusion[sp]); - double g = FastGini.splitGini(leftSplitCounts, rightSplitCounts, leftRightGini, true); - if (g < minGini) { - result = new SplitInfo(sp, g, leftRightGini[0], leftRightGini[1]); - minGini = g; - } - } - return result; - } -} diff --git a/src/main/java/au/csiro/variantspark/algo/JContinousClassificationFastSplitter.java b/src/main/java/au/csiro/variantspark/algo/JContinousClassificationFastSplitter.java deleted file mode 100644 index 00c13f94..00000000 --- a/src/main/java/au/csiro/variantspark/algo/JContinousClassificationFastSplitter.java +++ /dev/null @@ -1,50 +0,0 @@ -package au.csiro.variantspark.algo; - -import java.util.Arrays; - -import it.unimi.dsi.fastutil.doubles.DoubleArrays; - - -/** - * @author szu004 - * This is a naive implementation of precise (not binning) continous variable splitter - */ -public class JContinousClassificationFastSplitter implements ClassificationSplitter { - - private final int[] labels; - private final int noLabels; - - public JContinousClassificationFastSplitter(int[] labels, int noLabels) { - this.labels = labels; - this.noLabels = noLabels; - } - - @Override - public SplitInfo findSplit(final double[] data, int[] splitIndices) { - if (splitIndices.length < 2) { - // nothing to split - return null; - } - // create a dense rank for the data - // TODO: This needs to be move outside - int[] denseRank = new int[data.length]; - double rankValues[] = ArrayOps.denseRank(data, denseRank); - JConfusionClassificationSplitter splitter = new JConfusionClassificationSplitter(this.labels, this.noLabels, rankValues.length); - SplitInfo split = splitter.findSplit(denseRank, splitIndices); - // now need to convert the rank to the actual value - return split == null ? split: new SplitInfo(rankValues[(int)split.splitPoint()], split.gini(), split.leftGini(), split.rightGini()); - - } - - @Override - public SplitInfo findSplit(int[] data, int[] splitIndices) { - throw new UnsupportedOperationException("JContinousClassificationSplitter.findSplit(int[] ..."); - - } - - @Override - public SplitInfo findSplit(byte[] data, int[] splitIndices) { - throw new UnsupportedOperationException("JContinousClassificationSplitter.findSplit(byte[] ..."); - } - -} diff --git a/src/main/java/au/csiro/variantspark/algo/JMaskedClassificationSplitter.java b/src/main/java/au/csiro/variantspark/algo/JMaskedClassificationSplitter.java deleted file mode 100644 index 238689d7..00000000 --- a/src/main/java/au/csiro/variantspark/algo/JMaskedClassificationSplitter.java +++ /dev/null @@ -1,68 +0,0 @@ -package au.csiro.variantspark.algo; - -import java.util.Arrays; - -import au.csiro.variantspark.algo.impurity.FastGini; - - -/** - * Fast gini based splitter. - * NOT MULITHREADED !!! (Caches state to avoid heap allocations) - * - * @author szu004 - * - */ -public class JMaskedClassificationSplitter { - private final int[] leftSplitCounts; - private final int[] rightSplitCounts; - private final double[] leftRightGini = new double[2]; - private final int[] labels; - - public JMaskedClassificationSplitter(int[] labels, int nCategories) { - this.labels = labels; - this.leftSplitCounts = new int[nCategories]; - this.rightSplitCounts = new int[nCategories]; - } - - public SplitInfo findSplit(double[] data,int[] splitIndices) { - SplitInfo result = null; - double minGini = 1.0; - - /* TODO (review and test implementation) - * on the first pass we calculate the splits - * AND determine which split points are in this dataset - * because 0 is most likely we will do 0 as the initial pass */ - long splitCandidateSet = 0L; - for(int i:splitIndices) { - splitCandidateSet|=(1 << (int)data[i]); - } - - int sp = 0; - while(splitCandidateSet != 0L) { - while (splitCandidateSet != 0L && (splitCandidateSet & 1) == 0) { - sp ++; - splitCandidateSet >>= 1; - } - splitCandidateSet >>= 1; - - if (splitCandidateSet != 0L) { - Arrays.fill(leftSplitCounts, 0); - Arrays.fill(rightSplitCounts, 0); - for(int i:splitIndices) { - if ((int)data[i] <=sp) { - leftSplitCounts[labels[i]]++; - } else { - rightSplitCounts[labels[i]]++; - } - } - double g = FastGini.splitGini(leftSplitCounts, rightSplitCounts, leftRightGini); - if (g < minGini ) { - result = new SplitInfo(sp, g, leftRightGini[0], leftRightGini[1]); - minGini = g; - } - sp++; - } - } - return result; - } -} diff --git a/src/main/java/au/csiro/variantspark/algo/split/JOrderedFastIndexedSplitter.java b/src/main/java/au/csiro/variantspark/algo/split/JOrderedFastIndexedSplitter.java index b5760596..ac719f3e 100644 --- a/src/main/java/au/csiro/variantspark/algo/split/JOrderedFastIndexedSplitter.java +++ b/src/main/java/au/csiro/variantspark/algo/split/JOrderedFastIndexedSplitter.java @@ -46,10 +46,12 @@ protected SplitInfo doFindSplit(int[] splitIndices) { if (!thisAggregator.isEmpty()) { // only consider value that appeared at least once in the split impurityCalc.update(thisAggregator); - double thisImpurity = impurityCalc.getValue(leftRightImpurity); - if (thisImpurity < minImpurity) { - result = new SplitInfo(sp, thisImpurity, leftRightImpurity.left(), leftRightImpurity.right()); - minImpurity = thisImpurity; + if (impurityCalc.hasProperSplit()) { + double thisImpurity = impurityCalc.getValue(leftRightImpurity); + if (thisImpurity < minImpurity) { + result = new SplitInfo(sp, thisImpurity, leftRightImpurity.left(), leftRightImpurity.right()); + minImpurity = thisImpurity; + } } } } diff --git a/src/main/java/au/csiro/variantspark/algo/split/JOrderedIndexedSplitter.java b/src/main/java/au/csiro/variantspark/algo/split/JOrderedIndexedSplitter.java index 835628fa..4639ae87 100644 --- a/src/main/java/au/csiro/variantspark/algo/split/JOrderedIndexedSplitter.java +++ b/src/main/java/au/csiro/variantspark/algo/split/JOrderedIndexedSplitter.java @@ -49,10 +49,12 @@ public SplitInfo doFindSplit(int[] splitIndices) { impurityCalc.update(i); } } - double g = impurityCalc.getValue(leftRightImpurity); - if (g < minImpurity ) { - result = new SplitInfo(sp, g, leftRightImpurity.left(), leftRightImpurity.right()); - minImpurity = g; + if (impurityCalc.hasProperSplit()) { + double g = impurityCalc.getValue(leftRightImpurity); + if (g < minImpurity ) { + result = new SplitInfo(sp, g, leftRightImpurity.left(), leftRightImpurity.right()); + minImpurity = g; + } } } return result; diff --git a/src/main/scala/au/csiro/variantspark/algo/Split.scala b/src/main/scala/au/csiro/variantspark/algo/Split.scala index aabfb268..0e0cb682 100644 --- a/src/main/scala/au/csiro/variantspark/algo/Split.scala +++ b/src/main/scala/au/csiro/variantspark/algo/Split.scala @@ -28,6 +28,16 @@ trait IndexedSplitAggregator { left.add(agg) right.sub(agg) } + + /** + * Is this a valid split that is one that does not put + * all elements to one side + */ + def hasProperSplit:Boolean = !left.isEmpty() && !right.isEmpty() + + /** + * Get split impurity value + */ def getValue(outSplitImp:SplitImpurity):Double = { left.splitValue(right, outSplitImp) } diff --git a/src/test/data/regression/.CNAE-9-imp_category.csv.crc b/src/test/data/regression/.CNAE-9-imp_category.csv.crc deleted file mode 100644 index c0577b43b4edc799d0693360ca3bb16b4c106d13..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 32 ocmYc;N@ieSU}9*@>0i3WFF$u>*P^*jZgmzr^Ac~|>3KC10KY;GEdT%j diff --git a/src/test/data/regression/.chr22-imp_22_16050408.csv.crc b/src/test/data/regression/.chr22-imp_22_16050408.csv.crc deleted file mode 100644 index 0a7a8b26406f1c971bae8643e9e7cccd937a6050..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 36 scmYc;N@ieSU}CV`aD0))wfgtR8dwFIbhqtV;#Xr-DI&?nIbC-<0PzeBQvd(} diff --git a/src/test/data/regression/.synth_2000_500_fact_10_0.0-imp_cat10.csv.crc b/src/test/data/regression/.synth_2000_500_fact_10_0.0-imp_cat10.csv.crc deleted file mode 100644 index 296190d39f4d74ea60d813aa85799c80cd153c62..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 32 ocmYc;N@ieSU}Ctd9hwqk;rdto*2(RMGIO2?wY^fgHOrs{0K7yEk^lez diff --git a/src/test/data/regression/.synth_2000_500_fact_10_0.0-imp_cat2.csv.crc b/src/test/data/regression/.synth_2000_500_fact_10_0.0-imp_cat2.csv.crc deleted file mode 100644 index 26c1610c7161d15962e4c42cf2e3a61745d458aa..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 32 ocmYc;N@ieSU}E5z|Cl?bE&bJuU6;T38HswW@Cpvyp`^SN0JeAxI{*Lx diff --git a/src/test/data/regression/.synth_2000_500_fact_10_0.995-imp_cat10.csv.crc b/src/test/data/regression/.synth_2000_500_fact_10_0.995-imp_cat10.csv.crc deleted file mode 100644 index f1b3523bdf47a9e9ae7fd2c9f14d9ca4cd77bc11..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 32 qcmV+*0N?*(a$^7h00IE! ( l & split ).size).toArray - // for (i <- 0 until data.length -1) { - // val leftCount = labels.map(l => ( l & split & data(i)).size).toArray - // } - // } - // - // - // def testBits { - // val rg = new XorShift1024StarRandomGenerator(13) - // val nLabels = 10000 - // val labels = Array.fill(nLabels)(Math.abs(rg.nextInt) % 2) - // val splitIndexes = Range(0, 10).toArray - // val data = Array.fill(nLabels)((Math.abs(rg.nextInt()) % 3).toByte) - // // encode labels as bytes - // val bSplit = BitSet(splitIndexes:_*) - // val bLabels = Range(0,2).map(i => BitSet(labels.indices.filter(labels(_) == i).toArray:_*)).toArray - // val bData = Range(0,3).map(i => BitSet(data.indices.filter(labels(_) == i).toArray:_*)).toArray - // Timed.time { - // for (i <- 0 until 50000) { - // findBitmapSplit(bData, bLabels, bSplit) - // } - // }.report("Splitting") - // Timed.time { - // for (i <- 0 until 50000) { - // findBitmapSplit(bData, bLabels, bSplit) - // } - // }.report("Splitting1") - // Timed.time { - // for (i <- 0 until 50000) { - // findBitmapSplit(bData, bLabels, bSplit) - // } - // }.report("Splitting2") - // } - - -} \ No newline at end of file diff --git a/src/test/scala/au/csiro/variantspark/test/regression/ImportanceDatasetRegressionTest.scala b/src/test/scala/au/csiro/variantspark/test/regression/ImportanceDatasetRegressionTest.scala new file mode 100644 index 00000000..ff527d72 --- /dev/null +++ b/src/test/scala/au/csiro/variantspark/test/regression/ImportanceDatasetRegressionTest.scala @@ -0,0 +1,32 @@ +package au.csiro.variantspark.test.regression + +import java.util.Collection + +import scala.collection.JavaConverters.asJavaCollectionConverter + +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + + +/** + * Runs regression test for real world datasets + */ +@RunWith(classOf[Parameterized]) +class ImportanceDatasetRegressionTest(filenameWithExpected:String, cmdLine:String) extends ImportanceRegressionTest { + + @Test + def testDatasetImportanceOutputMatches() { + runRegression(cmdLine, filenameWithExpected) + } +} + +object ImportanceDatasetRegressionTest { + + @Parameters + def datasets():Collection[Array[Object]] = List( + Array[Object]("chr22-imp_22_16050408.csv", "importance -if data/chr22_1000.vcf -ff data/chr22-labels.csv -fc 22_16050408 -v -rn 100 -rbs 50 -ro -sr 17 -on 100 -sp 4 -of ${outputFile}"), + Array[Object]("CNAE-9-imp_category.csv", """importance -if data/CNAE-9-wide.csv -it csv -ff data/CNAE-9-labels.csv -fc category -v -ro -rn 100 -rbs 50 -sr 17 -io {"defVariableType":"ORDINAL(10)"} -sp 4 -on 100 -of ${outputFile}""") + ).asJavaCollection +} \ No newline at end of file diff --git a/src/test/scala/au/csiro/variantspark/test/regression/ImportanceRegressionTest.scala b/src/test/scala/au/csiro/variantspark/test/regression/ImportanceRegressionTest.scala index 7f52cdac..51a2c6de 100644 --- a/src/test/scala/au/csiro/variantspark/test/regression/ImportanceRegressionTest.scala +++ b/src/test/scala/au/csiro/variantspark/test/regression/ImportanceRegressionTest.scala @@ -10,8 +10,29 @@ import org.apache.spark.sql.SparkSession import org.junit.BeforeClass import org.apache.commons.lang3.text.StrSubstitutor import collection.JavaConverters._ -import org.junit.runner.RunWith -import org.junit.Ignore + + +/** + * Base class for regression test that compare importance output for know + * datasets and parameters against the recorded one assumed to be correct. + * The expected output can be updated with the `dev/test-get-regression-cases.sh` + */ +abstract class ImportanceRegressionTest { + + import ImportanceRegressionTest._ + def expected(fileName:String):String = new File(ExpectedDir, fileName).getPath + def synth(fileName:String):String = new File(SynthDataDir, fileName).getPath + def actual(fileName:String):String = new File(ActualDir, fileName).getPath + + def runRegression(cmdLine:String, expextedFileName:String, sessionBuilder:SparkSession.Builder = MasterLocal2) { + withSessionBuilder(sessionBuilder) { _ => + val outputFile = actual(expextedFileName) + val sub = new StrSubstitutor(Map("outputFile" -> outputFile).asJava) + VariantSparkApp.main(sub.replace(cmdLine).split(" ")) + assertSameContent(expected(expextedFileName), outputFile) + } + } +} object ImportanceRegressionTest { @@ -44,84 +65,5 @@ object ImportanceRegressionTest { } } -class ImportanceRegressionTest { - - import ImportanceRegressionTest._ - - def expected(fileName:String):String = new File(ExpectedDir, fileName).getPath - def synth(fileName:String):String = new File(SynthDataDir, fileName).getPath - def actual(fileName:String):String = new File(ActualDir, fileName).getPath - - //TODO: Refactor with ParametrizedTest: see: https://www.tutorialspoint.com/junit/junit_parameterized_test.htm - def runRegression(cmdLine:String, expextedFileName:String, sessionBuilder:SparkSession.Builder = MasterLocal2) { - withSessionBuilder(MasterLocal2) { _ => - val outputFile = actual(expextedFileName) - val sub = new StrSubstitutor(Map("outputFile" -> outputFile).asJava) - VariantSparkApp.main(sub.replace(cmdLine).split(" ")) - assertSameContent(expected(expextedFileName), outputFile) - } - } - - def runSynthRegression(caseFile:String) { - // synth_2000_500_fact_3_0.995-imp_cat2.csv - val caseFileRE = """(synth_([^_]+)_([^_]+)_fact_([^_]+)_([^_]+))-imp_([^_]+).csv""".r - caseFile match { - case caseFileRE(prefix,_,_,ivo,_,response) => runRegression(s"importance -if ${synth(prefix)}-wide.csv -ff ${synth(prefix)}-labels.csv -fc ${response} -it csv -ivo ${ivo} -v -rn 100 -rbs 50 -ro -sr 17 -on 100 -sp 4 -of $${outputFile}", - caseFile) - } - } - - @Test - def testVFCImportance() { - runRegression("importance -if data/chr22_1000.vcf -ff data/chr22-labels.csv -fc 22_16050408 -v -rn 100 -rbs 50 -ro -sr 17 -on 100 -sp 4 -of ${outputFile}", - "chr22-imp_22_16050408.csv") - } - - @Test - def testCNAEImportance() { - runRegression("importance -if data/CNAE-9-wide.csv -it csv -ff data/CNAE-9-labels.csv -fc category -v -ro -rn 100 -rbs 50 -sr 17 -ivo 10 -sp 4 -on 100 -of ${outputFile}", - "CNAE-9-imp_category.csv") - } - - @Test - def test_synth_2000_500_fact_3_0_995_imp_cat2() { - runSynthRegression("synth_2000_500_fact_3_0.995-imp_cat2.csv") - } - - @Test - def test_synth_2000_500_fact_3_0_995_imp_cat10() { - runSynthRegression("synth_2000_500_fact_3_0.995-imp_cat10.csv") - } - - @Test - def test_synth_2000_500_fact_3_0_imp_cat2() { - runSynthRegression("synth_2000_500_fact_3_0.0-imp_cat2.csv") - } - - @Test - def test_synth_2000_500_fact_3_0_imp_cat10() { - runSynthRegression("synth_2000_500_fact_3_0.0-imp_cat10.csv") - } - - @Test - def test_synth_2000_500_fact_10_0_995_imp_cat2() { - runSynthRegression("synth_2000_500_fact_10_0.995-imp_cat2.csv") - } - - @Test - def test_synth_2000_500_fact_10_0_995_imp_cat10() { - runSynthRegression("synth_2000_500_fact_10_0.995-imp_cat10.csv") - } - - @Test - def test_synth_2000_500_fact_10_0_imp_cat2() { - runSynthRegression("synth_2000_500_fact_10_0.0-imp_cat2.csv") - } - - @Test - def test_synth_2000_500_fact_10_0_imp_cat10() { - runSynthRegression("synth_2000_500_fact_10_0.0-imp_cat10.csv") - } -} diff --git a/src/test/scala/au/csiro/variantspark/test/regression/ImportanceSynthRegressionTest.scala b/src/test/scala/au/csiro/variantspark/test/regression/ImportanceSynthRegressionTest.scala new file mode 100644 index 00000000..21d52b26 --- /dev/null +++ b/src/test/scala/au/csiro/variantspark/test/regression/ImportanceSynthRegressionTest.scala @@ -0,0 +1,43 @@ +package au.csiro.variantspark.test.regression + +import java.util.Collection + +import scala.collection.JavaConverters.asJavaCollectionConverter + +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters +import com.google.common.io.PatternFilenameFilter + + +/** + * Runs regression test for syntetic datasets + * The datasets are generated with `dev/test-get-synth-data.sh` + */ + +@RunWith(classOf[Parameterized]) +class ImportanceSynthRegressionTest(caseFile:String) extends ImportanceRegressionTest { + import ImportanceSynthRegressionTest.caseFileRE + + @Test + def testCaseImportanceOutputMatches() { + caseFile match { + case caseFileRE(prefix,_,_,ivo,_,response) => runRegression(s"""importance -if ${synth(prefix)}-wide.csv -ff ${synth(prefix)}-labels.csv -fc ${response} -it csv -io {"defVariableType":"ORDINAL(${ivo})"} -v -rn 100 -rbs 50 -ro -sr 17 -on 100 -sp 4 -of $${outputFile}""", + caseFile) + } + } +} + +object ImportanceSynthRegressionTest { + import ImportanceRegressionTest._ + + /** + * Match test cases from such as: synth_2000_500_fact_3_0.995-imp_cat2.csv + */ + val caseFileRE = """(synth_([^_]+)_([^_]+)_fact_([^_]+)_([^_]+))-imp_([^_]+).csv""".r + + @Parameters + def testCases:Collection[Array[Object]] = ExpectedDir.listFiles(new PatternFilenameFilter(caseFileRE.pattern)) + .map(f => Array[Object](f.getName)).toList.asJavaCollection +} \ No newline at end of file