Skip to content

Commit

Permalink
[ML][MINOR] Separate estimator and model params for read/write test.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
Since we allow ```Estimator``` and ```Model``` not always share same params (see ```ALSParams``` and ```ALSModelParams```), we should pass in test params for estimator and model separately in function ```testEstimatorAndModelReadWrite```.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes apache#17151 from yanboliang/test-rw.
  • Loading branch information
yanboliang committed Mar 8, 2017
1 parent 314e48a commit 1fa5886
Show file tree
Hide file tree
Showing 23 changed files with 59 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -372,16 +372,18 @@ class DecisionTreeClassifierSuite
// Categorical splits with tree depth 2
val categoricalData: DataFrame =
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
allParamSettings, checkModelData)

// Continuous splits with tree depth 2
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings,
allParamSettings, checkModelData)

// Continuous splits with tree depth 0
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
checkModelData)
allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext

val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
val svm = new LinearSVC()
testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings,
checkModelData)
LinearSVCSuite.allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2089,7 +2089,7 @@ class LogisticRegressionSuite
}
val lr = new LogisticRegression()
testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings,
checkModelData)
LogisticRegressionSuite.allParamSettings, checkModelData)
}

test("should support all NumericType labels and weights, and not support other types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
assert(model.theta === model2.theta)
}
val nb = new NaiveBayes()
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings,
NaiveBayesSuite.allParamSettings, checkModelData)
}

test("should support all NumericType labels and weights, and not support other types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ class RandomForestClassifierSuite

val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ class BisectingKMeansSuite
assert(model.clusterCenters === model2.clusterCenters)
}
val bisectingKMeans = new BisectingKMeans()
testEstimatorAndModelReadWrite(
bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings,
BisectingKMeansSuite.allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov))
}
val gm = new GaussianMixture()
testEstimatorAndModelReadWrite(gm, dataset,
testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings,
GaussianMixtureSuite.allParamSettings, checkModelData)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(model.clusterCenters === model2.clusterCenters)
}
val kmeans = new KMeans()
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
KMeansSuite.allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
Vectors.dense(model2.getDocConcentration) absTol 1e-6)
}
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings,
LDASuite.allParamSettings, checkModelData)
}

test("read/write DistributedLDAModel") {
Expand All @@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset,
LDASuite.allParamSettings ++ Map("optimizer" -> "em"),
LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class BucketedRandomProjectionLSHSuite
}
val mh = new BucketedRandomProjectionLSH()
val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0)
testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
}

test("hashFunction") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.selectedFeatures === model2.selectedFeatures)
}
val nb = new ChiSqSelector
testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings,
ChiSqSelectorSuite.allParamSettings, checkModelData)
}

test("should support all NumericType labels and not support other types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
val mh = new MinHashLSH()
val settings = Map("inputCol" -> "keys", "outputCol" -> "values")
testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
}

test("hashFunction") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
model2.freqItemsets.sort("items").collect())
}
val fPGrowth = new FPGrowth()
testEstimatorAndModelReadWrite(
fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
FPGrowthSuite.allParamSettings, checkModelData)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,37 +518,26 @@ class ALSSuite
}

test("read/write") {
import ALSSuite._
val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
val als = new ALS()
allEstimatorParamSettings.foreach { case (p, v) =>
als.set(als.getParam(p), v)
}
val spark = this.spark
import spark.implicits._
val model = als.fit(ratings.toDF())

// Test Estimator save/load
val als2 = testDefaultReadWrite(als)
allEstimatorParamSettings.foreach { case (p, v) =>
val param = als.getParam(p)
assert(als.get(param).get === als2.get(param).get)
}
import ALSSuite._
val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)

// Test Model save/load
val model2 = testDefaultReadWrite(model)
allModelParamSettings.foreach { case (p, v) =>
val param = model.getParam(p)
assert(model.get(param).get === model2.get(param).get)
}
assert(model.rank === model2.rank)
def getFactors(df: DataFrame): Set[(Int, Array[Float])] = {
df.select("id", "features").collect().map { case r =>
(r.getInt(0), r.getAs[Array[Float]](1))
}.toSet
}
assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))

def checkModelData(model: ALSModel, model2: ALSModel): Unit = {
assert(model.rank === model2.rank)
assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
}

val als = new ALS()
testEstimatorAndModelReadWrite(als, ratings.toDF(), allEstimatorParamSettings,
allModelParamSettings, checkModelData)
}

test("input type validation") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ class AFTSurvivalRegressionSuite
}
val aft = new AFTSurvivalRegression()
testEstimatorAndModelReadWrite(aft, datasetMultivariate,
AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
AFTSurvivalRegressionSuite.allParamSettings, AFTSurvivalRegressionSuite.allParamSettings,
checkModelData)
}

test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,17 @@ class DecisionTreeRegressorSuite
val categoricalData: DataFrame =
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0)
testEstimatorAndModelReadWrite(dt, categoricalData,
TreeTests.allParamSettings, checkModelData)
TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)

// Continuous splits with tree depth 2
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
testEstimatorAndModelReadWrite(dt, continuousData,
TreeTests.allParamSettings, checkModelData)
TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)

// Continuous splits with tree depth 0
testEstimatorAndModelReadWrite(dt, continuousData,
TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared")
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,7 @@ class GeneralizedLinearRegressionSuite

val glr = new GeneralizedLinearRegression()
testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
GeneralizedLinearRegressionSuite.allParamSettings,
GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class IsotonicRegressionSuite

val ir = new IsotonicRegression()
testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings,
checkModelData)
IsotonicRegressionSuite.allParamSettings, checkModelData)
}

test("should support all NumericType labels and weights, and not support other types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ class LinearRegressionSuite
}
val lr = new LinearRegression()
testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
checkModelData)
LinearRegressionSuite.allParamSettings, checkModelData)
}

test("should support all NumericType labels and weights, and not support other types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex

val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,12 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* - Check Params on Estimator and Model
* - Compare model data
*
* This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s.
* This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s.
*
* @param estimator Estimator to test
* @param dataset Dataset to pass to [[Estimator.fit()]]
* @param testParams Set of [[Param]] values to set in estimator
* @param testEstimatorParams Set of [[Param]] values to set in estimator
* @param testModelParams Set of [[Param]] values to set in model
* @param checkModelData Method which takes the original and loaded [[Model]] and compares their
* data. This method does not need to check [[Param]] values.
* @tparam E Type of [[Estimator]]
Expand All @@ -99,24 +100,25 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
estimator: E,
dataset: Dataset[_],
testParams: Map[String, Any],
testEstimatorParams: Map[String, Any],
testModelParams: Map[String, Any],
checkModelData: (M, M) => Unit): Unit = {
// Set some Params to make sure set Params are serialized.
testParams.foreach { case (p, v) =>
testEstimatorParams.foreach { case (p, v) =>
estimator.set(estimator.getParam(p), v)
}
val model = estimator.fit(dataset)

// Test Estimator save/load
val estimator2 = testDefaultReadWrite(estimator)
testParams.foreach { case (p, v) =>
testEstimatorParams.foreach { case (p, v) =>
val param = estimator.getParam(p)
assert(estimator.get(param).get === estimator2.get(param).get)
}

// Test Model save/load
val model2 = testDefaultReadWrite(model)
testParams.foreach { case (p, v) =>
testModelParams.foreach { case (p, v) =>
val param = model.getParam(p)
assert(model.get(param).get === model2.get(param).get)
}
Expand Down

0 comments on commit 1fa5886

Please sign in to comment.