Skip to content

Commit 632818f

Browse files
committed
removing threshold for classification predict method
1 parent 2116360 commit 632818f

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

+6-2
Original file line numberDiff line numberDiff line change
@@ -1034,8 +1034,12 @@ object DecisionTree extends Serializable with Logging {
10341034
/**
10351035
* Calculates the classifier accuracy.
10361036
*/
1037-
def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
1038-
val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
1037+
def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint],
1038+
threshold: Double = 0.5): Double = {
1039+
def predictedValue(features: Array[Double]) = {
1040+
if (model.predict(features) < threshold) 0.0 else 1.0
1041+
}
1042+
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
10391043
val count = data.count()
10401044
logDebug("correct prediction count = " + correctCount)
10411045
logDebug("data count = " + count)

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

+1-8
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
3434
* @return Double prediction from the trained model
3535
*/
3636
def predict(features: Array[Double]): Double = {
37-
algo match {
38-
case Classification => {
39-
if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0
40-
}
41-
case Regression => {
42-
topNode.predictIfLeaf(features)
43-
}
44-
}
37+
topNode.predictIfLeaf(features)
4538
}
4639

4740
/**

0 commit comments

Comments
 (0)