Skip to content

Commit c0f30c1

Browse files
committed
Added random forests and test suites to spark.ml. Not tested yet. Need to add example as well
1 parent d045ebd commit c0f30c1

File tree

15 files changed

+172
-164
lines changed

15 files changed

+172
-164
lines changed

mllib/src/main/scala/org/apache/spark/ml/Model.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ import org.apache.spark.ml.param.ParamMap
3030
abstract class Model[M <: Model[M]] extends Transformer {
3131
/**
3232
* The parent estimator that produced this model.
33+
* Note: For ensembles' component Models, this value can be null.
3334
*/
3435
val parent: Estimator[M]
3536

3637
/**
3738
* Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model.
39+
* Note: For ensembles' component Models, this value can be null.
3840
*/
3941
val fittingParamMap: ParamMap
4042
}

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ final class DecisionTreeClassificationModel private[ml] (
113113
require(rootNode != null,
114114
"DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
115115

116-
override protected def predict(features: Vector): Double = {
116+
override private[ml] def predict(features: Vector): Double = {
117117
rootNode.predict(features)
118118
}
119119

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@ package org.apache.spark.ml.classification
1919

2020
import scala.collection.mutable
2121

22-
import org.apache.spark.SparkContext
2322
import org.apache.spark.annotation.AlphaComponent
2423
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
2524
import org.apache.spark.ml.impl.tree._
26-
import org.apache.spark.ml.param.ParamMap
25+
import org.apache.spark.ml.param.{Params, ParamMap}
2726
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
2827
import org.apache.spark.ml.util.MetadataUtils
2928
import org.apache.spark.mllib.linalg.Vector
@@ -100,11 +99,10 @@ final class RandomForestClassifier
10099
}
101100

102101
/** (private[ml]) Create a Strategy instance to use with the old API. */
103-
override private[ml] def getOldStrategy(
102+
private[ml] def getOldStrategy(
104103
categoricalFeatures: Map[Int, Int],
105104
numClasses: Int): OldStrategy = {
106-
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
107-
getSubsamplingRate)
105+
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
108106
}
109107
}
110108

@@ -123,10 +121,11 @@ object RandomForestClassifier {
123121
* It supports both binary and multiclass labels, as well as both continuous and categorical
124122
* features.
125123
* @param trees Decision trees in the ensemble.
124+
* Warning: These have null parents.
126125
*/
127126
@AlphaComponent
128127
final class RandomForestClassificationModel private[ml] (
129-
override val parent: DecisionTreeClassifier,
128+
override val parent: RandomForestClassifier,
130129
override val fittingParamMap: ParamMap,
131130
val trees: Array[DecisionTreeClassificationModel])
132131
extends PredictionModel[Vector, RandomForestClassificationModel]
@@ -140,6 +139,8 @@ final class RandomForestClassificationModel private[ml] (
140139
override lazy val getTreeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
141140

142141
override def predict(features: Vector): Double = {
142+
// TODO: Override transform() to broadcast model.
143+
// TODO: When we add a generic Bagging class, handle transform there. Skip single-Row predict.
143144
// Classifies using majority votes.
144145
// Ignore the weights since all are 1.0 for now.
145146
val votes = mutable.Map.empty[Int, Double]
@@ -150,33 +151,37 @@ final class RandomForestClassificationModel private[ml] (
150151
votes.maxBy(_._2)._1
151152
}
152153

153-
override def toString: String = {
154-
s"RandomForestClassificationModel with $numTrees trees"
154+
override protected def copy(): RandomForestClassificationModel = {
155+
val m = new RandomForestClassificationModel(parent, fittingParamMap, trees)
156+
Params.inheritValues(this.extractParamMap(), this, m)
157+
m
155158
}
156159

157-
override def save(sc: SparkContext, path: String): Unit = {
158-
this.toOld.save(sc, path)
160+
override def toString: String = {
161+
s"RandomForestClassificationModel with $numTrees trees"
159162
}
160163

161-
override protected def formatVersion: String = OldRandomForestModel.formatVersion
162-
163-
/** Convert to a model in the old API */
164+
/** (private[ml]) Convert to a model in the old API */
164165
private[ml] def toOld: OldRandomForestModel = {
165166
new OldRandomForestModel(OldAlgo.Classification, trees.map(_.toOld))
166167
}
167168
}
168169

169-
object RandomForestClassificationModel
170-
extends Loader[RandomForestClassificationModel] {
171-
172-
override def load(sc: SparkContext, path: String): RandomForestClassificationModel = {
173-
RandomForestClassificationModel.fromOld(OldRandomForestModel.load(sc, path))
174-
}
170+
private[ml] object RandomForestClassificationModel {
175171

176-
private[ml] def fromOld(oldModel: OldRandomForestModel): RandomForestClassificationModel = {
172+
/** (private[ml]) Convert a model from the old API */
173+
def fromOld(
174+
oldModel: OldRandomForestModel,
175+
parent: RandomForestClassifier,
176+
fittingParamMap: ParamMap,
177+
categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
177178
require(oldModel.algo == OldAlgo.Classification,
178179
s"Cannot convert non-classification RandomForestModel (old API) to" +
179180
s" RandomForestClassificationModel (new API). Algo is: ${oldModel.algo}")
180-
new RandomForestClassificationModel(oldModel.trees.map(DecisionTreeClassificationModel.fromOld))
181+
val trees = oldModel.trees.map { tree =>
182+
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
183+
DecisionTreeClassificationModel.fromOld(tree, null, null, categoricalFeatures)
184+
}
185+
new RandomForestClassificationModel(parent, fittingParamMap, trees)
181186
}
182187
}

mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,12 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
352352
* Create a Strategy instance to use with the old API.
353353
* NOTE: The caller should set impurity and seed.
354354
*/
355-
override private[ml] def getOldStrategy(
355+
private[ml] def getOldStrategy(
356356
categoricalFeatures: Map[Int, Int],
357-
numClasses: Int): OldStrategy = {
358-
val strategy = super.getOldStrategy(categoricalFeatures, numClasses)
359-
strategy.setSubsamplingRate(getSubsamplingRate)
360-
strategy
357+
numClasses: Int,
358+
oldAlgo: OldAlgo.Algo,
359+
oldImpurity: OldImpurity): OldStrategy = {
360+
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
361361
}
362362
}
363363

mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ final class DecisionTreeRegressionModel private[ml] (
104104
require(rootNode != null,
105105
"DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
106106

107-
override protected def predict(features: Vector): Double = {
107+
override private[ml] def predict(features: Vector): Double = {
108108
rootNode.predict(features)
109109
}
110110

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

Lines changed: 72 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,90 +15,86 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.regression
19-
20-
import org.apache.spark.SparkContext
21-
import org.apache.spark.mllib.impl.tree._
18+
package org.apache.spark.ml.regression
19+
20+
import org.apache.spark.annotation.AlphaComponent
21+
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
22+
import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams}
23+
import org.apache.spark.ml.param.{Params, ParamMap}
24+
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
25+
import org.apache.spark.ml.util.MetadataUtils
2226
import org.apache.spark.mllib.linalg.Vector
27+
import org.apache.spark.mllib.regression.LabeledPoint
2328
import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
2429
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
2530
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
26-
import org.apache.spark.mllib.util.{Loader, Saveable}
2731
import org.apache.spark.rdd.RDD
32+
import org.apache.spark.sql.DataFrame
2833

2934

3035
/**
36+
* :: AlphaComponent ::
37+
*
3138
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression.
3239
* It supports both continuous and categorical features.
3340
*/
34-
class RandomForestRegressor
35-
extends TreeRegressor[RandomForestRegressionModel]
36-
with RandomForestParams[RandomForestRegressor]
37-
with TreeRegressorParams[RandomForestRegressor] {
41+
@AlphaComponent
42+
final class RandomForestRegressor
43+
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
44+
with RandomForestParams with TreeRegressorParams {
3845

3946
// Override parameter setters from parent trait for Java API compatibility.
4047

4148
// Parameters from TreeRegressorParams:
4249

43-
override def setMaxDepth(maxDepth: Int): RandomForestRegressor = super.setMaxDepth(maxDepth)
50+
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
4451

45-
override def setMaxBins(maxBins: Int): RandomForestRegressor = super.setMaxBins(maxBins)
52+
override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
4653

47-
override def setMinInstancesPerNode(minInstancesPerNode: Int): RandomForestRegressor =
48-
super.setMinInstancesPerNode(minInstancesPerNode)
54+
override def setMinInstancesPerNode(value: Int): this.type =
55+
super.setMinInstancesPerNode(value)
4956

50-
override def setMinInfoGain(minInfoGain: Double): RandomForestRegressor =
51-
super.setMinInfoGain(minInfoGain)
57+
override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
5258

53-
override def setMaxMemoryInMB(maxMemoryInMB: Int): RandomForestRegressor =
54-
super.setMaxMemoryInMB(maxMemoryInMB)
59+
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
5560

56-
override def setCacheNodeIds(cacheNodeIds: Boolean): RandomForestRegressor =
57-
super.setCacheNodeIds(cacheNodeIds)
61+
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
5862

59-
override def setCheckpointInterval(checkpointInterval: Int): RandomForestRegressor =
60-
super.setCheckpointInterval(checkpointInterval)
63+
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
6164

62-
override def setImpurity(impurity: String): RandomForestRegressor =
63-
super.setImpurity(impurity)
65+
override def setImpurity(value: String): this.type = super.setImpurity(value)
6466

6567
// Parameters from TreeEnsembleParams:
6668

67-
override def setSubsamplingRate(subsamplingRate: Double): RandomForestRegressor =
68-
super.setSubsamplingRate(subsamplingRate)
69+
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
6970

70-
override def setSeed(seed: Long): RandomForestRegressor = super.setSeed(seed)
71+
override def setSeed(value: Long): this.type = super.setSeed(value)
7172

7273
// Parameters from RandomForestParams:
7374

74-
override def setNumTrees(numTrees: Int): RandomForestRegressor = super.setNumTrees(numTrees)
75+
override def setNumTrees(value: Int): this.type = super.setNumTrees(value)
7576

76-
override def setFeaturesPerNode(featuresPerNode: String): RandomForestRegressor =
77-
super.setFeaturesPerNode(featuresPerNode)
77+
override def setFeaturesPerNode(value: String): this.type = super.setFeaturesPerNode(value)
7878

79-
override def run(
80-
input: RDD[LabeledPoint],
81-
categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = {
79+
override protected def train(
80+
dataset: DataFrame,
81+
paramMap: ParamMap): RandomForestRegressionModel = {
82+
val categoricalFeatures: Map[Int, Int] =
83+
MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
84+
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
8285
val strategy = getOldStrategy(categoricalFeatures)
8386
val oldModel = OldRandomForest.trainRegressor(
84-
input, strategy, getNumTrees, getFeaturesPerNodeStr, getSeed.toInt)
85-
RandomForestRegressionModel.fromOld(oldModel)
87+
oldDataset, strategy, getNumTrees, getFeaturesPerNodeStr, getSeed.toInt)
88+
RandomForestRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
8689
}
8790

88-
/**
89-
* Create a Strategy instance to use with the old API.
90-
* TODO: Make this protected once we deprecate the old API.
91-
*/
92-
private[mllib] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
93-
val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0)
94-
strategy.algo = OldAlgo.Regression
95-
strategy.impurity = getOldImpurity
96-
strategy
91+
/** (private[ml]) Create a Strategy instance to use with the old API. */
92+
private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
93+
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
9794
}
9895
}
9996

10097
object RandomForestRegressor {
101-
10298
/** Accessor for supported impurity settings */
10399
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
104100

@@ -107,51 +103,66 @@ object RandomForestRegressor {
107103
}
108104

109105
/**
106+
* :: AlphaComponent ::
107+
*
110108
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
111109
* It supports both continuous and categorical features.
112110
* @param trees Decision trees in the ensemble.
113111
*/
114-
class RandomForestRegressionModel(val trees: Array[DecisionTreeRegressionModel])
115-
extends TreeEnsembleModel with Serializable with Saveable {
112+
@AlphaComponent
113+
final class RandomForestRegressionModel private[ml] (
114+
override val parent: RandomForestRegressor,
115+
override val fittingParamMap: ParamMap,
116+
val trees: Array[DecisionTreeRegressionModel])
117+
extends PredictionModel[Vector, RandomForestRegressionModel]
118+
with TreeEnsembleModel with Serializable {
116119

117120
require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
118121

119122
override def getTrees: Array[DecisionTreeModel] = trees.asInstanceOf[Array[DecisionTreeModel]]
120123

124+
// Note: We may add support for weights (based on tree performance) later on.
121125
override lazy val getTreeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
122126

123127
override def predict(features: Vector): Double = {
128+
// TODO: Override transform() to broadcast model.
129+
// TODO: When we add a generic Bagging class, handle transform there. Skip single-Row predict.
124130
// Predict average of tree predictions.
125131
// Ignore the weights since all are 1.0 for now.
126132
trees.map(_.predict(features)).sum / numTrees
127133
}
128134

129-
override def toString: String = {
130-
s"RandomForestRegressionModel with $numTrees trees"
135+
override protected def copy(): RandomForestRegressionModel = {
136+
val m = new RandomForestRegressionModel(parent, fittingParamMap, trees)
137+
Params.inheritValues(this.extractParamMap(), this, m)
138+
m
131139
}
132140

133-
override def save(sc: SparkContext, path: String): Unit = {
134-
this.toOld.save(sc, path)
141+
override def toString: String = {
142+
s"RandomForestRegressionModel with $numTrees trees"
135143
}
136144

137-
override protected def formatVersion: String = OldRandomForestModel.formatVersion
138-
139-
/** Convert to a model in the old API */
140-
private[mllib] def toOld: OldRandomForestModel = {
145+
/** (private[ml]) Convert to a model in the old API */
146+
private[ml] def toOld: OldRandomForestModel = {
141147
new OldRandomForestModel(OldAlgo.Regression, trees.map(_.toOld))
142148
}
143149
}
144150

145-
object RandomForestRegressionModel extends Loader[RandomForestRegressionModel] {
151+
private[ml] object RandomForestRegressionModel {
146152

147-
override def load(sc: SparkContext, path: String): RandomForestRegressionModel = {
148-
RandomForestRegressionModel.fromOld(OldRandomForestModel.load(sc, path))
149-
}
150-
151-
private[mllib] def fromOld(oldModel: OldRandomForestModel): RandomForestRegressionModel = {
153+
/** (private[ml]) Convert a model from the old API */
154+
def fromOld(
155+
oldModel: OldRandomForestModel,
156+
parent: RandomForestRegressor,
157+
fittingParamMap: ParamMap,
158+
categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = {
152159
require(oldModel.algo == OldAlgo.Regression,
153160
s"Cannot convert non-regression RandomForestModel (old API) to" +
154161
s" RandomForestRegressionModel (new API). Algo is: ${oldModel.algo}")
155-
new RandomForestRegressionModel(oldModel.trees.map(DecisionTreeRegressionModel.fromOld))
162+
val trees = oldModel.trees.map { tree =>
163+
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
164+
DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
165+
}
166+
new RandomForestRegressionModel(parent, fittingParamMap, trees)
156167
}
157168
}

mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ trait TreeEnsembleModel {
7272
// Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of
7373
// DecisionTreeModel.
7474

75-
/** Trees in this ensemble */
75+
/** Trees in this ensemble. Warning: These have null parent Estimators. */
7676
def getTrees: Array[DecisionTreeModel]
7777

7878
/** Weights for each tree, zippable with [[getTrees]] */

mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.ml.classification;
1919

20-
import java.io.File;
2120
import java.io.Serializable;
2221
import java.util.HashMap;
2322
import java.util.Map;
@@ -32,7 +31,6 @@
3231
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
3332
import org.apache.spark.mllib.regression.LabeledPoint;
3433
import org.apache.spark.sql.DataFrame;
35-
import org.apache.spark.util.Utils;
3634

3735

3836
public class JavaDecisionTreeClassifierSuite implements Serializable {

0 commit comments

Comments
 (0)