Skip to content

[Spark-13784][ML][WIP] Model export/import for spark.ml: RandomForests #12023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@

package org.apache.spark.ml.classification

import org.apache.hadoop.fs.Path
import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.RandomForestModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
Expand All @@ -43,7 +48,7 @@ import org.apache.spark.sql.functions._
final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
with RandomForestClassifierParams with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfc"))
Expand Down Expand Up @@ -120,7 +125,7 @@ final class RandomForestClassifier @Since("1.4.0") (

@Since("1.4.0")
@Experimental
object RandomForestClassifier {
object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] {
/** Accessor for supported impurity settings: entropy, gini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
Expand All @@ -129,14 +134,18 @@ object RandomForestClassifier {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies

@Since("2.0.0")
override def load(path: String): RandomForestClassifier = super.load(path)
}

/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification.
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
* @param _trees Decision trees in the ensemble.
*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update your IntelliJ settings: Editor -> Code Style -> Scala -> ScalaDoc tab -> uncheck "Use scaladoc indent for leading asterisk."

You will probably need to manually correct these indentation changes for this PR, though.

* @param _trees Decision trees in the ensemble.
* Warning: These have null parents.
*/
@Since("1.4.0")
Expand All @@ -147,13 +156,14 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
with RandomForestClassifierParams with TreeEnsembleModel with MLWritable with Serializable {

require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")

/**
* Construct a random forest classification model, with all trees weighted equally.
* @param trees Component trees
*
* @param trees Component trees
*/
private[ml] def this(
trees: Array[DecisionTreeClassificationModel],
Expand Down Expand Up @@ -240,12 +250,66 @@ final class RandomForestClassificationModel private[ml] (
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
}

@Since("2.0.0")
override def write: MLWriter =
new RandomForestClassificationModel.RandomForestClassificationModelWriter(this)

@Since("2.0.0")
override def read: MLReader =
new RandomForestClassificationModel.RandomForestClassificationModelReader(this)
}

private[ml] object RandomForestClassificationModel {
@Since("2.0.0")
object RandomForestClassificationModel extends MLReadable[RandomForestClassificationModel] {


@Since("2.0.0")
override def load(path: String): RandomForestClassificationModel = super.load(path)

private[RandomForestClassificationModel]
class RandomForestClassificationModelWriter(instance: RandomForestClassificationModel)
extends MLWriter {

override protected def saveImpl(path: String): Unit = {
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numClasses" -> instance.numClasses)
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
for(treeIndex <- 1 to instance.getNumTrees) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is writing each tree separately. Based on our JIRA discussion, it would be better to write all trees in a single DataFrame. You could create an RDD of trees, then flatMap that to an RDD of NodeData, and then convert that to a DataFrame.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley Sorry for the confusion. In the JIRA discussion, I meant every tree would be stored in a single dataframe. I guess I can work on storing all of them in a single dataframe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks--that should be much more efficient.

val (nodeData, _) = NodeData.build(instance.trees(treeIndex).rootNode, treeIndex, 0)
val dataPath = new Path(path, "data" + treeIndex).toString
sqlContext.createDataFrame(nodeData).write.parquet(dataPath)
}
}
}

private class RandomForestClassificationModelReader(instance: RandomForestClassificationModel)
extends MLReader[RandomForestClassificationModel] {

/** Checked against metadata when loading model */
private val className = classOf[RandomForestClassificationModel].getName

override def load(path: String): RandomForestClassificationModel = {
implicit val format = DefaultFormats
implicit val root: Array[DecisionTreeClassificationModel] = _
var metadata: DefaultParamsReader.Metadata = null
for(treeIndex <- 1 to instance.getNumTrees) {
val dataPath = new Path(path, "data" + treeIndex).toString
metadata = DefaultParamsReader.loadMetadata(dataPath, sc, className)
root :+ loadTreeNodes(path, metadata, sqlContext)
}
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val model = new RandomForestClassificationModel(metadata.uid, root, numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}


/** (private[ml]) Convert a model from the old API */
def fromOld(
private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
categoricalFeatures: Map[Int, Int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@

package org.apache.spark.ml.regression

import org.apache.hadoop.fs.Path
import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.RandomForestModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
Expand All @@ -41,7 +46,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
with RandomForestParams with TreeRegressorParams {
with RandomForestRegressorParams with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfr"))
Expand Down Expand Up @@ -108,7 +113,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val

@Since("1.4.0")
@Experimental
object RandomForestRegressor {
object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
Expand All @@ -117,13 +122,18 @@ object RandomForestRegressor {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies

@Since("2.0.0")
override def load(path: String): RandomForestRegressor = super.load(path)

}

/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
* @param _trees Decision trees in the ensemble.
*
* @param _trees Decision trees in the ensemble.
* @param numFeatures Number of features used by this model
*/
@Since("1.4.0")
Expand All @@ -133,13 +143,14 @@ final class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
with TreeEnsembleModel with Serializable {
with RandomForestRegressorParams with TreeEnsembleModel with MLWritable with Serializable {

require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")

/**
* Construct a random forest regression model, with all trees weighted equally.
* @param trees Component trees
*
* @param trees Component trees
*/
private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
this(Identifiable.randomUID("rfr"), trees, numFeatures)
Expand Down Expand Up @@ -199,21 +210,71 @@ final class RandomForestRegressionModel private[ml] (
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
}

@Since("2.0.0")
override def write: MLWriter =
new RandomForestRegressionModel.RandomForestRegressionModelWriter(this)

@Since("2.0.0")
override def read: MLReader[RandomForestRegressionModel] =
new RandomForestRegressionModel.RandomForestRegressionModelReader(this)
}

private[ml] object RandomForestRegressionModel {

/** (private[ml]) Convert a model from the old API */
def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestRegressor,
categoricalFeatures: Map[Int, Int],
numFeatures: Int = -1): RandomForestRegressionModel = {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
@Since("2.0.0")
object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] {

@Since("2.0.0")
override def load(path: String): RandomForestRegressionModel = super.load(path)

private[RandomForestRegressionModel]
class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel)
extends MLWriter {

override protected def saveImpl(path: String): Unit = {
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures)
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
for ( treeIndex <- 1 to instance.getNumTrees) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share this implementation between all tree ensembles? How about making a generic method with this signature in TreeEnsembleParams?

private[ml] def saveImpl(trees: Array[DecisionTreeModel], path: String, sql: SQLContext)

and a similar loadImpl method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley Should saveImpl and load methods in RandomForestClassifier and Regressor over ride this method? I assume loadImpl will also have same signature.

val (nodeData, _) = NodeData.build(instance.trees(treeIndex).rootNode, treeIndex, 0)
val dataPath = new Path(path, "data" + treeIndex).toString
sqlContext.createDataFrame(nodeData).write.parquet(dataPath)
}
}
}

private class RandomForestRegressionModelReader(instance: RandomForestRegressionModel)
extends MLReader[RandomForestRegressionModel] {

/** Checked against metadata when loading model */
private val className = classOf[RandomForestRegressionModel].getName

override def load(path: String): RandomForestRegressionModel = {
implicit val format = DefaultFormats
implicit val root: Array[DecisionTreeRegressionModel] = _
var metadata: DefaultParamsReader.Metadata = null
for ( treeIndex <- 1 to instance.getNumTrees) {
val dataPath = new Path(path, "data" + treeIndex).toString
metadata = DefaultParamsReader.loadMetadata(dataPath, sc, className)
root :+ loadTreeNodes(path, metadata, sqlContext)
}
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val model = new RandomForestRegressionModel(metadata.uid, root, numFeatures)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

/** (private[ml]) Convert a model from the old API */
private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestRegressor,
categoricalFeatures: Map[Int, Int],
numFeatures: Int = -1): RandomForestRegressionModel = {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
new RandomForestRegressionModel(parent.uid, newTrees, numFeatures)
}
Expand Down
Loading