Skip to content

Commit 930d3db

Browse files
committed
Fix python unit test and add document.
1 parent d632135 commit 930d3db

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ final class RegressionEvaluator(override val uid: String)
3737

3838
/**
3939
* param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`)
40+
*
41+
* Because we will maximize evaluation value (ref: `CrossValidator`),
42+
* when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
43+
* we take and output the negative of this metric.
4044
* @group param
4145
*/
4246
val metricName: Param[String] = {

python/pyspark/ml/evaluation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,15 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
160160
...
161161
>>> evaluator = RegressionEvaluator(predictionCol="raw")
162162
>>> evaluator.evaluate(dataset)
163-
2.842...
163+
-2.842...
164164
>>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
165165
0.993...
166166
>>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
167-
2.649...
167+
-2.649...
168168
"""
169-
# a placeholder to make it appear in the generated doc
169+
# Because we will maximize evaluation value (ref: `CrossValidator`),
170+
# when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
171+
# we take and output the negative of this metric.
170172
metricName = Param(Params._dummy(), "metricName",
171173
"metric name in evaluation (mse|rmse|r2|mae)")
172174

0 commit comments

Comments
 (0)