Skip to content

Commit

Permalink
Add log-regression validation
Browse files Browse the repository at this point in the history
  • Loading branch information
lovisek committed Feb 9, 2024
1 parent ce39529 commit 29d9ed9
Showing 1 changed file with 107 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ import org.renaissance.Benchmark
import org.renaissance.Benchmark._
import org.renaissance.BenchmarkContext
import org.renaissance.BenchmarkResult
import org.renaissance.BenchmarkResult.Validators
import org.renaissance.BenchmarkResult.Assert
import org.renaissance.License
import org.renaissance.apache.spark.ResourceUtil.duplicateLinesFromUrl

import java.nio.file.Files
import java.nio.file.Path

@Name("log-regression")
Expand All @@ -34,28 +33,53 @@ import java.nio.file.Path
defaultValue = "20",
summary = "Maximum number of iterations of the logistic regression algorithm."
)
@Configuration(name = "test", settings = Array("copy_count = 5"))
@Parameter(name = "expected_coefficient_sum", defaultValue = "-0.0653998570980114")
@Parameter(name = "expected_coefficient_sum_squares", defaultValue = "9.401759355004592E-5")
@Parameter(name = "expected_intercept_value", defaultValue = "2.287050116462375")
@Parameter(name = "expected_intercept_count", defaultValue = "1")
@Parameter(name = "expected_class_count", defaultValue = "2")
@Configuration(
name = "test",
settings = Array(
"copy_count = 5",
"expected_coefficient_sum = -0.06538768469885561",
"expected_coefficient_sum_squares = 9.395555567324299E-5",
"expected_intercept_value = 2.286718680950285",
"expected_class_count = 2"
)
)
@Configuration(name = "jmh")
final class LogRegression extends Benchmark with SparkUtil {

// Utility class for validation.

private case class ModelSummary(
coefficientSum: Double,
coefficientSumSquares: Double,
coefficientCount: Int,
interceptValue: Double,
interceptCount: Int,
classCount: Int
)

// TODO: Consolidate benchmark parameters across the suite.
// See: https://github.com/renaissance-benchmarks/renaissance/issues/27

private val inputResource = "/sample_libsvm_data.txt"

private val inputFeatureCount = 692

private var maxIterationsParam: Int = _

private val lrRegularizationParam = 0.1

private val lrElasticNetMixingParam = 0.0

private val lrConvergenceToleranceParam = 0.0

private var inputDataFrame: DataFrame = _
private var lrMaxIterationsParam: Int = _

private var expectedModelSummary: ModelSummary = _

private var outputLogisticRegression: LogisticRegressionModel = _
private var inputDataFrame: DataFrame = _

private def loadData(inputFile: Path, featureCount: Int) = {
sparkSession.read
Expand All @@ -67,7 +91,17 @@ final class LogRegression extends Benchmark with SparkUtil {
override def setUpBeforeAll(bc: BenchmarkContext): Unit = {
setUpSparkContext(bc)

maxIterationsParam = bc.parameter("max_iterations").toPositiveInteger
lrMaxIterationsParam = bc.parameter("max_iterations").toPositiveInteger

// Validation parameters.
expectedModelSummary = ModelSummary(
bc.parameter("expected_coefficient_sum").toDouble,
bc.parameter("expected_coefficient_sum_squares").toDouble,
inputFeatureCount,
bc.parameter("expected_intercept_value").toDouble,
bc.parameter("expected_intercept_count").toPositiveInteger,
bc.parameter("expected_class_count").toPositiveInteger
)

val inputFile = duplicateLinesFromUrl(
getClass.getResource(inputResource),
Expand All @@ -79,43 +113,81 @@ final class LogRegression extends Benchmark with SparkUtil {
}

override def run(bc: BenchmarkContext): BenchmarkResult = {
val lor = new LogisticRegression()
val logRegression = new LogisticRegression()
.setElasticNetParam(lrElasticNetMixingParam)
.setRegParam(lrRegularizationParam)
.setTol(lrConvergenceToleranceParam)
.setMaxIter(maxIterationsParam)

outputLogisticRegression = lor.fit(inputDataFrame)

// TODO: add more in-depth validation
Validators.compound(
Validators.simple("class count", 2, outputLogisticRegression.numClasses),
Validators.simple(
"feature count",
inputFeatureCount,
outputLogisticRegression.numFeatures
)
.setMaxIter(lrMaxIterationsParam)

val logRegressionModel = logRegression.fit(inputDataFrame)
() => validate(logRegressionModel)
}

private def validate(model: LogisticRegressionModel): Unit = {
//
// Validation currently supports only binary classification which returns
// a single intercept value. If multinomial logistic regression is needed,
// the validation needs to be updated to support multiple intercept values.
//
val actualModelSummary = summarizeModel(model)
validateSummary(
expectedModelSummary,
actualModelSummary,
coefficientSumTolerance = 0.1e-14,
coefficientSumSquaresTolerance = 0.1e-17,
interceptTolerance = 0.1e-13
)
}

override def tearDownAfterAll(bc: BenchmarkContext): Unit = {
if (dumpResultsBeforeTearDown && outputLogisticRegression != null) {
val outputFile = bc.scratchDirectory().resolve("output.txt")
dumpResult(outputLogisticRegression, outputFile)
}
private def summarizeModel(model: LogisticRegressionModel): ModelSummary = {
val coefficients = model.coefficients.toArray

tearDownSparkContext()
ModelSummary(
coefficients.sum,
coefficients.map(num => num * num).sum,
coefficients.length,
model.interceptVector(0),
model.interceptVector.size,
model.numClasses
)
}

private def dumpResult(lrm: LogisticRegressionModel, outputFile: Path) = {
val output = new StringBuilder
output.append(s"num features: ${lrm.numFeatures}\n")
output.append(s"num classes: ${lrm.numClasses}\n")
output.append(s"intercepts: ${lrm.interceptVector.toString}\n")
output.append(s"coefficients: ${lrm.coefficients.toString}\n")
private def validateSummary(
expected: ModelSummary,
actual: ModelSummary,
coefficientSumTolerance: Double,
coefficientSumSquaresTolerance: Double,
interceptTolerance: Double
): Unit = {
Assert.assertEquals(
expected.coefficientSum,
actual.coefficientSum,
coefficientSumTolerance,
"coefficients sum"
)

Assert.assertEquals(
expected.coefficientSumSquares,
actual.coefficientSumSquares,
coefficientSumSquaresTolerance,
"coefficients sum of squares"
)

Assert.assertEquals(expected.coefficientCount, actual.coefficientCount, "coefficient count")

// Files.writeString() is only available from Java 11.
Files.write(outputFile, output.toString.getBytes)
Assert.assertEquals(
expected.interceptValue,
actual.interceptValue,
interceptTolerance,
"intercept value"
)

Assert.assertEquals(expected.interceptCount, actual.interceptCount, "intercept count")

Assert.assertEquals(expected.classCount, actual.classCount, "class count")
}

override def tearDownAfterAll(bc: BenchmarkContext): Unit = {
tearDownSparkContext()
}
}

0 comments on commit 29d9ed9

Please sign in to comment.