Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.automl

import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import org.apache.spark.ml.classification._

// scalastyle:off magic.number
class VerifyDefaultHyperparams extends TestBase {

test("LogisticRegression default range is non-empty") {
val lr = new LogisticRegression()
val params = DefaultHyperparams.defaultRange(lr)
assert(params.nonEmpty)
val paramNames = params.map(_._1.name).toSet
assert(paramNames.contains("regParam"))
assert(paramNames.contains("elasticNetParam"))
assert(paramNames.contains("maxIter"))
}

test("DecisionTreeClassifier default range is non-empty") {
val dt = new DecisionTreeClassifier()
val params = DefaultHyperparams.defaultRange(dt)
assert(params.nonEmpty)
val paramNames = params.map(_._1.name).toSet
assert(paramNames.contains("maxBins"))
assert(paramNames.contains("maxDepth"))
}

test("GBTClassifier default range is non-empty") {
val gbt = new GBTClassifier()
val params = DefaultHyperparams.defaultRange(gbt)
assert(params.nonEmpty)
assert(params.length >= 5)
}

test("RandomForestClassifier default range is non-empty") {
val rf = new RandomForestClassifier()
val params = DefaultHyperparams.defaultRange(rf)
assert(params.nonEmpty)
val paramNames = params.map(_._1.name).toSet
assert(paramNames.contains("numTrees"))
}

test("MultilayerPerceptronClassifier default range is non-empty") {
val mlp = new MultilayerPerceptronClassifier()
val params = DefaultHyperparams.defaultRange(mlp)
assert(params.nonEmpty)
val paramNames = params.map(_._1.name).toSet
assert(paramNames.contains("blockSize"))
assert(paramNames.contains("layers"))
}

test("NaiveBayes default range is non-empty") {
val nb = new NaiveBayes()
val params = DefaultHyperparams.defaultRange(nb)
assert(params.nonEmpty)
val paramNames = params.map(_._1.name).toSet
assert(paramNames.contains("smoothing"))
}

test("default ranges produce valid distributions") {
val lr = new LogisticRegression()
val params = DefaultHyperparams.defaultRange(lr)
params.foreach { case (param, dist) =>
val value = dist.getNext
assert(value != null) // scalastyle:ignore null
}
}
}
// scalastyle:on magic.number
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.automl

import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import com.microsoft.azure.synapse.ml.core.metrics.MetricConstants
import com.microsoft.azure.synapse.ml.core.schema.SchemaConstants

class VerifyEvaluationUtils extends TestBase {

test("getMetricWithOperator returns correct metric for regression MSE") {
val (name, _) = EvaluationUtils.getMetricWithOperator(
SchemaConstants.RegressionKind, MetricConstants.MseSparkMetric)
assert(name === MetricConstants.MseColumnName)
}

test("getMetricWithOperator returns correct metric for regression RMSE") {
val (name, _) = EvaluationUtils.getMetricWithOperator(
SchemaConstants.RegressionKind, MetricConstants.RmseSparkMetric)
assert(name === MetricConstants.RmseColumnName)
}

test("getMetricWithOperator returns correct metric for regression R2") {
val (name, _) = EvaluationUtils.getMetricWithOperator(
SchemaConstants.RegressionKind, MetricConstants.R2SparkMetric)
assert(name === MetricConstants.R2ColumnName)
}

test("getMetricWithOperator returns correct metric for regression MAE") {
val (name, _) = EvaluationUtils.getMetricWithOperator(
SchemaConstants.RegressionKind, MetricConstants.MaeSparkMetric)
assert(name === MetricConstants.MaeColumnName)
}

test("getMetricWithOperator returns correct metric for classification AUC") {
val (name, _) = EvaluationUtils.getMetricWithOperator(
SchemaConstants.ClassificationKind, MetricConstants.AucSparkMetric)
assert(name === MetricConstants.AucColumnName)
}

test("getMetricWithOperator returns correct metric for classification accuracy") {
val (name, _) = EvaluationUtils.getMetricWithOperator(
SchemaConstants.ClassificationKind, MetricConstants.AccuracySparkMetric)
assert(name === MetricConstants.AccuracyColumnName)
}

test("regression metrics use chooseLowest ordering (except R2)") {
val (_, mseOrd) = EvaluationUtils.getMetricWithOperator(
SchemaConstants.RegressionKind, MetricConstants.MseSparkMetric)
// MSE should prefer lower values
assert(mseOrd.compare(1.0, 2.0) > 0)

val (_, r2Ord) = EvaluationUtils.getMetricWithOperator(
SchemaConstants.RegressionKind, MetricConstants.R2SparkMetric)
// R2 should prefer higher values
assert(r2Ord.compare(1.0, 2.0) < 0)
}

test("classification metrics use chooseHighest ordering") {
val (_, aucOrd) = EvaluationUtils.getMetricWithOperator(
SchemaConstants.ClassificationKind, MetricConstants.AucSparkMetric)
// AUC should prefer higher values
assert(aucOrd.compare(1.0, 2.0) < 0)
}

test("unsupported regression metric throws") {
assertThrows[Exception] {
EvaluationUtils.getMetricWithOperator(SchemaConstants.RegressionKind, "bogus_metric")
}
}

test("unsupported classification metric throws") {
assertThrows[Exception] {
EvaluationUtils.getMetricWithOperator(SchemaConstants.ClassificationKind, "bogus_metric")
}
}

test("unsupported model type throws") {
assertThrows[Exception] {
EvaluationUtils.getMetricWithOperator("unsupported_type", MetricConstants.MseSparkMetric)
}
}

test("ModelTypeUnsupportedErr constant is defined") {
assert(EvaluationUtils.ModelTypeUnsupportedErr.nonEmpty)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.automl

import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import org.apache.spark.ml.param.{IntParam, DoubleParam, Param, Params, ParamMap}
import org.apache.spark.ml.util.Identifiable

import scala.collection.JavaConverters._

// scalastyle:off magic.number
class VerifyHyperparamBuilder extends TestBase {

private object TestParams extends Params {
override val uid: String = Identifiable.randomUID("TestParams") // scalastyle:ignore field.name
override def copy(extra: ParamMap): Params = this
val intParam = new IntParam(this, "intParam", "test int param") // scalastyle:ignore field.name
val doubleParam = new DoubleParam(this, "doubleParam", "test double param") // scalastyle:ignore field.name
}

test("IntRangeHyperParam generates values within range") {
val hp = new IntRangeHyperParam(5, 15, seed = 42)
val values = (1 to 100).map(_ => hp.getNext())
assert(values.forall(v => v >= 5 && v < 15))
assert(values.toSet.size > 1) // not all the same
}

test("DoubleRangeHyperParam generates values within range") {
val hp = new DoubleRangeHyperParam(0.0, 1.0, seed = 42)
val values = (1 to 100).map(_ => hp.getNext())
assert(values.forall(v => v >= 0.0 && v < 1.0))
}

test("FloatRangeHyperParam generates values within range") {
val hp = new FloatRangeHyperParam(0.0f, 1.0f, seed = 42)
val values = (1 to 100).map(_ => hp.getNext())
assert(values.forall(v => v >= 0.0f && v < 1.0f))
}

test("LongRangeHyperParam generates values") {
val hp = new LongRangeHyperParam(0L, 100L, seed = 42)
val value = hp.getNext()
assert(value.isInstanceOf[Long])
}

test("DiscreteHyperParam selects from provided values") {
val hp = new DiscreteHyperParam(List("a", "b", "c"), seed = 42)
val values = (1 to 100).map(_ => hp.getNext())
assert(values.forall(Set("a", "b", "c").contains))
assert(values.toSet.size > 1)
}

test("DiscreteHyperParam getValues returns Java list") {
val hp = new DiscreteHyperParam(List(1, 2, 3))
val javaList = hp.getValues
assert(javaList.asScala.toList === List(1, 2, 3))
}

test("HyperparamBuilder builds array of param-dist pairs") {
val hp = new HyperparamBuilder()
.addHyperparam(TestParams.intParam, new IntRangeHyperParam(1, 10))
.addHyperparam(TestParams.doubleParam, new DoubleRangeHyperParam(0.0, 1.0))
.build()
assert(hp.length === 2)
assert(hp.map(_._1.name).toSet === Set("intParam", "doubleParam"))
}

test("HyperparamBuilder empty build returns empty array") {
val hp = new HyperparamBuilder().build()
assert(hp.isEmpty)
}

test("HyperParamUtils.getRangeHyperParam matches Int type") {
val hp = HyperParamUtils.getRangeHyperParam(1, 10)
assert(hp.isInstanceOf[IntRangeHyperParam])
}

test("HyperParamUtils.getRangeHyperParam matches Double type") {
val hp = HyperParamUtils.getRangeHyperParam(0.0, 1.0)
assert(hp.isInstanceOf[DoubleRangeHyperParam])
}

test("HyperParamUtils.getRangeHyperParam matches Float type") {
val hp = HyperParamUtils.getRangeHyperParam(0.0f, 1.0f)
assert(hp.isInstanceOf[FloatRangeHyperParam])
}

test("HyperParamUtils.getRangeHyperParam matches Long type") {
val hp = HyperParamUtils.getRangeHyperParam(0L, 100L)
assert(hp.isInstanceOf[LongRangeHyperParam])
}

test("HyperParamUtils.getRangeHyperParam throws on unsupported type") {
assertThrows[Exception] {
HyperParamUtils.getRangeHyperParam("a", "z")
}
}

test("HyperParamUtils.getDiscreteHyperParam creates from Java ArrayList") {
val javaList = new java.util.ArrayList[String]()
javaList.add("x")
javaList.add("y")
val hp = HyperParamUtils.getDiscreteHyperParam(javaList)
val values = (1 to 50).map(_ => hp.getNext().toString)
assert(values.forall(v => v == "x" || v == "y"))
}

test("seeded RangeHyperParam produces deterministic sequences") {
val hp1 = new IntRangeHyperParam(0, 100, seed = 123)
val hp2 = new IntRangeHyperParam(0, 100, seed = 123)
val seq1 = (1 to 10).map(_ => hp1.getNext())
val seq2 = (1 to 10).map(_ => hp2.getNext())
assert(seq1 === seq2)
}
}
// scalastyle:on magic.number
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.automl

import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
import org.apache.spark.ml.util.Identifiable

// scalastyle:off magic.number
class VerifyParamSpace extends TestBase {

private object TestParams extends Params {
override val uid: String = Identifiable.randomUID("TestParams") // scalastyle:ignore field.name
override def copy(extra: ParamMap): Params = this
val intParam = new IntParam(this, "intParam", "test int param") // scalastyle:ignore field.name
}

test("GridSpace iterates over all ParamMaps") {
val pm1 = ParamMap(TestParams.intParam -> 1)
val pm2 = ParamMap(TestParams.intParam -> 2)
val pm3 = ParamMap(TestParams.intParam -> 3)
val grid = new GridSpace(Array(pm1, pm2, pm3))
val result = grid.paramMaps.toList
assert(result.length === 3)
}

test("GridSpace with empty array produces empty iterator") {
val grid = new GridSpace(Array.empty[ParamMap])
assert(!grid.paramMaps.hasNext)
}

test("RandomSpace produces infinite iterator") {
val builder = new HyperparamBuilder()
.addHyperparam(TestParams.intParam, new IntRangeHyperParam(1, 100))
val space = new RandomSpace(builder.build())
val values = space.paramMaps.take(50).toList
assert(values.length === 50)
values.foreach { pm =>
val v = pm.get(TestParams.intParam)
assert(v.isDefined)
assert(v.get >= 1 && v.get < 100)
}
}

test("RandomSpace iterator always hasNext") {
val builder = new HyperparamBuilder()
.addHyperparam(TestParams.intParam, new IntRangeHyperParam(0, 10))
val space = new RandomSpace(builder.build())
assert(space.paramMaps.hasNext)
space.paramMaps.next()
assert(space.paramMaps.hasNext)
}

test("Dist.getParamPair creates correct ParamPair") {
val dist = new IntRangeHyperParam(5, 15, seed = 42)
val pp = dist.getParamPair(TestParams.intParam)
assert(pp.param.name === TestParams.intParam.name)
val value = pp.value.asInstanceOf[Int]
assert(value >= 5)
assert(value < 15)
}
}
// scalastyle:on magic.number
Loading
Loading