Skip to content

[SPARK-10020][MLlib]: ML model broadcasts should be stored in private vars: mllib GeneralizedLinearModel #8249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.mllib.regression

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.mllib.feature.StandardScaler
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.rdd.RDD
Expand All @@ -39,6 +40,8 @@ import org.apache.spark.storage.StorageLevel
abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double)
extends Serializable {

private var bcWeights: Option[Broadcast[Vector]] = None

/**
* Predict the result given a data point and the weights learned.
*
Expand All @@ -57,11 +60,17 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
def predict(testData: RDD[Vector]): RDD[Double] = {
// A small optimization to avoid serializing the entire model. Only the weightsMatrix
// and intercept is needed.
val localWeights = weights
val bcWeights = testData.context.broadcast(localWeights)
bcWeights match {
case None => {
val localWeights = weights
bcWeights = Some(testData.context.broadcast(localWeights))
}
case _ =>
}
val localBcWeights = bcWeights
val localIntercept = intercept
testData.mapPartitions { iter =>
val w = bcWeights.value
val w = localBcWeights.get.value
iter.map(v => predictPoint(v, w, localIntercept))
}
}
Expand Down