Skip to content

Commit fe59a4a

Browse files
viiryajkbradley
authored andcommitted
[SPARK-8468] [ML] Take the negative of some metrics in RegressionEvaluator to get correct cross validation
JIRA: https://issues.apache.org/jira/browse/SPARK-8468 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #6905 from viirya/cv_min and squashes the following commits: 930d3db [Liang-Chi Hsieh] Fix python unit test and add document. d632135 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cv_min 16e3b2c [Liang-Chi Hsieh] Take the negative instead of reciprocal. c3dd8d9 [Liang-Chi Hsieh] For comments. b5f52c1 [Liang-Chi Hsieh] Add param to CrossValidator for choosing whether to maximize evaulation value. (cherry picked from commit 0b89951) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
1 parent 9b16508 commit fe59a4a

File tree

5 files changed

+48
-11
lines changed

5 files changed

+48
-11
lines changed

mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ final class RegressionEvaluator(override val uid: String)
3737

3838
/**
3939
* param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`)
40+
*
41+
* Because we will maximize evaluation value (ref: `CrossValidator`),
42+
* when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
43+
* we take and output the negative of this metric.
4044
* @group param
4145
*/
4246
val metricName: Param[String] = {
@@ -70,13 +74,13 @@ final class RegressionEvaluator(override val uid: String)
7074
val metrics = new RegressionMetrics(predictionAndLabels)
7175
val metric = $(metricName) match {
7276
case "rmse" =>
73-
metrics.rootMeanSquaredError
77+
-metrics.rootMeanSquaredError
7478
case "mse" =>
75-
metrics.meanSquaredError
79+
-metrics.meanSquaredError
7680
case "r2" =>
7781
metrics.r2
7882
case "mae" =>
79-
metrics.meanAbsoluteError
83+
-metrics.meanAbsoluteError
8084
}
8185
metric
8286
}

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
297297

298298
/**
299299
* :: Experimental ::
300-
* A param amd its value.
300+
* A param and its value.
301301
*/
302302
@Experimental
303303
case class ParamPair[T](param: Param[T], value: T) {

mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
6363

6464
// default = rmse
6565
val evaluator = new RegressionEvaluator()
66-
assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001)
66+
assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001)
6767

6868
// r2 score
6969
evaluator.setMetricName("r2")
7070
assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001)
7171

7272
// mae
7373
evaluator.setMetricName("mae")
74-
assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
74+
assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001)
7575
}
7676
}

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ package org.apache.spark.ml.tuning
2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.ml.{Estimator, Model}
2222
import org.apache.spark.ml.classification.LogisticRegression
23-
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator}
23+
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
2424
import org.apache.spark.ml.param.ParamMap
2525
import org.apache.spark.ml.param.shared.HasInputCol
26+
import org.apache.spark.ml.regression.LinearRegression
2627
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
27-
import org.apache.spark.mllib.util.MLlibTestSparkContext
28+
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
2829
import org.apache.spark.sql.{DataFrame, SQLContext}
2930
import org.apache.spark.sql.types.StructType
3031

@@ -57,6 +58,36 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
5758
assert(parent.getMaxIter === 10)
5859
}
5960

61+
test("cross validation with linear regression") {
62+
val dataset = sqlContext.createDataFrame(
63+
sc.parallelize(LinearDataGenerator.generateLinearInput(
64+
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
65+
66+
val trainer = new LinearRegression
67+
val lrParamMaps = new ParamGridBuilder()
68+
.addGrid(trainer.regParam, Array(1000.0, 0.001))
69+
.addGrid(trainer.maxIter, Array(0, 10))
70+
.build()
71+
val eval = new RegressionEvaluator()
72+
val cv = new CrossValidator()
73+
.setEstimator(trainer)
74+
.setEstimatorParamMaps(lrParamMaps)
75+
.setEvaluator(eval)
76+
.setNumFolds(3)
77+
val cvModel = cv.fit(dataset)
78+
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
79+
assert(parent.getRegParam === 0.001)
80+
assert(parent.getMaxIter === 10)
81+
assert(cvModel.avgMetrics.length === lrParamMaps.length)
82+
83+
eval.setMetricName("r2")
84+
val cvModel2 = cv.fit(dataset)
85+
val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
86+
assert(parent2.getRegParam === 0.001)
87+
assert(parent2.getMaxIter === 10)
88+
assert(cvModel2.avgMetrics.length === lrParamMaps.length)
89+
}
90+
6091
test("validateParams should check estimatorParamMaps") {
6192
import CrossValidatorSuite._
6293

python/pyspark/ml/evaluation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,15 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
160160
...
161161
>>> evaluator = RegressionEvaluator(predictionCol="raw")
162162
>>> evaluator.evaluate(dataset)
163-
2.842...
163+
-2.842...
164164
>>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
165165
0.993...
166166
>>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
167-
2.649...
167+
-2.649...
168168
"""
169-
# a placeholder to make it appear in the generated doc
169+
# Because we will maximize evaluation value (ref: `CrossValidator`),
170+
# when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
171+
# we take and output the negative of this metric.
170172
metricName = Param(Params._dummy(), "metricName",
171173
"metric name in evaluation (mse|rmse|r2|mae)")
172174

0 commit comments

Comments
 (0)