Skip to content

[MLlib] SPARK-1536: multiclass classification support for decision tree #886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 80 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
50b143a
adding support for very deep trees
manishamde Apr 20, 2014
abc5a23
Parameterizing max memory.
etrain Apr 22, 2014
2f6072c
Merge pull request #5 from etrain/deep_tree
manishamde Apr 22, 2014
2f1e093
minor: added doc for maxMemory parameter
manishamde Apr 22, 2014
0287772
Fixing scalastyle issue.
etrain Apr 22, 2014
fecf89a
Merge pull request #6 from etrain/deep_tree
manishamde Apr 22, 2014
719d009
updating user documentation
manishamde Apr 24, 2014
9dbdabe
merge from master
manishamde Apr 29, 2014
1517155
updated documentation
manishamde Apr 29, 2014
718506b
added unit test
manishamde Apr 30, 2014
e0426ee
renamed parameter
manishamde May 1, 2014
dad9652
removed unused imports
manishamde May 1, 2014
cbd9f14
modified scala.math to math
manishamde May 3, 2014
5e82202
added documentation, fixed off by 1 error in max level calculation
manishamde May 6, 2014
4731cda
formatting
manishamde May 6, 2014
5eca9e4
grammar
manishamde May 6, 2014
8053fed
more formatting
manishamde May 6, 2014
426bb28
programming guide blurb
manishamde May 6, 2014
b27ad2c
formatting
manishamde May 6, 2014
ce004a1
minor formatting
manishamde May 6, 2014
7fc9545
added docs
manishamde May 7, 2014
968ca9d
merged master
manishamde May 7, 2014
a1a6e09
added weighted point class
manishamde May 1, 2014
14aea48
changing instance format to weighted labeled point
manishamde May 1, 2014
455bea9
fixed tests
manishamde May 1, 2014
46f909c
todo for multiclass support
manishamde May 4, 2014
4d5f70c
added multiclass support for find splits bins
manishamde May 6, 2014
3f85a17
tests for multiclass classification
manishamde May 7, 2014
46e06ee
minor mods
manishamde May 7, 2014
6c7af22
prepared for multiclass without breaking binary classification
manishamde May 7, 2014
5c78e1a
added multiclass support
manishamde May 11, 2014
e006f9d
changing variable names
manishamde May 12, 2014
098e8c5
merged master
manishamde May 12, 2014
34549d0
fixing error during merge
manishamde May 12, 2014
e547151
minor modifications
manishamde May 12, 2014
75f2bfc
minor code style fix
manishamde May 12, 2014
6b912dc
added numclasses to tree runner, predict logic for multiclass, add mu…
manishamde May 12, 2014
18d2835
changing default values for num classes
manishamde May 12, 2014
d012be7
fixed while loop
manishamde May 12, 2014
ed5a2df
fixed classification requirements
manishamde May 13, 2014
d8e4a11
sample weights
manishamde May 13, 2014
ab5cb21
multiclass logic
manishamde May 18, 2014
d811425
multiclass bin aggregate logic
manishamde May 18, 2014
f16a9bb
fixing while loop
manishamde May 18, 2014
1dd2735
bin search logic for multiclass
manishamde May 18, 2014
7e5f08c
minor doc
manishamde May 18, 2014
bce835f
code cleanup
manishamde May 18, 2014
828ff16
added categorical variable test
manishamde May 21, 2014
8cfd3b6
working for categorical multiclass classification
manishamde May 22, 2014
f5f6b83
multiclass for continous variables
manishamde May 22, 2014
1892a2c
tests and use multiclass binaggregate length when atleast one categor…
manishamde May 23, 2014
9a90c93
Merge branch 'master' into multiclass
manishamde May 23, 2014
12e6d0a
minor: removing line in doc
manishamde May 26, 2014
237762d
renaming functions
manishamde May 26, 2014
34ee7b9
minor: code style
manishamde May 26, 2014
23d4268
minor: another minor code style
manishamde May 26, 2014
e3e8843
minor code formatting
manishamde May 27, 2014
adc7315
support ordered categorical splits for multiclass classification
manishamde Jun 4, 2014
8e44ab8
updated doc
manishamde Jun 4, 2014
3d7f911
updated doc
manishamde Jun 5, 2014
485eaae
implicit conversion from LabeledPoint to WeightedLabeledPoint
manishamde Jul 7, 2014
5c1b2ca
doc for PointConverter class
manishamde Jul 8, 2014
9cc3e31
added implicit conversion import
manishamde Jul 8, 2014
06b1690
fixed off-by-one error in bin to split conversion
manishamde Jul 9, 2014
2061cf5
merged from master
manishamde Jul 10, 2014
0fecd38
minor: add newline to EOF
manishamde Jul 10, 2014
d75ac32
removed WeightedLabeledPoint from this PR
manishamde Jul 11, 2014
e4c1321
using while loop for regression histograms
manishamde Jul 14, 2014
b2ae41f
minor: scalastyle
manishamde Jul 14, 2014
4e85f2c
minor: fixed scalastyle issues
manishamde Jul 14, 2014
2d85a48
minor: fixed scalastyle issues reprise
manishamde Jul 14, 2014
afced16
removed label weights support
manishamde Jul 14, 2014
c8428c4
fixing weird multiline bug
manishamde Jul 15, 2014
45e767a
adding developer api annotation for overriden methods
manishamde Jul 15, 2014
abf2901
adding classes to MimaExcludes.scala
manishamde Jul 17, 2014
e1c970d
merged master
manishamde Jul 17, 2014
10fdd82
fixing MIMA excludes
manishamde Jul 17, 2014
1ce7212
change problem filter for mima
manishamde Jul 18, 2014
c5b2d04
more MIMA fixes
manishamde Jul 18, 2014
26f8acc
another attempt at fixing mima
manishamde Jul 18, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/mllib-decision-tree.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,17 @@ bins if the condition is not satisfied.

**Categorical features**

For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for
binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the
For `$M$` categorical feature values, one could come up with `$2^(M-1)-1$` split candidates. For
binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the
categorical feature values by the proportion of labels falling in one of the two classes (see
Section 9.2.4 in
[Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for
details). For example, for a binary classification problem with one categorical feature with three
categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical
features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B
and A , B \| C where \| denotes the split.
and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification
when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value
is used for ordering.

### Stopping rule

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
algo: Algo = Classification,
numClassesForClassification: Int = 2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this to be a parameter and not inferred from the data?

Also - I'm wondering if it makes sense to subclass params with DecisionTreeParams vs. RegressionTreeParams so that we keep logically separate options separate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inference from a large dataset could take a lot of time. In general, most practitioners know in advance. If not, we can add a pre-processing step.

Currently we have only numClassesForClassification as a classification specific parameter. In general, I agree with you. At the same time, didn't want to create more configuration classes for the user. Shall we leave it as is for now and handle it with the ensembles PR where we have more parameters (boosting iterations, num trees, feature subsetting, etc.) ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, makes sense. If it doesn't complicate things too much we might
consider adding an interface that doesn't have this specified and figures
it out in one shot.

Worth noting is that in R, an object of type "factor" (the default for
categorical/label data) has this information built in. It can be a big pain
at load time while the system tries to figure out the cardinality of the
factor, but it leads to a nice compact representation of the data and
eliminates situations like this one.

I agree on doing the API separation with the ensembles PR.

On Thu, Jun 19, 2014 at 10:46 AM, manishamde notifications@github.com
wrote:

In
examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala:

@@ -49,6 +49,7 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
algo: Algo = Classification,

  •  numClassesForClassification: Int = 2,
    

Inference from a large dataset could take a lot of time. In general, most
practitioners know in advance. If not, we can add a pre-processing step.

Currently we have only numClassesForClassification as a classification
specific parameter. In general, I agree with you. At the same time, didn't
want to create more configuration classes for the user. Shall we leave it
as is for now and handle it with the ensembles PR where we have more
parameters (boosting iterations, num trees, feature subsetting, etc.) ?


Reply to this email directly or view it on GitHub
https://github.com/apache/spark/pull/886/files#r13982468.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let me create a JIRA ticket for this so that we don't forget. :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 100)
Expand All @@ -68,6 +69,10 @@ object DecisionTreeRunner {
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
opt[Int]("numClassesForClassification")
.text(s"number of classes for classification, "
+ s"default: ${defaultParams.numClassesForClassification}")
.action((x, c) => c.copy(numClassesForClassification = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
Expand Down Expand Up @@ -118,7 +123,13 @@ object DecisionTreeRunner {
case Variance => impurity.Variance
}

val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins)
val strategy
= new Strategy(
algo = params.algo,
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
numClassesForClassification = params.numClassesForClassification)
val model = DecisionTree.train(training, strategy)

if (params.algo == Classification) {
Expand All @@ -139,12 +150,8 @@ object DecisionTreeRunner {
*/
private def accuracyScore(
model: DecisionTreeModel,
data: RDD[LabeledPoint],
threshold: Double = 0.5): Double = {
def predictedValue(features: Vector): Double = {
if (model.predict(features) < threshold) 0.0 else 1.0
}
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
data: RDD[LabeledPoint]): Double = {
val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
val count = data.count()
correctCount.toDouble / count
}
Expand Down
732 changes: 532 additions & 200 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
* @param numClassesForClassification number of classes for classification. Default value is 2
* leads to binary classification
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
Expand All @@ -44,7 +46,15 @@ class Strategy (
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
val numClassesForClassification: Int = 2,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val maxMemoryInMB: Int = 128) extends Serializable
val maxMemoryInMB: Int = 128) extends Serializable {

require(numClassesForClassification >= 2)
val isMulticlassClassification = numClassesForClassification > 2
val isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,35 @@ object Entropy extends Impurity {

/**
* :: DeveloperApi ::
* entropy calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return entropy value
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
override def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
-(f0 * log2(f0)) - (f1 * log2(f1))
override def calculate(counts: Array[Double], totalCount: Double): Double = {
val numClasses = counts.length
var impurity = 0.0
var classIndex = 0
while (classIndex < numClasses) {
val classCount = counts(classIndex)
if (classCount != 0) {
val freq = classCount / totalCount
impurity -= freq * log2(freq)
}
classIndex += 1
}
impurity
}

/**
* :: DeveloperApi ::
* variance calculation
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
*/
@DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Entropy.calculate")
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,32 @@ object Gini extends Impurity {

/**
* :: DeveloperApi ::
* Gini coefficient calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return Gini coefficient value
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
override def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
1 - f0 * f0 - f1 * f1
override def calculate(counts: Array[Double], totalCount: Double): Double = {
val numClasses = counts.length
var impurity = 1.0
var classIndex = 0
while (classIndex < numClasses) {
val freq = counts(classIndex) / totalCount
impurity -= freq * freq
classIndex += 1
}
impurity
}

/**
* :: DeveloperApi ::
* variance calculation
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
*/
@DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Gini.calculate")
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ trait Impurity extends Serializable {

/**
* :: DeveloperApi ::
* information calculation for binary classification
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
def calculate(c0 : Double, c1 : Double): Double
def calculate(counts: Array[Double], totalCount: Double): Double

/**
* :: DeveloperApi ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,16 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
*/
@Experimental
object Variance extends Impurity {
override def calculate(c0: Double, c1: Double): Double =

/**
* :: DeveloperApi ::
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
override def calculate(counts: Array[Double], totalCount: Double): Double =
throw new UnsupportedOperationException("Variance.calculate")

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin
* @param category categorical label value accepted in the bin for binary classification
*/
private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,19 @@ import org.apache.spark.annotation.DeveloperApi
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
val rightImpurity: Double,
val predict: Double) extends Serializable {
val predict: Double,
val prob: Double = 0.0) extends Serializable {

override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict)
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
}
}
Loading