Skip to content

Commit d811425

Browse files
committed
multiclass bin aggregate logic
1 parent ab5cb21 commit d811425

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -926,16 +926,22 @@ object DecisionTree extends Serializable with Logging {
926926
while (featureIndex < numFeatures){
927927
var splitIndex = 0
928928
while (splitIndex < numBins - 1) {
929+
val totalNodeAgg = Array.ofDim[Double](numClasses)
929930
var classIndex = 0
930931
while (classIndex < numClasses) {
931932
// shift for this featureIndex
932933
val shift = numClasses * featureIndex * numBins
933-
leftNodeAgg(featureIndex)(splitIndex)(classIndex)
934-
= binData(shift + classIndex)
935-
rightNodeAgg(featureIndex)(splitIndex)(classIndex)
936-
= binData(shift + numClasses + classIndex)
934+
val binValue = binData(shift + classIndex)
935+
leftNodeAgg(featureIndex)(splitIndex)(classIndex) = binValue
936+
totalNodeAgg(classIndex) = binValue
937937
classIndex += 1
938938
}
939+
// Calculate rightNodeAgg
940+
classIndex = 0
941+
while (classIndex < numClasses) {
942+
rightNodeAgg(featureIndex)(splitIndex)(classIndex)
943+
= totalNodeAgg(classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex)
944+
}
939945
splitIndex += 1
940946
}
941947
featureIndex += 1

0 commit comments

Comments
 (0)