Skip to content

Commit

Permalink
[SPARK-14516][ML][FOLLOW-UP] Move ClusteringEvaluatorSuite test data …
Browse files Browse the repository at this point in the history
…to data/mllib.

## What changes were proposed in this pull request?
Move ```ClusteringEvaluatorSuite``` test data(iris) to data/mllib, to prevent from re-creating a new folder.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes apache#19648 from yanboliang/spark-14516.
  • Loading branch information
yanboliang committed Nov 8, 2017
1 parent 7475a96 commit 3da3d76
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 19 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,21 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.Dataset


class ClusteringEvaluatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

import testImplicits._

@transient var irisDataset: Dataset[_] = _

override def beforeAll(): Unit = {
super.beforeAll()
irisDataset = spark.read.format("libsvm").load("../data/mllib/iris_libsvm.txt")
}

test("params") {
ParamsSuite.checkParams(new ClusteringEvaluator)
}
Expand All @@ -53,37 +59,23 @@ class ClusteringEvaluatorSuite
0.6564679231
*/
test("squared euclidean Silhouette") {
val iris = ClusteringEvaluatorSuite.irisDataset(spark)
val evaluator = new ClusteringEvaluator()
.setFeaturesCol("features")
.setPredictionCol("label")

assert(evaluator.evaluate(iris) ~== 0.6564679231 relTol 1e-5)
assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5)
}

test("number of clusters must be greater than one") {
val iris = ClusteringEvaluatorSuite.irisDataset(spark)
.where($"label" === 0.0)
val singleClusterDataset = irisDataset.where($"label" === 0.0)
val evaluator = new ClusteringEvaluator()
.setFeaturesCol("features")
.setPredictionCol("label")

val e = intercept[AssertionError]{
evaluator.evaluate(iris)
evaluator.evaluate(singleClusterDataset)
}
assert(e.getMessage.contains("Number of clusters must be greater than one"))
}

}

object ClusteringEvaluatorSuite {
def irisDataset(spark: SparkSession): DataFrame = {

val irisPath = Thread.currentThread()
.getContextClassLoader
.getResource("test-data/iris.libsvm")
.toString

spark.read.format("libsvm").load(irisPath)
}
}

0 comments on commit 3da3d76

Please sign in to comment.