Skip to content

Commit e547151

Browse files
committed
minor modifications
1 parent 34549d0 commit e547151

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ object DecisionTree extends Serializable with Logging {
242242
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
243243
}
244244

245+
// TODO: Add multiclass classification support
246+
247+
// TODO: Add sample weight support
245248

246249
/**
247250
* Method to train a decision tree model where the instances are represented as an RDD of
@@ -723,8 +726,8 @@ object DecisionTree extends Serializable with Logging {
723726
val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount)
724727
val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount)
725728

726-
val leftWeight = leftTotalCount.toDouble / (leftTotalCount + rightTotalCount)
727-
val rightWeight = rightTotalCount.toDouble / (leftTotalCount + rightTotalCount)
729+
val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount)
730+
val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount)
728731

729732
val gain = {
730733
if (level > 0) {
@@ -734,7 +737,7 @@ object DecisionTree extends Serializable with Logging {
734737
}
735738
}
736739

737-
//TODO: Make modification here
740+
//TODO: Make multiclass modification here
738741
val predict = (leftCounts(1) + rightCounts(1)) / (leftTotalCount + rightTotalCount)
739742

740743
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)

0 commit comments

Comments
 (0)