Skip to content

Commit c8c20dc

Browse files
committed
Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check
2 parents cf70b07 + 73ab7f1 commit c8c20dc

File tree

9 files changed

+615
-630
lines changed

9 files changed

+615
-630
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 386 additions & 492 deletions
Large diffs are not rendered by default.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.impl
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.mllib.regression.LabeledPoint
23+
import org.apache.spark.mllib.tree.configuration.Algo._
24+
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
25+
import org.apache.spark.mllib.tree.configuration.Strategy
26+
import org.apache.spark.mllib.tree.impurity.Impurity
27+
import org.apache.spark.rdd.RDD
28+
29+
30+
/**
31+
* Learning and dataset metadata for DecisionTree.
32+
*
33+
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
34+
* For regression: fixed at 0 (no meaning).
35+
* @param featureArity Map: categorical feature index --> arity.
36+
* I.e., the feature takes values in {0, ..., arity - 1}.
37+
*/
38+
private[tree] class DecisionTreeMetadata(
39+
val numFeatures: Int,
40+
val numExamples: Long,
41+
val numClasses: Int,
42+
val maxBins: Int,
43+
val featureArity: Map[Int, Int],
44+
val unorderedFeatures: Set[Int],
45+
val impurity: Impurity,
46+
val quantileStrategy: QuantileStrategy) extends Serializable {
47+
48+
def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
49+
50+
def isClassification: Boolean = numClasses >= 2
51+
52+
def isMulticlass: Boolean = numClasses > 2
53+
54+
def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0)
55+
56+
def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex)
57+
58+
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
59+
60+
}
61+
62+
private[tree] object DecisionTreeMetadata {
63+
64+
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
65+
66+
val numFeatures = input.take(1)(0).features.size
67+
val numExamples = input.count()
68+
val numClasses = strategy.algo match {
69+
case Classification => strategy.numClassesForClassification
70+
case Regression => 0
71+
}
72+
73+
val maxBins = math.min(strategy.maxBins, numExamples).toInt
74+
val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
75+
76+
val unorderedFeatures = new mutable.HashSet[Int]()
77+
if (numClasses > 2) {
78+
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
79+
if (k - 1 < log2MaxBinsp1) {
80+
// Note: The above check is equivalent to checking:
81+
// numUnorderedBins = (1 << k - 1) - 1 < maxBins
82+
unorderedFeatures.add(f)
83+
} else {
84+
// TODO: Allow this case, where we simply will know nothing about some categories?
85+
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
86+
s"in categorical features (>= $k)")
87+
}
88+
}
89+
} else {
90+
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
91+
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
92+
s"in categorical features (>= $k)")
93+
}
94+
}
95+
96+
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
97+
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
98+
strategy.impurity, strategy.quantileCalculationStrategy)
99+
}
100+
101+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.mllib.tree.impl
1919

2020
import org.apache.spark.mllib.regression.LabeledPoint
21-
import org.apache.spark.mllib.tree.configuration.Strategy
2221
import org.apache.spark.mllib.tree.model.Bin
2322
import org.apache.spark.rdd.RDD
2423

@@ -48,50 +47,35 @@ private[tree] object TreePoint {
4847
* Convert an input dataset into its TreePoint representation,
4948
* binning feature values in preparation for DecisionTree training.
5049
* @param input Input dataset.
51-
* @param strategy DecisionTree training info, used for dataset metadata.
5250
* @param bins Bins for features, of size (numFeatures, numBins).
51+
* @param metadata Learning and dataset metadata
5352
* @return TreePoint dataset representation
5453
*/
5554
def convertToTreeRDD(
5655
input: RDD[LabeledPoint],
57-
strategy: Strategy,
58-
bins: Array[Array[Bin]]): RDD[TreePoint] = {
56+
bins: Array[Array[Bin]],
57+
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
5958
input.map { x =>
60-
TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins,
61-
strategy.categoricalFeaturesInfo)
59+
TreePoint.labeledPointToTreePoint(x, bins, metadata)
6260
}
6361
}
6462

6563
/**
6664
* Convert one LabeledPoint into its TreePoint representation.
6765
* @param bins Bins for features, of size (numFeatures, numBins).
68-
* @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
6966
*/
7067
private def labeledPointToTreePoint(
7168
labeledPoint: LabeledPoint,
72-
isMulticlassClassification: Boolean,
7369
bins: Array[Array[Bin]],
74-
categoricalFeaturesInfo: Map[Int, Int]): TreePoint = {
70+
metadata: DecisionTreeMetadata): TreePoint = {
7571

7672
val numFeatures = labeledPoint.features.size
7773
val numBins = bins(0).size
7874
val arr = new Array[Int](numFeatures)
7975
var featureIndex = 0
8076
while (featureIndex < numFeatures) {
81-
val featureInfo = categoricalFeaturesInfo.get(featureIndex)
82-
val isFeatureContinuous = featureInfo.isEmpty
83-
if (isFeatureContinuous) {
84-
arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false,
85-
bins, categoricalFeaturesInfo)
86-
} else {
87-
val featureCategories = featureInfo.get
88-
val isSpaceSufficientForAllCategoricalSplits
89-
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
90-
val isUnorderedFeature =
91-
isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
92-
arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous,
93-
isUnorderedFeature, bins, categoricalFeaturesInfo)
94-
}
77+
arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
78+
metadata.isUnordered(featureIndex), bins, metadata.featureArity)
9579
featureIndex += 1
9680
}
9781

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model
2020
import org.apache.spark.mllib.tree.configuration.FeatureType._
2121

2222
/**
23-
* Used for "binning" the features bins for faster best split calculation. For a continuous
24-
* feature, a bin is determined by a low and a high "split". For a categorical feature,
25-
* the a bin is determined using a single label value (category).
23+
* Used for "binning" the features bins for faster best split calculation.
24+
*
25+
* For a continuous feature, the bin is determined by a low and a high split,
26+
* where an example with featureValue falls into the bin s.t.
27+
* lowSplit.threshold < featureValue <= highSplit.threshold.
28+
*
29+
* For ordered categorical features, there is a 1-1-1 correspondence between
30+
* bins, splits, and feature values. The bin is determined by category/feature value.
31+
* However, the bins are not necessarily ordered by feature value;
32+
* they are ordered using impurity.
33+
* For unordered categorical features, there is a 1-1 correspondence between bins, splits,
34+
* where bins and splits correspond to subsets of feature values (in highSplit.categories).
35+
*
2636
* @param lowSplit signifying the lower threshold for the continuous feature to be
2737
* accepted in the bin
2838
* @param highSplit signifying the upper threshold for the continuous feature to be
2939
* accepted in the bin
3040
* @param featureType type of feature -- categorical or continuous
31-
* @param category categorical label value accepted in the bin for binary classification
41+
* @param category categorical label value accepted in the bin for ordered features
3242
*/
3343
private[tree]
3444
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
3939
* @return Double prediction from the trained model
4040
*/
4141
def predict(features: Vector): Double = {
42-
topNode.predictIfLeaf(features)
42+
topNode.predict(features)
4343
}
4444

4545
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala

Lines changed: 0 additions & 28 deletions
This file was deleted.

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,24 @@ class Node (
6969

7070
/**
7171
* predict value if node is not leaf
72-
* @param feature feature value
72+
* @param features feature value
7373
* @return predicted value
7474
*/
75-
def predictIfLeaf(feature: Vector) : Double = {
75+
def predict(features: Vector) : Double = {
7676
if (isLeaf) {
7777
predict
7878
} else{
7979
if (split.get.featureType == Continuous) {
80-
if (feature(split.get.feature) <= split.get.threshold) {
81-
leftNode.get.predictIfLeaf(feature)
80+
if (features(split.get.feature) <= split.get.threshold) {
81+
leftNode.get.predict(features)
8282
} else {
83-
rightNode.get.predictIfLeaf(feature)
83+
rightNode.get.predict(features)
8484
}
8585
} else {
86-
if (split.get.categories.contains(feature(split.get.feature))) {
87-
leftNode.get.predictIfLeaf(feature)
86+
if (split.get.categories.contains(features(split.get.feature))) {
87+
leftNode.get.predict(features)
8888
} else {
89-
rightNode.get.predictIfLeaf(feature)
89+
rightNode.get.predict(features)
9090
}
9191
}
9292
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
2424
* :: DeveloperApi ::
2525
* Split applied to a feature
2626
* @param feature feature index
27-
* @param threshold threshold for continuous feature
27+
* @param threshold Threshold for continuous feature.
28+
* Split left if feature <= threshold, else right.
2829
* @param featureType type of feature -- categorical or continuous
29-
* @param categories accepted values for categorical variables
30+
* @param categories Split left if categorical feature value is in this set, else right.
3031
*/
3132
@DeveloperApi
3233
case class Split(

0 commit comments

Comments
 (0)