Skip to content

[SPARK-12182][ML] Distributed binning for trees in spark.ml #10231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 54 additions & 56 deletions mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,8 @@ private[ml] object RandomForest extends Logging {
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
nodeToFeatures(nodeIndex)
}
new DTStatsAggregator(metadata, featuresForNode)
}
Expand Down Expand Up @@ -832,8 +832,8 @@ private[ml] object RandomForest extends Logging {
val numFeatures = metadata.numFeatures

// Sample the input only if there are continuous features.
val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
val sampledInput = if (hasContinuousFeatures) {
val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
val sampledInput = if (continuousFeatures.nonEmpty) {
// Calculate the number of samples for approximate quantile calculation.
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
val fraction = if (requiredSamples < metadata.numExamples) {
Expand All @@ -842,58 +842,57 @@ private[ml] object RandomForest extends Logging {
1.0
}
logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
} else {
new Array[LabeledPoint](0)
input.sparkContext.emptyRDD[LabeledPoint]
}

val splits = new Array[Array[Split]](numFeatures)

// Find all splits.
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
if (metadata.isContinuous(featureIndex)) {
val featureSamples = sampledInput.map(_.features(featureIndex))
val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)
findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
}

val numSplits = featureSplits.length
logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
splits(featureIndex) = new Array[Split](numSplits)
private def findSplitsBinsBySorting(
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata,
continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {

val continuousSplits: scala.collection.Map[Int, Array[Split]] = {
// reduce the parallelism for split computations when there are less
// continuous features than input partitions. this prevents tasks from
// being spun up that will definitely do no work.
val numPartitions = math.min(continuousFeatures.length, input.partitions.length)

input
.flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
.groupByKey(numPartitions)
.map { case (idx, samples) =>
val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(as mentioned in jenkins): scala style long line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
(idx, splits)
}.collectAsMap()
}

var splitIndex = 0
while (splitIndex < numSplits) {
val threshold = featureSplits(splitIndex)
splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold)
splitIndex += 1
}
} else {
// Categorical feature
if (metadata.isUnordered(featureIndex)) {
val numSplits = metadata.numSplits(featureIndex)
val featureArity = metadata.featureArity(featureIndex)
// TODO: Use an implicit representation mapping each category to a subset of indices.
// I.e., track indices such that we can calculate the set of bins for which
// feature value x splits to the left.
// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
splits(featureIndex) = new Array[Split](numSplits)
var splitIndex = 0
while (splitIndex < numSplits) {
val categories: List[Double] =
extractMultiClassCategories(splitIndex + 1, featureArity)
splits(featureIndex)(splitIndex) =
new CategoricalSplit(featureIndex, categories.toArray, featureArity)
splitIndex += 1
}
} else {
// Ordered features
// Bins correspond to feature values, so we do not need to compute splits or bins
// beforehand. Splits are constructed as needed during training.
splits(featureIndex) = new Array[Split](0)
val numFeatures = metadata.numFeatures
val splits: Array[Array[Split]] = Array.tabulate(numFeatures) {
case i if metadata.isContinuous(i) =>
val split = continuousSplits(i)
metadata.setNumSplits(i, split.length)
split

case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
val featureArity = metadata.featureArity(i)
Array.tabulate[Split](metadata.numSplits(i)) { splitIndex =>
val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
new CategoricalSplit(i, categories.toArray, featureArity)
}
}
featureIndex += 1

case i if metadata.isCategorical(i) =>
// Ordered features
// Bins correspond to feature values, so we do not need to compute splits or bins
// beforehand. Splits are constructed as needed during training.
Array.empty[Split]
}
splits
}
Expand Down Expand Up @@ -935,7 +934,7 @@ private[ml] object RandomForest extends Logging {
* @return array of splits
*/
private[tree] def findSplitsForContinuousFeature(
featureSamples: Array[Double],
featureSamples: Iterable[Double],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
Expand All @@ -945,8 +944,9 @@ private[ml] object RandomForest extends Logging {
val numSplits = metadata.numSplits(featureIndex)

// get count for each distinct value
val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
m + ((x, m.getOrElse(x, 0) + 1))
val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
case ((m, cnt), x) =>
(m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
}
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
Expand All @@ -957,7 +957,7 @@ private[ml] object RandomForest extends Logging {
valueCounts.map(_._1)
} else {
// stride between splits
val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
val stride: Double = numSamples.toDouble / (numSplits + 1)
logDebug("stride = " + stride)

// iterate `valueCount` to find splits
Expand Down Expand Up @@ -993,8 +993,6 @@ private[ml] object RandomForest extends Logging {
assert(splits.length > 0,
s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
" Please remove this feature and then try again.")
// set number of splits accordingly
metadata.setNumSplits(featureIndex, splits.length)

splits
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ object DecisionTree extends Serializable with Logging {
featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = {
val splits = {
val featureSplits = findSplitsForContinuousFeature(
featureSamples.toArray,
featureSamples,
metadata,
featureIndex)
logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}")
Expand Down Expand Up @@ -1111,7 +1111,7 @@ object DecisionTree extends Serializable with Logging {
* @return array of splits
*/
private[tree] def findSplitsForContinuousFeature(
featureSamples: Array[Double],
featureSamples: Iterable[Double],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
Expand All @@ -1121,8 +1121,9 @@ object DecisionTree extends Serializable with Logging {
val numSplits = metadata.numSplits(featureIndex)

// get count for each distinct value
val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
m + ((x, m.getOrElse(x, 0) + 1))
val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
case ((m, cnt), x) =>
(m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
}
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
Expand All @@ -1133,7 +1134,7 @@ object DecisionTree extends Serializable with Logging {
valueCounts.map(_._1)
} else {
// stride between splits
val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
val stride: Double = numSamples.toDouble / (numSplits + 1)
logDebug("stride = " + stride)

// iterate `valueCount` to find splits
Expand Down