Skip to content

[SPARK-13784][ML] Persistence for RandomForestClassifier, RandomForestRegressor #12118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ final class GBTClassificationModel private[ml](
extends PredictionModel[Vector, GBTClassificationModel]
with TreeEnsembleModel with Serializable {

require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.")
require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")

Expand Down Expand Up @@ -227,6 +227,9 @@ final class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0
}

/** Number of trees in ensemble */
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
Expand Down Expand Up @@ -272,6 +275,6 @@ private[ml] object GBTClassificationModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

package org.apache.spark.ml.classification

import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
Expand All @@ -43,7 +47,7 @@ import org.apache.spark.sql.functions._
final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
with RandomForestClassifierParams with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfc"))
Expand Down Expand Up @@ -120,7 +124,7 @@ final class RandomForestClassifier @Since("1.4.0") (

@Since("1.4.0")
@Experimental
object RandomForestClassifier {
object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] {
/** Accessor for supported impurity settings: entropy, gini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
Expand All @@ -129,15 +133,19 @@ object RandomForestClassifier {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies

@Since("2.0.0")
override def load(path: String): RandomForestClassifier = super.load(path)
}

/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification.
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
*
* @param _trees Decision trees in the ensemble.
* Warning: These have null parents.
* Warning: These have null parents.
*/
@Since("1.4.0")
@Experimental
Expand All @@ -147,12 +155,14 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
with RandomForestClassificationModelParams with TreeEnsembleModel with MLWritable
with Serializable {

require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")

/**
* Construct a random forest classification model, with all trees weighted equally.
*
* @param trees Component trees
*/
private[ml] def this(
Expand All @@ -165,7 +175,7 @@ final class RandomForestClassificationModel private[ml] (
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]

// Note: We may add support for weights (based on tree performance) later on.
private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)

@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
Expand Down Expand Up @@ -208,6 +218,15 @@ final class RandomForestClassificationModel private[ml] (
}
}

/**
* Number of trees in ensemble
*
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
*/
// TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
Expand All @@ -216,7 +235,7 @@ final class RandomForestClassificationModel private[ml] (

@Since("1.4.0")
override def toString: String = {
s"RandomForestClassificationModel (uid=$uid) with $numTrees trees"
s"RandomForestClassificationModel (uid=$uid) with $getNumTrees trees"
}

/**
Expand All @@ -236,12 +255,69 @@ final class RandomForestClassificationModel private[ml] (
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
}

@Since("2.0.0")
override def write: MLWriter =
new RandomForestClassificationModel.RandomForestClassificationModelWriter(this)
}

private[ml] object RandomForestClassificationModel {
@Since("2.0.0")
object RandomForestClassificationModel extends MLReadable[RandomForestClassificationModel] {

@Since("2.0.0")
override def read: MLReader[RandomForestClassificationModel] =
new RandomForestClassificationModelReader

@Since("2.0.0")
override def load(path: String): RandomForestClassificationModel = super.load(path)

private[RandomForestClassificationModel]
class RandomForestClassificationModelWriter(instance: RandomForestClassificationModel)
extends MLWriter {

override protected def saveImpl(path: String): Unit = {
// Note: numTrees is not currently used, but could be nice to store for fast querying.
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numClasses" -> instance.numClasses,
"numTrees" -> instance.getNumTrees)
EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
}
}

private class RandomForestClassificationModelReader
extends MLReader[RandomForestClassificationModel] {

/** Checked against metadata when loading model */
private val className = classOf[RandomForestClassificationModel].getName
private val treeClassName = classOf[DecisionTreeClassificationModel].getName

override def load(path: String): RandomForestClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)]) =
EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just IMO, maybe check numTrees == trees.length since there's redundant information.

val trees: Array[DecisionTreeClassificationModel] = treesData.map {
case (treeMetadata, root) =>
val tree =
new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(tree, treeMetadata)
tree
}
require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" +
s" trees based on metadata but found ${trees.length} trees.")

val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

/** (private[ml]) Convert a model from the old API */
def fromOld(
/** Convert a model from the old API */
private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
categoricalFeatures: Map[Int, Int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ final class GBTRegressionModel private[ml](
extends PredictionModel[Vector, GBTRegressionModel]
with TreeEnsembleModel with Serializable {

require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.")
require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")

Expand Down Expand Up @@ -213,6 +213,9 @@ final class GBTRegressionModel private[ml](
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}

/** Number of trees in ensemble */
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): GBTRegressionModel = {
copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
Expand Down Expand Up @@ -258,6 +261,6 @@ private[ml] object GBTRegressionModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
Loading