Skip to content

[SPARK-3160] [SPARK-3494] [mllib] DecisionTree: eliminate pre-allocated nodes, parentImpurities arrays. Memory calc bug fix. #2341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
191 changes: 80 additions & 111 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class Strategy (
if (algo == Classification) {
require(numClassesForClassification >= 2)
}
require(minInstancesPerNode >= 1,
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")

val isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
val isMulticlassWithCategoricalFeatures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator(
* Offset for each feature for calculating indices into the [[allStats]] array.
*/
private val featureOffsets: Array[Int] = {
def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
if (isUnordered(featureIndex)) {
total + 2 * numBins(featureIndex)
} else {
total + numBins(featureIndex)
}
}
Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
}

/**
Expand Down Expand Up @@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator(
s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
s" but was called for ordered feature $featureIndex.")
val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
(baseOffset, baseOffset + numBins(featureIndex) * statsSize)
(baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ private[tree] class DecisionTreeMetadata(
val numBins: Array[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy,
val maxDepth: Int,
val minInstancesPerNode: Int,
val minInfoGain: Double) extends Serializable {

Expand Down Expand Up @@ -129,7 +130,7 @@ private[tree] object DecisionTreeMetadata {

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
strategy.minInstancesPerNode, strategy.minInfoGain)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
* @return RDD[Int] where each entry contains the corresponding prediction
* @return RDD of predictions for each of the given data points
*/
def predict(features: RDD[Vector]): RDD[Double] = {
features.map(x => predict(x))
Expand Down
37 changes: 37 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class Node (
* build the left node and right nodes if not leaf
* @param nodes array of nodes
*/
@deprecated("build should no longer be used since trees are constructed on-the-fly in training",
"1.2.0")
def build(nodes: Array[Node]): Unit = {
logDebug("building node " + id + " at level " + Node.indexToLevel(id))
logDebug("id = " + id + ", split = " + split)
Expand Down Expand Up @@ -93,6 +95,23 @@ class Node (
}
}

/**
* Returns a deep copy of the subtree rooted at this node.
*/
private[tree] def deepCopy(): Node = {
val leftNodeCopy = if (leftNode.isEmpty) {
None
} else {
Some(leftNode.get.deepCopy())
}
val rightNodeCopy = if (rightNode.isEmpty) {
None
} else {
Some(rightNode.get.deepCopy())
}
new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
}

/**
* Get the number of nodes in tree below this node, including leaf nodes.
* E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
Expand Down Expand Up @@ -190,4 +209,22 @@ private[tree] object Node {
*/
def startIndexInLevel(level: Int): Int = 1 << level

/**
* Traces down from a root node to get the node with the given node index.
* This assumes the node exists.
*/
def getNode(nodeIndex: Int, rootNode: Node): Node = {
var tmpNode: Node = rootNode
var levelsToGo = indexToLevel(nodeIndex)
while (levelsToGo > 0) {
if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
tmpNode = tmpNode.leftNode.get
} else {
tmpNode = tmpNode.rightNode.get
}
levelsToGo -= 1
}
tmpNode
}

}
Loading