-
Notifications
You must be signed in to change notification settings - Fork 28.7k
MLI-1 Decision Trees #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cd53eae
92cedce
8bca1e2
0012a77
03f534c
dad0afc
4798aae
80e8c66
b0eb866
98ec8d5
02c595c
733d6dd
154aa77
b0e3e76
c8f6d60
e23c2e5
53108ed
6df35b9
dbb7ac1
d504eb1
6b7de78
b09dc98
c0e522b
f067d68
5841c28
0dd7659
dd0c0d7
9372779
84f85d6
d3023b3
63e786b
cd2c2b4
eb8fcbe
794ff4d
d1ef4f6
ad1fc21
62c2562
6068356
2116360
632818f
ff363a7
4576b64
24500c5
c487e6a
f963ef5
201702f
62dc723
e1dd86f
f536ae9
7d54b4f
1e8c704
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
This package contains the default implementation of the decision tree algorithm. | ||
|
||
The decision tree algorithm supports: | ||
+ Binary classification | ||
+ Regression | ||
+ Information loss calculation with entropy and gini for classification and variance for regression | ||
+ Both continuous and categorical features | ||
|
||
# Tree improvements | ||
+ Node model pruning | ||
+ Printing to dot files | ||
|
||
# Future Ensemble Extensions | ||
|
||
+ Random forests | ||
+ Boosting | ||
+ Extremely randomized trees |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.configuration | ||
|
||
/** | ||
* Enum to select the algorithm for the decision tree | ||
*/ | ||
object Algo extends Enumeration { | ||
type Algo = Value | ||
val Classification, Regression = Value | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.configuration | ||
|
||
/** | ||
* Enum to describe whether a feature is "continuous" or "categorical" | ||
*/ | ||
object FeatureType extends Enumeration { | ||
type FeatureType = Value | ||
val Continuous, Categorical = Value | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.configuration | ||
|
||
/** | ||
* Enum for selecting the quantile calculation strategy | ||
*/ | ||
object QuantileStrategy extends Enumeration { | ||
type QuantileStrategy = Value | ||
val Sort, MinMax, ApproxHist = Value | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.configuration | ||
|
||
import org.apache.spark.mllib.tree.impurity.Impurity | ||
import org.apache.spark.mllib.tree.configuration.Algo._ | ||
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ | ||
|
||
/** | ||
* Stores all the configuration options for tree construction | ||
* @param algo classification or regression | ||
* @param impurity criterion used for information gain calculation | ||
* @param maxDepth maximum depth of the tree | ||
* @param maxBins maximum number of bins used for splitting features | ||
* @param quantileCalculationStrategy algorithm for calculating quantiles | ||
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the | ||
* number of discrete values they take. For example, an entry (n -> | ||
* k) implies the feature n is categorical with k categories 0, | ||
* 1, 2, ... , k-1. It's important to note that features are | ||
* zero-indexed. | ||
*/ | ||
class Strategy ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
val algo: Algo, | ||
val impurity: Impurity, | ||
val maxDepth: Int, | ||
val maxBins: Int = 100, | ||
val quantileCalculationStrategy: QuantileStrategy = Sort, | ||
val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.impurity | ||
|
||
/** | ||
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during | ||
* binary classification. | ||
*/ | ||
object Entropy extends Impurity { | ||
|
||
def log2(x: Double) = scala.math.log(x) / scala.math.log(2) | ||
|
||
/** | ||
* entropy calculation | ||
* @param c0 count of instances with label 0 | ||
* @param c1 count of instances with label 1 | ||
* @return entropy value | ||
*/ | ||
def calculate(c0: Double, c1: Double): Double = { | ||
if (c0 == 0 || c1 == 0) { | ||
0 | ||
} else { | ||
val total = c0 + c1 | ||
val f0 = c0 / total | ||
val f1 = c1 / total | ||
-(f0 * log2(f0)) - (f1 * log2(f1)) | ||
} | ||
} | ||
|
||
def calculate(count: Double, sum: Double, sumSquares: Double): Double = | ||
throw new UnsupportedOperationException("Entropy.calculate") | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.impurity | ||
|
||
/** | ||
* Class for calculating the | ||
* [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] | ||
* during binary classification. | ||
*/ | ||
object Gini extends Impurity { | ||
|
||
/** | ||
* Gini coefficient calculation | ||
* @param c0 count of instances with label 0 | ||
* @param c1 count of instances with label 1 | ||
* @return Gini coefficient value | ||
*/ | ||
override def calculate(c0: Double, c1: Double): Double = { | ||
if (c0 == 0 || c1 == 0) { | ||
0 | ||
} else { | ||
val total = c0 + c1 | ||
val f0 = c0 / total | ||
val f1 = c1 / total | ||
1 - f0 * f0 - f1 * f1 | ||
} | ||
} | ||
|
||
def calculate(count: Double, sum: Double, sumSquares: Double): Double = | ||
throw new UnsupportedOperationException("Gini.calculate") | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.impurity | ||
|
||
/** | ||
* Trait for calculating information gain. | ||
*/ | ||
trait Impurity extends Serializable { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The For a generic interface, an additional The |
||
|
||
/** | ||
* information calculation for binary classification | ||
* @param c0 count of instances with label 0 | ||
* @param c1 count of instances with label 1 | ||
* @return information value | ||
*/ | ||
def calculate(c0 : Double, c1 : Double): Double | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. JavaDoc for public methods. |
||
|
||
/** | ||
* information calculation for regression | ||
* @param count number of instances | ||
* @param sum sum of labels | ||
* @param sumSquares summation of squares of the labels | ||
* @return information value | ||
*/ | ||
def calculate(count: Double, sum: Double, sumSquares: Double): Double | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is easy to loss precision or run into overflow in the computation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a nice observation. However, using the I see your concern with computing We can calculate StatCounter per partition using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with Manish. Numerical stability is the first thing that comes to mind on seeing a large We definitely cannot use the methods in DoubleRDDFunctions because we want to calculate the variance of various splits, which requires the stats to be "aggregable". But we may be able to modify the api's to use (count, avg, avgSquares) as the stats and make the calculations more stable. E.g., to merge (count, avg) of two parts There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that overflow is an issue here (particularly in the case of sumSquares), but also agree with Manish/Hirakendu that this algorithm maintains its ability to generate a tree in a reasonable amount of time based on this property that we compute statistics for splits and then merge them together. I actually do think it makes sense to maintain "(count, average, averageSumSq)" for each partition in a way that's overflow friendly and compute the combination as count-weighted average of both as Hirakendu suggests. This will complicate the code but should solve the overflow problem and keep things pretty efficient. That said - maybe this could be taken care of in a future PR as a bugfix, rather than in this one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The major loss of precision is from The question is whether we should make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm just catching up on this, but is the problem that there will be other types of Impurity later that calculate different stats (not just variance)? In that case, maybe we can have Impurity be parameterized (Impurity[T]) where T is a type it accumulates over. However I'd also be okay with leaving this as is initially and marking the API unstable if this is an internal API. The question is how many users will call this directly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW I'd also be okay updating this API in a later pull request before we release 1.0. It's fair game to change new APIs in that time window. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mateiz A user needs an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mengxr The generic interface you noted is correct. However, I think implementing this generic interface and the corresponding implementations is not a minor code change. There are some assumptions in the bin aggregation code that may need to be updated and it also requires adding partition-wise impurity calculation and aggregation. @mateiz As @mengxr noted, it's highly unlikely that a user will write their own I think we all agree (please correct me if I am wrong) the Is this the correct method of marking a method as unstable using the javadoc? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding to the discussion on the need for a generic interface for In addition to performance-oriented implementations for specific loss functions, I would still recommend a generic For reference and example of one such interface and implementations, see |
||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.impurity | ||
|
||
/** | ||
* Class for calculating variance during regression | ||
*/ | ||
object Variance extends Impurity { | ||
override def calculate(c0: Double, c1: Double): Double = | ||
throw new UnsupportedOperationException("Variance.calculate") | ||
|
||
/** | ||
* variance calculation | ||
* @param count number of instances | ||
* @param sum sum of labels | ||
* @param sumSquares summation of squares of the labels | ||
*/ | ||
override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { | ||
val squaredLoss = sumSquares - (sum * sum) / count | ||
squaredLoss / count | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.model | ||
|
||
import org.apache.spark.mllib.tree.configuration.FeatureType._ | ||
|
||
/** | ||
* Used for "binning" the features bins for faster best split calculation. For a continuous | ||
* feature, a bin is determined by a low and a high "split". For a categorical feature, | ||
* the a bin is determined using a single label value (category). | ||
* @param lowSplit signifying the lower threshold for the continuous feature to be | ||
* accepted in the bin | ||
* @param highSplit signifying the upper threshold for the continuous feature to be | ||
* accepted in the bin | ||
* @param featureType type of feature -- categorical or continuous | ||
* @param category categorical label value accepted in the bin | ||
*/ | ||
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It's not clear this class is needed at first place. For categorical variables, the value itself is the bin index, and for continuous variables, bins are simply defined by candidate thresholds, in turn defined by quanties. For every feature id, one can maintain a list of categories and thresholds and be done. In that case, for continuous features, the position of the threshold is the bin index. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
Algorithm
Enumeration seems redundant givenImpurity
which implies theAlgorithm
anyway.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The various
Enumeration
classes inmllib.tree.configuration
package are neat. A uniform design pattern for parameters and options should be used for MLLib and Spark, and this could be a start. Alternatively, if there is an existing pattern in use, it should be followed for decision tree as well.