15
15
* limitations under the License.
16
16
*/
17
17
18
- package org .apache .spark .mllib .regression
19
-
20
- import org .apache .spark .SparkContext
21
- import org .apache .spark .mllib .impl .tree ._
18
+ package org .apache .spark .ml .regression
19
+
20
+ import org .apache .spark .annotation .AlphaComponent
21
+ import org .apache .spark .ml .impl .estimator .{PredictionModel , Predictor }
22
+ import org .apache .spark .ml .impl .tree .{RandomForestParams , TreeRegressorParams }
23
+ import org .apache .spark .ml .param .{Params , ParamMap }
24
+ import org .apache .spark .ml .tree .{DecisionTreeModel , TreeEnsembleModel }
25
+ import org .apache .spark .ml .util .MetadataUtils
22
26
import org .apache .spark .mllib .linalg .Vector
27
+ import org .apache .spark .mllib .regression .LabeledPoint
23
28
import org .apache .spark .mllib .tree .{RandomForest => OldRandomForest }
24
29
import org .apache .spark .mllib .tree .configuration .{Algo => OldAlgo , Strategy => OldStrategy }
25
30
import org .apache .spark .mllib .tree .model .{RandomForestModel => OldRandomForestModel }
26
- import org .apache .spark .mllib .util .{Loader , Saveable }
27
31
import org .apache .spark .rdd .RDD
32
+ import org .apache .spark .sql .DataFrame
28
33
29
34
30
35
/**
36
+ * :: AlphaComponent ::
37
+ *
31
38
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest ]] learning algorithm for regression.
32
39
* It supports both continuous and categorical features.
33
40
*/
34
- class RandomForestRegressor
35
- extends TreeRegressor [ RandomForestRegressionModel ]
36
- with RandomForestParams [ RandomForestRegressor ]
37
- with TreeRegressorParams [ RandomForestRegressor ] {
41
+ @ AlphaComponent
42
+ final class RandomForestRegressor
43
+ extends Predictor [ Vector , RandomForestRegressor , RandomForestRegressionModel ]
44
+ with RandomForestParams with TreeRegressorParams {
38
45
39
46
// Override parameter setters from parent trait for Java API compatibility.
40
47
41
48
// Parameters from TreeRegressorParams:
42
49
43
- override def setMaxDepth (maxDepth : Int ): RandomForestRegressor = super .setMaxDepth(maxDepth )
50
+ override def setMaxDepth (value : Int ): this . type = super .setMaxDepth(value )
44
51
45
- override def setMaxBins (maxBins : Int ): RandomForestRegressor = super .setMaxBins(maxBins )
52
+ override def setMaxBins (value : Int ): this . type = super .setMaxBins(value )
46
53
47
- override def setMinInstancesPerNode (minInstancesPerNode : Int ): RandomForestRegressor =
48
- super .setMinInstancesPerNode(minInstancesPerNode )
54
+ override def setMinInstancesPerNode (value : Int ): this . type =
55
+ super .setMinInstancesPerNode(value )
49
56
50
- override def setMinInfoGain (minInfoGain : Double ): RandomForestRegressor =
51
- super .setMinInfoGain(minInfoGain)
57
+ override def setMinInfoGain (value : Double ): this .type = super .setMinInfoGain(value)
52
58
53
- override def setMaxMemoryInMB (maxMemoryInMB : Int ): RandomForestRegressor =
54
- super .setMaxMemoryInMB(maxMemoryInMB)
59
+ override def setMaxMemoryInMB (value : Int ): this .type = super .setMaxMemoryInMB(value)
55
60
56
- override def setCacheNodeIds (cacheNodeIds : Boolean ): RandomForestRegressor =
57
- super .setCacheNodeIds(cacheNodeIds)
61
+ override def setCacheNodeIds (value : Boolean ): this .type = super .setCacheNodeIds(value)
58
62
59
- override def setCheckpointInterval (checkpointInterval : Int ): RandomForestRegressor =
60
- super .setCheckpointInterval(checkpointInterval)
63
+ override def setCheckpointInterval (value : Int ): this .type = super .setCheckpointInterval(value)
61
64
62
- override def setImpurity (impurity : String ): RandomForestRegressor =
63
- super .setImpurity(impurity)
65
+ override def setImpurity (value : String ): this .type = super .setImpurity(value)
64
66
65
67
// Parameters from TreeEnsembleParams:
66
68
67
- override def setSubsamplingRate (subsamplingRate : Double ): RandomForestRegressor =
68
- super .setSubsamplingRate(subsamplingRate)
69
+ override def setSubsamplingRate (value : Double ): this .type = super .setSubsamplingRate(value)
69
70
70
- override def setSeed (seed : Long ): RandomForestRegressor = super .setSeed(seed )
71
+ override def setSeed (value : Long ): this . type = super .setSeed(value )
71
72
72
73
// Parameters from RandomForestParams:
73
74
74
- override def setNumTrees (numTrees : Int ): RandomForestRegressor = super .setNumTrees(numTrees )
75
+ override def setNumTrees (value : Int ): this . type = super .setNumTrees(value )
75
76
76
- override def setFeaturesPerNode (featuresPerNode : String ): RandomForestRegressor =
77
- super .setFeaturesPerNode(featuresPerNode)
77
+ override def setFeaturesPerNode (value : String ): this .type = super .setFeaturesPerNode(value)
78
78
79
- override def run (
80
- input : RDD [LabeledPoint ],
81
- categoricalFeatures : Map [Int , Int ]): RandomForestRegressionModel = {
79
+ override protected def train (
80
+ dataset : DataFrame ,
81
+ paramMap : ParamMap ): RandomForestRegressionModel = {
82
+ val categoricalFeatures : Map [Int , Int ] =
83
+ MetadataUtils .getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
84
+ val oldDataset : RDD [LabeledPoint ] = extractLabeledPoints(dataset, paramMap)
82
85
val strategy = getOldStrategy(categoricalFeatures)
83
86
val oldModel = OldRandomForest .trainRegressor(
84
- input , strategy, getNumTrees, getFeaturesPerNodeStr, getSeed.toInt)
85
- RandomForestRegressionModel .fromOld(oldModel)
87
+ oldDataset , strategy, getNumTrees, getFeaturesPerNodeStr, getSeed.toInt)
88
+ RandomForestRegressionModel .fromOld(oldModel, this , paramMap, categoricalFeatures )
86
89
}
87
90
88
- /**
89
- * Create a Strategy instance to use with the old API.
90
- * TODO: Make this protected once we deprecate the old API.
91
- */
92
- private [mllib] def getOldStrategy (categoricalFeatures : Map [Int , Int ]): OldStrategy = {
93
- val strategy = super .getOldStrategy(categoricalFeatures, numClasses = 0 )
94
- strategy.algo = OldAlgo .Regression
95
- strategy.impurity = getOldImpurity
96
- strategy
91
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
92
+ private [ml] def getOldStrategy (categoricalFeatures : Map [Int , Int ]): OldStrategy = {
93
+ super .getOldStrategy(categoricalFeatures, numClasses = 0 , OldAlgo .Regression , getOldImpurity)
97
94
}
98
95
}
99
96
100
97
object RandomForestRegressor {
101
-
102
98
/** Accessor for supported impurity settings */
103
99
final val supportedImpurities : Array [String ] = TreeRegressorParams .supportedImpurities
104
100
@@ -107,51 +103,66 @@ object RandomForestRegressor {
107
103
}
108
104
109
105
/**
106
+ * :: AlphaComponent ::
107
+ *
110
108
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest ]] model for regression.
111
109
* It supports both continuous and categorical features.
112
110
* @param trees Decision trees in the ensemble.
113
111
*/
114
- class RandomForestRegressionModel (val trees : Array [DecisionTreeRegressionModel ])
115
- extends TreeEnsembleModel with Serializable with Saveable {
112
+ @ AlphaComponent
113
+ final class RandomForestRegressionModel private [ml] (
114
+ override val parent : RandomForestRegressor ,
115
+ override val fittingParamMap : ParamMap ,
116
+ val trees : Array [DecisionTreeRegressionModel ])
117
+ extends PredictionModel [Vector , RandomForestRegressionModel ]
118
+ with TreeEnsembleModel with Serializable {
116
119
117
120
require(numTrees > 0 , " RandomForestRegressionModel requires at least 1 tree." )
118
121
119
122
override def getTrees : Array [DecisionTreeModel ] = trees.asInstanceOf [Array [DecisionTreeModel ]]
120
123
124
+ // Note: We may add support for weights (based on tree performance) later on.
121
125
override lazy val getTreeWeights : Array [Double ] = Array .fill[Double ](numTrees)(1.0 )
122
126
123
127
override def predict (features : Vector ): Double = {
128
+ // TODO: Override transform() to broadcast model.
129
+ // TODO: When we add a generic Bagging class, handle transform there. Skip single-Row predict.
124
130
// Predict average of tree predictions.
125
131
// Ignore the weights since all are 1.0 for now.
126
132
trees.map(_.predict(features)).sum / numTrees
127
133
}
128
134
129
- override def toString : String = {
130
- s " RandomForestRegressionModel with $numTrees trees "
135
+ override protected def copy (): RandomForestRegressionModel = {
136
+ val m = new RandomForestRegressionModel (parent, fittingParamMap, trees)
137
+ Params .inheritValues(this .extractParamMap(), this , m)
138
+ m
131
139
}
132
140
133
- override def save ( sc : SparkContext , path : String ) : Unit = {
134
- this .toOld.save(sc, path)
141
+ override def toString : String = {
142
+ s " RandomForestRegressionModel with $numTrees trees "
135
143
}
136
144
137
- override protected def formatVersion : String = OldRandomForestModel .formatVersion
138
-
139
- /** Convert to a model in the old API */
140
- private [mllib] def toOld : OldRandomForestModel = {
145
+ /** (private[ml]) Convert to a model in the old API */
146
+ private [ml] def toOld : OldRandomForestModel = {
141
147
new OldRandomForestModel (OldAlgo .Regression , trees.map(_.toOld))
142
148
}
143
149
}
144
150
145
- object RandomForestRegressionModel extends Loader [ RandomForestRegressionModel ] {
151
+ private [ml] object RandomForestRegressionModel {
146
152
147
- override def load (sc : SparkContext , path : String ): RandomForestRegressionModel = {
148
- RandomForestRegressionModel .fromOld(OldRandomForestModel .load(sc, path))
149
- }
150
-
151
- private [mllib] def fromOld (oldModel : OldRandomForestModel ): RandomForestRegressionModel = {
153
+ /** (private[ml]) Convert a model from the old API */
154
+ def fromOld (
155
+ oldModel : OldRandomForestModel ,
156
+ parent : RandomForestRegressor ,
157
+ fittingParamMap : ParamMap ,
158
+ categoricalFeatures : Map [Int , Int ]): RandomForestRegressionModel = {
152
159
require(oldModel.algo == OldAlgo .Regression ,
153
160
s " Cannot convert non-regression RandomForestModel (old API) to " +
154
161
s " RandomForestRegressionModel (new API). Algo is: ${oldModel.algo}" )
155
- new RandomForestRegressionModel (oldModel.trees.map(DecisionTreeRegressionModel .fromOld))
162
+ val trees = oldModel.trees.map { tree =>
163
+ // parent, fittingParamMap for each tree is null since there are no good ways to set these.
164
+ DecisionTreeRegressionModel .fromOld(tree, null , null , categoricalFeatures)
165
+ }
166
+ new RandomForestRegressionModel (parent, fittingParamMap, trees)
156
167
}
157
168
}
0 commit comments