Skip to content

[SPARK-10015][MLlib]: ML model broadcasts should be stored in private vars: spark.ml tree ensembles #8243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
Expand Down Expand Up @@ -175,14 +176,20 @@ final class GBTClassificationModel(
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")

private var bcastModel: Option[Broadcast[GBTClassificationModel]] = None

override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]

override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
bcastModel match {
case None => bcastModel = Some(dataset.sqlContext.sparkContext.broadcast(this))
case _ =>
}
val lclBcastModel = bcastModel
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
lclBcastModel.get.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.Experimental
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
Expand Down Expand Up @@ -132,6 +133,8 @@ final class RandomForestClassificationModel private[ml] (

require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")

private var bcastModel: Option[Broadcast[RandomForestClassificationModel]] = None

/**
* Construct a random forest classification model, with all trees weighted equally.
* @param trees Component trees
Expand All @@ -150,9 +153,13 @@ final class RandomForestClassificationModel private[ml] (
override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
bcastModel match {
case None => bcastModel = Some(dataset.sqlContext.sparkContext.broadcast(this))
case _ =>
}
val lclBcastModel = bcastModel
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
lclBcastModel.get.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams}
Expand Down Expand Up @@ -165,14 +166,20 @@ final class GBTRegressionModel(
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")

private var bcastModel: Option[Broadcast[GBTRegressionModel]] = None

override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]

override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
bcastModel match {
case None => bcastModel = Some(dataset.sqlContext.sparkContext.broadcast(this))
case _ =>
}
val lclBcastModel = bcastModel
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
lclBcastModel.get.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.regression

import org.apache.spark.annotation.Experimental
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
Expand Down Expand Up @@ -121,6 +122,8 @@ final class RandomForestRegressionModel private[ml] (

require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")

private var bcastModel: Option[Broadcast[RandomForestRegressionModel]] = None

/**
* Construct a random forest regression model, with all trees weighted equally.
* @param trees Component trees
Expand All @@ -136,9 +139,13 @@ final class RandomForestRegressionModel private[ml] (
override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
bcastModel match {
case None => bcastModel = Some(dataset.sqlContext.sparkContext.broadcast(this))
case _ =>
}
val lclBcastModel = bcastModel
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
lclBcastModel.get.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
Expand Down