Skip to content

Commit afed67b

Browse files
committed
Updates per code review
1 parent cd07aff commit afed67b

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
298298
EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
299299
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
300300
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
301+
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
301302

302303
val trees: Array[DecisionTreeClassificationModel] = treesData.map {
303304
case (treeMetadata, root) =>
@@ -306,6 +307,8 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
306307
DefaultParamsReader.getAndSetParams(tree, treeMetadata)
307308
tree
308309
}
310+
require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" +
311+
s" trees based on metadata but found ${trees.length} trees.")
309312

310313
val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses)
311314
DefaultParamsReader.getAndSetParams(model, metadata)

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor
131131
* :: Experimental ::
132132
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
133133
* It supports both continuous and categorical features.
134-
*
135-
* @param _trees Decision trees in the ensemble.
134+
*
135+
* @param _trees Decision trees in the ensemble.
136136
* @param numFeatures Number of features used by this model
137137
*/
138138
@Since("1.4.0")
@@ -148,8 +148,8 @@ final class RandomForestRegressionModel private[ml] (
148148

149149
/**
150150
* Construct a random forest regression model, with all trees weighted equally.
151-
*
152-
* @param trees Component trees
151+
*
152+
* @param trees Component trees
153153
*/
154154
private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
155155
this(Identifiable.randomUID("rfr"), trees, numFeatures)
@@ -251,13 +251,16 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode
251251
val (metadata: Metadata, treesData: Array[(Metadata, Node)]) =
252252
EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
253253
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
254+
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
254255

255256
val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
256257
val tree =
257258
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
258259
DefaultParamsReader.getAndSetParams(tree, treeMetadata)
259260
tree
260261
}
262+
require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" +
263+
s" trees based on metadata but found ${trees.length} trees.")
261264

262265
val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures)
263266
DefaultParamsReader.getAndSetParams(model, metadata)

mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -452,11 +452,11 @@ private[ml] object EnsembleModelReadWrite {
452452
}
453453

454454
/**
455-
* Info for one [[Node]] in a tree ensemble
456-
*
457-
* @param treeID Tree index
458-
* @param nodeData Data for this node
459-
*/
455+
* Info for one [[Node]] in a tree ensemble
456+
*
457+
* @param treeID Tree index
458+
* @param nodeData Data for this node
459+
*/
460460
case class EnsembleNodeData(
461461
treeID: Int,
462462
nodeData: NodeData)

0 commit comments

Comments
 (0)