Skip to content

Commit aed7ff3

Browse files
committed
[SPARK-29258][ML][PYSPARK] parity between ml.evaluator and mllib.metrics
### What changes were proposed in this pull request? 1, expose `BinaryClassificationMetrics.numBins` in `BinaryClassificationEvaluator` 2, expose `RegressionMetrics.throughOrigin` in `RegressionEvaluator` 3, add metric `explainedVariance` in `RegressionEvaluator` ### Why are the changes needed? existing function in mllib.metrics should also be exposed in ml ### Does this PR introduce any user-facing change? yes, this PR add two expert params and one metric option ### How was this patch tested? existing and added tests Closes #25940 from zhengruifeng/evaluator_add_param. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
1 parent ada3ad3 commit aed7ff3

File tree

4 files changed

+109
-26
lines changed

4 files changed

+109
-26
lines changed

mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,28 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
5959
@Since("1.2.0")
6060
def setMetricName(value: String): this.type = set(metricName, value)
6161

62+
/**
63+
* param for number of bins to down-sample the curves (ROC curve, PR curve) in area
64+
* computation. If 0, no down-sampling will occur.
65+
* Default: 1000.
66+
* @group expertParam
67+
*/
68+
@Since("3.0.0")
69+
val numBins: IntParam = new IntParam(this, "numBins", "Number of bins to down-sample " +
70+
"the curves (ROC curve, PR curve) in area computation. If 0, no down-sampling will occur. " +
71+
"Must be >= 0.",
72+
ParamValidators.gtEq(0))
73+
74+
/** @group expertGetParam */
75+
@Since("3.0.0")
76+
def getNumBins: Int = $(numBins)
77+
78+
/** @group expertSetParam */
79+
@Since("3.0.0")
80+
def setNumBins(value: Int): this.type = set(numBins, value)
81+
82+
setDefault(numBins -> 1000)
83+
6284
/** @group setParam */
6385
@Since("1.5.0")
6486
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
@@ -94,7 +116,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
94116
case Row(rawPrediction: Double, label: Double, weight: Double) =>
95117
(rawPrediction, label, weight)
96118
}
97-
val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights)
119+
val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights, $(numBins))
98120
val metric = $(metricName) match {
99121
case "areaUnderROC" => metrics.areaUnderROC()
100122
case "areaUnderPR" => metrics.areaUnderPR()
@@ -104,10 +126,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
104126
}
105127

106128
@Since("1.5.0")
107-
override def isLargerBetter: Boolean = $(metricName) match {
108-
case "areaUnderROC" => true
109-
case "areaUnderPR" => true
110-
}
129+
override def isLargerBetter: Boolean = true
111130

112131
@Since("1.4.1")
113132
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)

mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml.evaluation
1919

2020
import org.apache.spark.annotation.Since
21-
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
21+
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
2222
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
2323
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
2424
import org.apache.spark.mllib.evaluation.RegressionMetrics
@@ -43,13 +43,14 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
4343
* - `"mse"`: mean squared error
4444
* - `"r2"`: R^2^ metric
4545
* - `"mae"`: mean absolute error
46+
* - `"var"`: explained variance
4647
*
4748
* @group param
4849
*/
4950
@Since("1.4.0")
5051
val metricName: Param[String] = {
51-
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))
52-
new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams)
52+
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae", "var"))
53+
new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae|var)", allowedParams)
5354
}
5455

5556
/** @group getParam */
@@ -60,6 +61,25 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
6061
@Since("1.4.0")
6162
def setMetricName(value: String): this.type = set(metricName, value)
6263

64+
/**
65+
* param for whether the regression is through the origin.
66+
* Default: false.
67+
* @group expertParam
68+
*/
69+
@Since("3.0.0")
70+
val throughOrigin: BooleanParam = new BooleanParam(this, "throughOrigin",
71+
"Whether the regression is through the origin.")
72+
73+
/** @group expertGetParam */
74+
@Since("3.0.0")
75+
def getThroughOrigin: Boolean = $(throughOrigin)
76+
77+
/** @group expertSetParam */
78+
@Since("3.0.0")
79+
def setThroughOrigin(value: Boolean): this.type = set(throughOrigin, value)
80+
81+
setDefault(throughOrigin -> false)
82+
6383
/** @group setParam */
6484
@Since("1.4.0")
6585
def setPredictionCol(value: String): this.type = set(predictionCol, value)
@@ -86,22 +106,20 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
86106
.rdd
87107
.map { case Row(prediction: Double, label: Double, weight: Double) =>
88108
(prediction, label, weight) }
89-
val metrics = new RegressionMetrics(predictionAndLabelsWithWeights)
90-
val metric = $(metricName) match {
109+
val metrics = new RegressionMetrics(predictionAndLabelsWithWeights, $(throughOrigin))
110+
$(metricName) match {
91111
case "rmse" => metrics.rootMeanSquaredError
92112
case "mse" => metrics.meanSquaredError
93113
case "r2" => metrics.r2
94114
case "mae" => metrics.meanAbsoluteError
115+
case "var" => metrics.explainedVariance
95116
}
96-
metric
97117
}
98118

99119
@Since("1.4.0")
100120
override def isLargerBetter: Boolean = $(metricName) match {
101-
case "rmse" => false
102-
case "mse" => false
103-
case "r2" => true
104-
case "mae" => false
121+
case "r2" | "var" => true
122+
case _ => false
105123
}
106124

107125
@Since("1.5.0")

mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ class RegressionEvaluatorSuite
7676
// mae
7777
evaluator.setMetricName("mae")
7878
assert(evaluator.evaluate(predictions) ~== 0.08399089 absTol 0.01)
79+
80+
// var
81+
evaluator.setMetricName("var")
82+
assert(evaluator.evaluate(predictions) ~== 63.6944519 absTol 0.01)
7983
}
8084

8185
test("read/write") {

python/pyspark/ml/evaluation.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
139139
0.70...
140140
>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})
141141
0.82...
142+
>>> evaluator.getNumBins()
143+
1000
142144
143145
.. versionadded:: 1.4.0
144146
"""
@@ -147,17 +149,22 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
147149
"metric name in evaluation (areaUnderROC|areaUnderPR)",
148150
typeConverter=TypeConverters.toString)
149151

152+
numBins = Param(Params._dummy(), "numBins", "Number of bins to down-sample the curves "
153+
"(ROC curve, PR curve) in area computation. If 0, no down-sampling will "
154+
"occur. Must be >= 0.",
155+
typeConverter=TypeConverters.toInt)
156+
150157
@keyword_only
151158
def __init__(self, rawPredictionCol="rawPrediction", labelCol="label",
152-
metricName="areaUnderROC", weightCol=None):
159+
metricName="areaUnderROC", weightCol=None, numBins=1000):
153160
"""
154161
__init__(self, rawPredictionCol="rawPrediction", labelCol="label", \
155-
metricName="areaUnderROC", weightCol=None)
162+
metricName="areaUnderROC", weightCol=None, numBins=1000)
156163
"""
157164
super(BinaryClassificationEvaluator, self).__init__()
158165
self._java_obj = self._new_java_obj(
159166
"org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid)
160-
self._setDefault(metricName="areaUnderROC")
167+
self._setDefault(metricName="areaUnderROC", numBins=1000)
161168
kwargs = self._input_kwargs
162169
self._set(**kwargs)
163170

@@ -175,13 +182,27 @@ def getMetricName(self):
175182
"""
176183
return self.getOrDefault(self.metricName)
177184

185+
@since("3.0.0")
186+
def setNumBins(self, value):
187+
"""
188+
Sets the value of :py:attr:`numBins`.
189+
"""
190+
return self._set(numBins=value)
191+
192+
@since("3.0.0")
193+
def getNumBins(self):
194+
"""
195+
Gets the value of numBins or its default value.
196+
"""
197+
return self.getOrDefault(self.numBins)
198+
178199
@keyword_only
179200
@since("1.4.0")
180201
def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
181-
metricName="areaUnderROC", weightCol=None):
202+
metricName="areaUnderROC", weightCol=None, numBins=1000):
182203
"""
183204
setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \
184-
metricName="areaUnderROC", weightCol=None)
205+
metricName="areaUnderROC", weightCol=None, numBins=1000)
185206
Sets params for binary classification evaluator.
186207
"""
187208
kwargs = self._input_kwargs
@@ -218,6 +239,8 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh
218239
>>> evaluator = RegressionEvaluator(predictionCol="raw", weightCol="weight")
219240
>>> evaluator.evaluate(dataset)
220241
2.740...
242+
>>> evaluator.getThroughOrigin()
243+
False
221244
222245
.. versionadded:: 1.4.0
223246
"""
@@ -226,20 +249,25 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh
226249
rmse - root mean squared error (default)
227250
mse - mean squared error
228251
r2 - r^2 metric
229-
mae - mean absolute error.""",
252+
mae - mean absolute error
253+
var - explained variance.""",
230254
typeConverter=TypeConverters.toString)
231255

256+
throughOrigin = Param(Params._dummy(), "throughOrigin",
257+
"whether the regression is through the origin.",
258+
typeConverter=TypeConverters.toBoolean)
259+
232260
@keyword_only
233261
def __init__(self, predictionCol="prediction", labelCol="label",
234-
metricName="rmse", weightCol=None):
262+
metricName="rmse", weightCol=None, throughOrigin=False):
235263
"""
236264
__init__(self, predictionCol="prediction", labelCol="label", \
237-
metricName="rmse", weightCol=None)
265+
metricName="rmse", weightCol=None, throughOrigin=False)
238266
"""
239267
super(RegressionEvaluator, self).__init__()
240268
self._java_obj = self._new_java_obj(
241269
"org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid)
242-
self._setDefault(metricName="rmse")
270+
self._setDefault(metricName="rmse", throughOrigin=False)
243271
kwargs = self._input_kwargs
244272
self._set(**kwargs)
245273

@@ -257,13 +285,27 @@ def getMetricName(self):
257285
"""
258286
return self.getOrDefault(self.metricName)
259287

288+
@since("3.0.0")
289+
def setThroughOrigin(self, value):
290+
"""
291+
Sets the value of :py:attr:`throughOrigin`.
292+
"""
293+
return self._set(throughOrigin=value)
294+
295+
@since("3.0.0")
296+
def getThroughOrigin(self):
297+
"""
298+
Gets the value of throughOrigin or its default value.
299+
"""
300+
return self.getOrDefault(self.throughOrigin)
301+
260302
@keyword_only
261303
@since("1.4.0")
262304
def setParams(self, predictionCol="prediction", labelCol="label",
263-
metricName="rmse", weightCol=None):
305+
metricName="rmse", weightCol=None, throughOrigin=False):
264306
"""
265307
setParams(self, predictionCol="prediction", labelCol="label", \
266-
metricName="rmse", weightCol=None)
308+
metricName="rmse", weightCol=None, throughOrigin=False)
267309
Sets params for regression evaluator.
268310
"""
269311
kwargs = self._input_kwargs

0 commit comments

Comments
 (0)