File tree Expand file tree Collapse file tree 1 file changed +10
-4
lines changed
mllib/src/main/scala/org/apache/spark/mllib/tree Expand file tree Collapse file tree 1 file changed +10
-4
lines changed Original file line number Diff line number Diff line change @@ -926,16 +926,22 @@ object DecisionTree extends Serializable with Logging {
926926 while (featureIndex < numFeatures){
927927 var splitIndex = 0
928928 while (splitIndex < numBins - 1 ) {
929+ val totalNodeAgg = Array .ofDim[Double ](numClasses)
929930 var classIndex = 0
930931 while (classIndex < numClasses) {
931932 // shift for this featureIndex
932933 val shift = numClasses * featureIndex * numBins
933- leftNodeAgg(featureIndex)(splitIndex)(classIndex)
934- = binData(shift + classIndex)
935- rightNodeAgg(featureIndex)(splitIndex)(classIndex)
936- = binData(shift + numClasses + classIndex)
934+ val binValue = binData(shift + classIndex)
935+ leftNodeAgg(featureIndex)(splitIndex)(classIndex) = binValue
936+ totalNodeAgg(classIndex) = binValue
937937 classIndex += 1
938938 }
939+ // Calculate rightNodeAgg
940+ classIndex = 0
941+ while (classIndex < numClasses) {
942+ rightNodeAgg(featureIndex)(splitIndex)(classIndex)
943+ = totalNodeAgg(classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex)
944+ }
939945 splitIndex += 1
940946 }
941947 featureIndex += 1
You can’t perform that action at this time.
0 commit comments