Skip to content

Commit d9e0919

Browse files
zhengruifengyanboliang
authored andcommitted
[SPARK-16851][ML] Incorrect threshould length in 'setThresholds()' evoke Exception
## What changes were proposed in this pull request? Add a length checking for threshoulds' length in method `setThreshoulds()` of classification models. ## How was this patch tested? unit tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #14457 from zhengruifeng/check_setThresholds.
1 parent a1ff72e commit d9e0919

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ abstract class ProbabilisticClassificationModel[
8383
def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
8484

8585
/** @group setParam */
86-
def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M]
86+
def setThresholds(value: Array[Double]): M = {
87+
require(value.length == numClasses, this.getClass.getSimpleName +
88+
".setThresholds() called with non-matching numClasses and thresholds.length." +
89+
s" numClasses=$numClasses, but thresholds has length ${value.length}")
90+
set(thresholds, value).asInstanceOf[M]
91+
}
8792

8893
/**
8994
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by

0 commit comments

Comments
 (0)