-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[Spark-13784][ML][WIP] Model export/import for spark.ml: RandomForests #12023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,16 @@ | |
|
||
package org.apache.spark.ml.classification | ||
|
||
import org.apache.hadoop.fs.Path | ||
import org.json4s.{DefaultFormats, JObject} | ||
import org.json4s.JsonDSL._ | ||
|
||
import org.apache.spark.annotation.{Experimental, Since} | ||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} | ||
import org.apache.spark.ml.tree._ | ||
import org.apache.spark.ml.tree.RandomForestModelReadWrite._ | ||
import org.apache.spark.ml.tree.impl.RandomForest | ||
import org.apache.spark.ml.util.{Identifiable, MetadataUtils} | ||
import org.apache.spark.ml.util._ | ||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} | ||
import org.apache.spark.mllib.regression.LabeledPoint | ||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} | ||
|
@@ -43,7 +48,7 @@ import org.apache.spark.sql.functions._ | |
final class RandomForestClassifier @Since("1.4.0") ( | ||
@Since("1.4.0") override val uid: String) | ||
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] | ||
with RandomForestParams with TreeClassifierParams { | ||
with RandomForestClassifierParams with DefaultParamsWritable { | ||
|
||
@Since("1.4.0") | ||
def this() = this(Identifiable.randomUID("rfc")) | ||
|
@@ -120,7 +125,7 @@ final class RandomForestClassifier @Since("1.4.0") ( | |
|
||
@Since("1.4.0") | ||
@Experimental | ||
object RandomForestClassifier { | ||
object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] { | ||
/** Accessor for supported impurity settings: entropy, gini */ | ||
@Since("1.4.0") | ||
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities | ||
|
@@ -129,14 +134,18 @@ object RandomForestClassifier { | |
@Since("1.4.0") | ||
final val supportedFeatureSubsetStrategies: Array[String] = | ||
RandomForestParams.supportedFeatureSubsetStrategies | ||
|
||
@Since("2.0.0") | ||
override def load(path: String): RandomForestClassifier = super.load(path) | ||
} | ||
|
||
/** | ||
* :: Experimental :: | ||
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. | ||
* It supports both binary and multiclass labels, as well as both continuous and categorical | ||
* features. | ||
* @param _trees Decision trees in the ensemble. | ||
* | ||
* @param _trees Decision trees in the ensemble. | ||
* Warning: These have null parents. | ||
*/ | ||
@Since("1.4.0") | ||
|
@@ -147,13 +156,14 @@ final class RandomForestClassificationModel private[ml] ( | |
@Since("1.6.0") override val numFeatures: Int, | ||
@Since("1.5.0") override val numClasses: Int) | ||
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] | ||
with TreeEnsembleModel with Serializable { | ||
with RandomForestClassifierParams with TreeEnsembleModel with MLWritable with Serializable { | ||
|
||
require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") | ||
|
||
/** | ||
* Construct a random forest classification model, with all trees weighted equally. | ||
* @param trees Component trees | ||
* | ||
* @param trees Component trees | ||
*/ | ||
private[ml] def this( | ||
trees: Array[DecisionTreeClassificationModel], | ||
|
@@ -240,12 +250,66 @@ final class RandomForestClassificationModel private[ml] ( | |
private[ml] def toOld: OldRandomForestModel = { | ||
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) | ||
} | ||
|
||
@Since("2.0.0") | ||
override def write: MLWriter = | ||
new RandomForestClassificationModel.RandomForestClassificationModelWriter(this) | ||
|
||
@Since("2.0.0") | ||
override def read: MLReader = | ||
new RandomForestClassificationModel.RandomForestClassificationModelReader(this) | ||
} | ||
|
||
private[ml] object RandomForestClassificationModel { | ||
@Since("2.0.0") | ||
object RandomForestClassificationModel extends MLReadable[RandomForestClassificationModel] { | ||
|
||
|
||
@Since("2.0.0") | ||
override def load(path: String): RandomForestClassificationModel = super.load(path) | ||
|
||
private[RandomForestClassificationModel] | ||
class RandomForestClassificationModelWriter(instance: RandomForestClassificationModel) | ||
extends MLWriter { | ||
|
||
override protected def saveImpl(path: String): Unit = { | ||
val extraMetadata: JObject = Map( | ||
"numFeatures" -> instance.numFeatures, | ||
"numClasses" -> instance.numClasses) | ||
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) | ||
for(treeIndex <- 1 to instance.getNumTrees) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is writing each tree separately. Based on our JIRA discussion, it would be better to write all trees in a single DataFrame. You could create an RDD of trees, then flatMap that to an RDD of NodeData, and then convert that to a DataFrame. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jkbradley Sorry for the confusion. In the JIRA discussion, I meant every tree would be stored in a single dataframe. I guess I can work on storing all of them in a single dataframe. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok thanks--that should be much more efficient. |
||
val (nodeData, _) = NodeData.build(instance.trees(treeIndex).rootNode, treeIndex, 0) | ||
val dataPath = new Path(path, "data" + treeIndex).toString | ||
sqlContext.createDataFrame(nodeData).write.parquet(dataPath) | ||
} | ||
} | ||
} | ||
|
||
private class RandomForestClassificationModelReader(instance: RandomForestClassificationModel) | ||
extends MLReader[RandomForestClassificationModel] { | ||
|
||
/** Checked against metadata when loading model */ | ||
private val className = classOf[RandomForestClassificationModel].getName | ||
|
||
override def load(path: String): RandomForestClassificationModel = { | ||
implicit val format = DefaultFormats | ||
implicit val root: Array[DecisionTreeClassificationModel] = _ | ||
var metadata: DefaultParamsReader.Metadata = null | ||
for(treeIndex <- 1 to instance.getNumTrees) { | ||
val dataPath = new Path(path, "data" + treeIndex).toString | ||
metadata = DefaultParamsReader.loadMetadata(dataPath, sc, className) | ||
root :+ loadTreeNodes(path, metadata, sqlContext) | ||
} | ||
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] | ||
val numClasses = (metadata.metadata \ "numClasses").extract[Int] | ||
val model = new RandomForestClassificationModel(metadata.uid, root, numFeatures, numClasses) | ||
DefaultParamsReader.getAndSetParams(model, metadata) | ||
model | ||
} | ||
} | ||
|
||
|
||
/** (private[ml]) Convert a model from the old API */ | ||
def fromOld( | ||
private[ml] def fromOld( | ||
oldModel: OldRandomForestModel, | ||
parent: RandomForestClassifier, | ||
categoricalFeatures: Map[Int, Int], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,12 +17,17 @@ | |
|
||
package org.apache.spark.ml.regression | ||
|
||
import org.apache.hadoop.fs.Path | ||
import org.json4s.{DefaultFormats, JObject} | ||
import org.json4s.JsonDSL._ | ||
|
||
import org.apache.spark.annotation.{Experimental, Since} | ||
import org.apache.spark.ml.{PredictionModel, Predictor} | ||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} | ||
import org.apache.spark.ml.tree._ | ||
import org.apache.spark.ml.tree.RandomForestModelReadWrite._ | ||
import org.apache.spark.ml.tree.impl.RandomForest | ||
import org.apache.spark.ml.util.{Identifiable, MetadataUtils} | ||
import org.apache.spark.ml.util._ | ||
import org.apache.spark.mllib.linalg.Vector | ||
import org.apache.spark.mllib.regression.LabeledPoint | ||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} | ||
|
@@ -41,7 +46,7 @@ import org.apache.spark.sql.functions._ | |
@Experimental | ||
final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) | ||
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] | ||
with RandomForestParams with TreeRegressorParams { | ||
with RandomForestRegressorParams with DefaultParamsWritable { | ||
|
||
@Since("1.4.0") | ||
def this() = this(Identifiable.randomUID("rfr")) | ||
|
@@ -108,7 +113,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val | |
|
||
@Since("1.4.0") | ||
@Experimental | ||
object RandomForestRegressor { | ||
object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{ | ||
/** Accessor for supported impurity settings: variance */ | ||
@Since("1.4.0") | ||
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities | ||
|
@@ -117,13 +122,18 @@ object RandomForestRegressor { | |
@Since("1.4.0") | ||
final val supportedFeatureSubsetStrategies: Array[String] = | ||
RandomForestParams.supportedFeatureSubsetStrategies | ||
|
||
@Since("2.0.0") | ||
override def load(path: String): RandomForestRegressor = super.load(path) | ||
|
||
} | ||
|
||
/** | ||
* :: Experimental :: | ||
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. | ||
* It supports both continuous and categorical features. | ||
* @param _trees Decision trees in the ensemble. | ||
* | ||
* @param _trees Decision trees in the ensemble. | ||
* @param numFeatures Number of features used by this model | ||
*/ | ||
@Since("1.4.0") | ||
|
@@ -133,13 +143,14 @@ final class RandomForestRegressionModel private[ml] ( | |
private val _trees: Array[DecisionTreeRegressionModel], | ||
override val numFeatures: Int) | ||
extends PredictionModel[Vector, RandomForestRegressionModel] | ||
with TreeEnsembleModel with Serializable { | ||
with RandomForestRegressorParams with TreeEnsembleModel with MLWritable with Serializable { | ||
|
||
require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.") | ||
|
||
/** | ||
* Construct a random forest regression model, with all trees weighted equally. | ||
* @param trees Component trees | ||
* | ||
* @param trees Component trees | ||
*/ | ||
private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = | ||
this(Identifiable.randomUID("rfr"), trees, numFeatures) | ||
|
@@ -199,21 +210,71 @@ final class RandomForestRegressionModel private[ml] ( | |
private[ml] def toOld: OldRandomForestModel = { | ||
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) | ||
} | ||
|
||
@Since("2.0.0") | ||
override def write: MLWriter = | ||
new RandomForestRegressionModel.RandomForestRegressionModelWriter(this) | ||
|
||
@Since("2.0.0") | ||
override def read: MLReader[RandomForestRegressionModel] = | ||
new RandomForestRegressionModel.RandomForestRegressionModelReader(this) | ||
} | ||
|
||
private[ml] object RandomForestRegressionModel { | ||
|
||
/** (private[ml]) Convert a model from the old API */ | ||
def fromOld( | ||
oldModel: OldRandomForestModel, | ||
parent: RandomForestRegressor, | ||
categoricalFeatures: Map[Int, Int], | ||
numFeatures: Int = -1): RandomForestRegressionModel = { | ||
require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + | ||
s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") | ||
val newTrees = oldModel.trees.map { tree => | ||
// parent for each tree is null since there is no good way to set this. | ||
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) | ||
@Since("2.0.0") | ||
object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] { | ||
|
||
@Since("2.0.0") | ||
override def load(path: String): RandomForestRegressionModel = super.load(path) | ||
|
||
private[RandomForestRegressionModel] | ||
class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel) | ||
extends MLWriter { | ||
|
||
override protected def saveImpl(path: String): Unit = { | ||
val extraMetadata: JObject = Map( | ||
"numFeatures" -> instance.numFeatures) | ||
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) | ||
for ( treeIndex <- 1 to instance.getNumTrees) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you share this implementation between all tree ensembles? How about making a generic method with this signature in TreeEnsembleParams?
and a similar loadImpl method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jkbradley Should saveImpl and load methods in RandomForestClassifier and Regressor over ride this method? I assume loadImpl will also have same signature. |
||
val (nodeData, _) = NodeData.build(instance.trees(treeIndex).rootNode, treeIndex, 0) | ||
val dataPath = new Path(path, "data" + treeIndex).toString | ||
sqlContext.createDataFrame(nodeData).write.parquet(dataPath) | ||
} | ||
} | ||
} | ||
|
||
private class RandomForestRegressionModelReader(instance: RandomForestRegressionModel) | ||
extends MLReader[RandomForestRegressionModel] { | ||
|
||
/** Checked against metadata when loading model */ | ||
private val className = classOf[RandomForestRegressionModel].getName | ||
|
||
override def load(path: String): RandomForestRegressionModel = { | ||
implicit val format = DefaultFormats | ||
implicit val root: Array[DecisionTreeRegressionModel] = _ | ||
var metadata: DefaultParamsReader.Metadata = null | ||
for ( treeIndex <- 1 to instance.getNumTrees) { | ||
val dataPath = new Path(path, "data" + treeIndex).toString | ||
metadata = DefaultParamsReader.loadMetadata(dataPath, sc, className) | ||
root :+ loadTreeNodes(path, metadata, sqlContext) | ||
} | ||
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] | ||
val model = new RandomForestRegressionModel(metadata.uid, root, numFeatures) | ||
DefaultParamsReader.getAndSetParams(model, metadata) | ||
model | ||
} | ||
} | ||
|
||
/** (private[ml]) Convert a model from the old API */ | ||
private[ml] def fromOld( | ||
oldModel: OldRandomForestModel, | ||
parent: RandomForestRegressor, | ||
categoricalFeatures: Map[Int, Int], | ||
numFeatures: Int = -1): RandomForestRegressionModel = { | ||
require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + | ||
s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") | ||
val newTrees = oldModel.trees.map { tree => | ||
// parent for each tree is null since there is no good way to set this. | ||
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) | ||
} | ||
new RandomForestRegressionModel(parent.uid, newTrees, numFeatures) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update your IntelliJ settings: Editor -> Code Style -> Scala -> ScalaDoc tab -> uncheck "Use scaladoc indent for leading asterisk."
You will probably need to manually correct these indentation changes for this PR, though.