diff --git a/benchmarks/apache-spark/src/main/scala/org/renaissance/apache/spark/LogRegression.scala b/benchmarks/apache-spark/src/main/scala/org/renaissance/apache/spark/LogRegression.scala index a85b3ab6..63d050f4 100644 --- a/benchmarks/apache-spark/src/main/scala/org/renaissance/apache/spark/LogRegression.scala +++ b/benchmarks/apache-spark/src/main/scala/org/renaissance/apache/spark/LogRegression.scala @@ -7,11 +7,10 @@ import org.renaissance.Benchmark import org.renaissance.Benchmark._ import org.renaissance.BenchmarkContext import org.renaissance.BenchmarkResult -import org.renaissance.BenchmarkResult.Validators +import org.renaissance.BenchmarkResult.Assert import org.renaissance.License import org.renaissance.apache.spark.ResourceUtil.duplicateLinesFromUrl -import java.nio.file.Files import java.nio.file.Path @Name("log-regression") @@ -34,10 +33,35 @@ import java.nio.file.Path defaultValue = "20", summary = "Maximum number of iterations of the logistic regression algorithm." ) -@Configuration(name = "test", settings = Array("copy_count = 5")) +@Parameter(name = "expected_coefficient_sum", defaultValue = "-0.0653998570980114") +@Parameter(name = "expected_coefficient_sum_squares", defaultValue = "9.401759355004592E-5") +@Parameter(name = "expected_intercept_value", defaultValue = "2.287050116462375") +@Parameter(name = "expected_intercept_count", defaultValue = "1") +@Parameter(name = "expected_class_count", defaultValue = "2") +@Configuration( + name = "test", + settings = Array( + "copy_count = 5", + "expected_coefficient_sum = -0.06538768469885561", + "expected_coefficient_sum_squares = 9.395555567324299E-5", + "expected_intercept_value = 2.286718680950285", + "expected_class_count = 2" + ) +) @Configuration(name = "jmh") final class LogRegression extends Benchmark with SparkUtil { + // Utility class for validation. + + private case class ModelSummary( + coefficientSum: Double, + coefficientSumSquares: Double, + coefficientCount: Int, + interceptValue: Double, + interceptCount: Int, + classCount: Int + ) + // TODO: Consolidate benchmark parameters across the suite. // See: https://github.com/renaissance-benchmarks/renaissance/issues/27 @@ -45,17 +69,17 @@ final class LogRegression extends Benchmark with SparkUtil { private val inputFeatureCount = 692 - private var maxIterationsParam: Int = _ - private val lrRegularizationParam = 0.1 private val lrElasticNetMixingParam = 0.0 private val lrConvergenceToleranceParam = 0.0 - private var inputDataFrame: DataFrame = _ + private var lrMaxIterationsParam: Int = _ + + private var expectedModelSummary: ModelSummary = _ - private var outputLogisticRegression: LogisticRegressionModel = _ + private var inputDataFrame: DataFrame = _ private def loadData(inputFile: Path, featureCount: Int) = { sparkSession.read @@ -67,7 +91,17 @@ final class LogRegression extends Benchmark with SparkUtil { override def setUpBeforeAll(bc: BenchmarkContext): Unit = { setUpSparkContext(bc) - maxIterationsParam = bc.parameter("max_iterations").toPositiveInteger + lrMaxIterationsParam = bc.parameter("max_iterations").toPositiveInteger + + // Validation parameters. + expectedModelSummary = ModelSummary( + bc.parameter("expected_coefficient_sum").toDouble, + bc.parameter("expected_coefficient_sum_squares").toDouble, + inputFeatureCount, + bc.parameter("expected_intercept_value").toDouble, + bc.parameter("expected_intercept_count").toPositiveInteger, + bc.parameter("expected_class_count").toPositiveInteger + ) val inputFile = duplicateLinesFromUrl( getClass.getResource(inputResource), @@ -79,43 +113,81 @@ final class LogRegression extends Benchmark with SparkUtil { } override def run(bc: BenchmarkContext): BenchmarkResult = { - val lor = new LogisticRegression() + val logRegression = new LogisticRegression() .setElasticNetParam(lrElasticNetMixingParam) .setRegParam(lrRegularizationParam) .setTol(lrConvergenceToleranceParam) - .setMaxIter(maxIterationsParam) - - outputLogisticRegression = lor.fit(inputDataFrame) - - // TODO: add more in-depth validation - Validators.compound( - Validators.simple("class count", 2, outputLogisticRegression.numClasses), - Validators.simple( - "feature count", - inputFeatureCount, - outputLogisticRegression.numFeatures - ) + .setMaxIter(lrMaxIterationsParam) + + val logRegressionModel = logRegression.fit(inputDataFrame) + () => validate(logRegressionModel) + } + + private def validate(model: LogisticRegressionModel): Unit = { + // + // Validation currently supports only binary classification which returns + // a single intercept value. If multinomial logistic regression is needed, + // the validation needs to be updated to support multiple intercept values. + // + val actualModelSummary = summarizeModel(model) + validateSummary( + expectedModelSummary, + actualModelSummary, + coefficientSumTolerance = 0.1e-14, + coefficientSumSquaresTolerance = 0.1e-17, + interceptTolerance = 0.1e-13 ) } - override def tearDownAfterAll(bc: BenchmarkContext): Unit = { - if (dumpResultsBeforeTearDown && outputLogisticRegression != null) { - val outputFile = bc.scratchDirectory().resolve("output.txt") - dumpResult(outputLogisticRegression, outputFile) - } + private def summarizeModel(model: LogisticRegressionModel): ModelSummary = { + val coefficients = model.coefficients.toArray - tearDownSparkContext() + ModelSummary( + coefficients.sum, + coefficients.map(num => num * num).sum, + coefficients.length, + model.interceptVector(0), + model.interceptVector.size, + model.numClasses + ) } - private def dumpResult(lrm: LogisticRegressionModel, outputFile: Path) = { - val output = new StringBuilder - output.append(s"num features: ${lrm.numFeatures}\n") - output.append(s"num classes: ${lrm.numClasses}\n") - output.append(s"intercepts: ${lrm.interceptVector.toString}\n") - output.append(s"coefficients: ${lrm.coefficients.toString}\n") + private def validateSummary( + expected: ModelSummary, + actual: ModelSummary, + coefficientSumTolerance: Double, + coefficientSumSquaresTolerance: Double, + interceptTolerance: Double + ): Unit = { + Assert.assertEquals( + expected.coefficientSum, + actual.coefficientSum, + coefficientSumTolerance, + "coefficients sum" + ) + + Assert.assertEquals( + expected.coefficientSumSquares, + actual.coefficientSumSquares, + coefficientSumSquaresTolerance, + "coefficients sum of squares" + ) + + Assert.assertEquals(expected.coefficientCount, actual.coefficientCount, "coefficient count") - // Files.writeString() is only available from Java 11. - Files.write(outputFile, output.toString.getBytes) + Assert.assertEquals( + expected.interceptValue, + actual.interceptValue, + interceptTolerance, + "intercept value" + ) + + Assert.assertEquals(expected.interceptCount, actual.interceptCount, "intercept count") + + Assert.assertEquals(expected.classCount, actual.classCount, "class count") } + override def tearDownAfterAll(bc: BenchmarkContext): Unit = { + tearDownSparkContext() + } }