Skip to content

[WIP][SPARK-25044][SQL] Address translation of LMF closure primitive args to Object in Scala 2.12 #22063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
// Close ServerSocket on task completion.
serverSocket.foreach { server =>
context.addTaskCompletionListener(_ => server.close())
context.addTaskCompletionListener[Unit](_ => server.close())
}
val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
if (boundPort == -1) {
Expand Down
4 changes: 2 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
}

protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Any) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really surprised that this worked before...

predict(features.asInstanceOf[FeaturesType])
val predictUDF = udfInternal { features: FeaturesType =>
predict(features)
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.classification

import org.apache.spark.SparkException
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
Expand Down Expand Up @@ -164,8 +164,8 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
var outputData = dataset
var numColsOutput = 0
if (getRawPredictionCol != "") {
val predictRawUDF = udf { (features: Any) =>
Copy link
Contributor

@cloud-fan cloud-fan Aug 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into this, and now I understand why it worked before.

Scala 2.11 somehow can generate type tag for Any, then Spark gets the input schema from type tag Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption. It will fail and input schema will be None, so no type check will be applied later.

I think it makes more sense to specify the type and ask Spark to do type check.

+1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review @cloud-fan , I could really use your input here. That's a good find. It may be that we want to explicitly support UDFs where a schema isn't available -- see below. But I agree I'd rather not. It gets kind of messy though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skonto @lrytz this might be of interest. Don't think it's a Scala issue per se but just checking if that behavior change makes sense.

Copy link
Contributor

@skonto skonto Aug 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adriaanm @lrytz any more info?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea, but in any case the new version seems nicer :-) Both 2.11 and 2.12 will happily generate a typeTag for Any, though, so that wouldn't immediately explain it. To see what was actually inferred, you could compile with -Xprint:typer (ideally after a full compile and then just making this file recompile incrementally).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologize, that's my mistake. In the end it isn't related to TypeTags for Any and that is not a difference. Thanks for your input, I think we are close.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem! Happy to help with the 2.12 upgrade.

predictRaw(features.asInstanceOf[FeaturesType])
val predictRawUDF = udfInternal { features: FeaturesType =>
predictRaw(features)
}
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
numColsOutput += 1
Expand All @@ -174,8 +174,8 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
val predUDF = if (getRawPredictionCol != "") {
udf(raw2prediction _).apply(col(getRawPredictionCol))
} else {
val predictUDF = udf { (features: Any) =>
predict(features.asInstanceOf[FeaturesType])
val predictUDF = udfInternal { features: FeaturesType =>
predict(features)
}
predictUDF(col(getFeaturesCol))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ class GBTClassificationModel private[ml](

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
val predictUDF = udfInternal { features: Vector =>
bcastModel.value.predict(features)
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ abstract class ProbabilisticClassificationModel[
var outputData = dataset
var numColsOutput = 0
if ($(rawPredictionCol).nonEmpty) {
val predictRawUDF = udf { (features: Any) =>
predictRaw(features.asInstanceOf[FeaturesType])
val predictRawUDF = udfInternal { features: FeaturesType =>
predictRaw(features)
}
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
numColsOutput += 1
Expand All @@ -123,8 +123,8 @@ abstract class ProbabilisticClassificationModel[
val probUDF = if ($(rawPredictionCol).nonEmpty) {
udf(raw2probability _).apply(col($(rawPredictionCol)))
} else {
val probabilityUDF = udf { (features: Any) =>
predictProbability(features.asInstanceOf[FeaturesType])
val probabilityUDF = udfInternal { features: FeaturesType =>
predictProbability(features)
}
probabilityUDF(col($(featuresCol)))
}
Expand All @@ -137,8 +137,8 @@ abstract class ProbabilisticClassificationModel[
} else if ($(probabilityCol).nonEmpty) {
udf(probability2prediction _).apply(col($(probabilityCol)))
} else {
val predictUDF = udf { (features: Any) =>
predict(features.asInstanceOf[FeaturesType])
val predictUDF = udfInternal { features: FeaturesType =>
predict(features)
}
predictUDF(col($(featuresCol)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ class RandomForestClassificationModel private[ml] (

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
val predictUDF = udfInternal { features: Vector =>
bcastModel.value.predict(features)
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme
}
}

val hashFeatures = udf { row: Row =>
val hashFeatures = udfInternal { row: Row =>
val map = new OpenHashMap[Int, Double]()
localInputCols.foreach { colName =>
val fieldIndex = row.fieldIndex(colName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.functions.{col, udfInternal}
import org.apache.spark.sql.types.{ArrayType, StructType}

/**
Expand Down Expand Up @@ -95,7 +95,7 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
val outputSchema = transformSchema(dataset.schema)
val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
// TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion.
val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML }
val t = udfInternal { terms: Seq[_] => hashingTF.transform(terms).asML }
val metadata = outputSchema($(outputCol)).metadata
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext
val featureEncoders = getFeatureEncoders(inputFeatures)
val featureAttrs = getFeatureAttrs(inputFeatures)

def interactFunc = udf { row: Row =>
def interactFunc = udfInternal { row: Row =>
var indices = ArrayBuilder.make[Int]
var values = ArrayBuilder.make[Double]
var size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
case VectorAssembler.ERROR_INVALID => (dataset, false)
}
// Data transformation.
val assembleFunc = udf { r: Row =>
val assembleFunc = udfInternal { r: Row =>
VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
}.asNondeterministic()
val args = $(inputCols).map { c =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo
* Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
* out of integer range or contains a fractional part.
*/
protected[recommendation] val checkedCast = udf { (n: Any) =>
protected[recommendation] val checkedCast = udfInternal { n: Any =>
n match {
case v: Int => v // Avoid unnecessary casting
case v: Number =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._

Expand Down Expand Up @@ -245,8 +244,8 @@ class GBTRegressionModel private[ml](

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
val predictUDF = udfInternal { features: Vector =>
bcastModel.value.predict(features)
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ class RandomForestRegressionModel private[ml] (

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
val predictUDF = udfInternal { features: Vector =>
bcastModel.value.predict(features)
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
Expand Down
5 changes: 5 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ object MimaExcludes {

// Exclude rules for 2.4.x
lazy val v24excludes = v23excludes ++ Seq(
// [SPARK-25044] Address translation of LMF closure primitive args to Object in Scala 2.12
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"),

// [SPARK-24296][CORE] Replicate large blocks as a stream.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.this"),
// [SPARK-23528] Add numIter to ClusteringSummary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -932,15 +932,6 @@ trait ScalaReflection {
tpe.dealias.erasure.typeSymbol.asClass.fullName
}

/**
* Returns classes of input parameters of scala function object.
*/
def getParameterTypes(func: AnyRef): Seq[Class[_]] = {
val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge)
assert(methods.length == 1)
methods.head.getParameterTypes
}

/**
* Returns the parameter names and types for the primary constructor of this type.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2149,28 +2149,36 @@ class Analyzer(

case p => p transformExpressionsUp {

case udf @ ScalaUDF(func, _, inputs, _, _, _, _) =>
val parameterTypes = ScalaReflection.getParameterTypes(func)
assert(parameterTypes.length == inputs.length)

// TODO: skip null handling for not-nullable primitive inputs after we can completely
// trust the `nullable` information.
// (cls, expr) => cls.isPrimitive && expr.nullable
val needsNullCheck = (cls: Class[_], expr: Expression) =>
cls.isPrimitive && !expr.isInstanceOf[KnownNotNull]
val inputsNullCheck = parameterTypes.zip(inputs)
.filter { case (cls, expr) => needsNullCheck(cls, expr) }
.map { case (_, expr) => IsNull(expr) }
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
// as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning
// branch of `If` will be called if any of these checked inputs is null. Thus we can
// prevent this rule from being applied repeatedly.
val newInputs = parameterTypes.zip(inputs).map{ case (cls, expr) =>
if (needsNullCheck(cls, expr)) KnownNotNull(expr) else expr }
inputsNullCheck
.map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs)))
.getOrElse(udf)
case udf@ScalaUDF(func, _, inputs, _, _, _, _, nullableTypes) =>

if (nullableTypes.isEmpty) {
// If no nullability info is available, do nothing. No fields will be specially
// checked for null in the plan. If nullability info is incorrect, the results
// of the UDF could be wrong.
udf

} else {
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
assert(nullableTypes.length == inputs.length)

// TODO: skip null handling for not-nullable primitive inputs after we can completely
// trust the `nullable` information.
val inputsNullCheck = nullableTypes.zip(inputs)
.filter { case (nullable, _) => !nullable }
.map { case (_, expr) => IsNull(expr) }
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
// as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning
// branch of `If` will be called if any of these checked inputs is null. Thus we can
// prevent this rule from being applied repeatedly.
val newInputs = nullableTypes.zip(inputs).map { case (nullable, expr) =>
if (nullable) expr else KnownNotNull(expr)
}
inputsNullCheck
.map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs)))
.getOrElse(udf)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.types.DataType
* @param nullable True if the UDF can return null value.
* @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result
* each time it is invoked with a particular input.
* @param nullableTypes which of the inputTypes are nullable (i.e. not primitive)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach here is to capture at registration time whether the arg types are primitive, or nullable. Not a great way to record this, but might be the least hack for now

*/
case class ScalaUDF(
function: AnyRef,
Expand All @@ -47,7 +48,8 @@ case class ScalaUDF(
inputTypes: Seq[DataType] = Nil,
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true)
udfDeterministic: Boolean = true,
nullableTypes: Seq[Boolean] = Nil)
extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {

// The constructor for SPARK 2.1 and 2.2
Expand All @@ -58,7 +60,8 @@ case class ScalaUDF(
inputTypes: Seq[DataType],
udfName: Option[String]) = {
this(
function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true)
function, dataType, children, inputTypes, udfName, nullable = true,
udfDeterministic = true, nullableTypes = Nil)
}

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,23 +261,6 @@ class ScalaReflectionSuite extends SparkFunSuite {
}
}

test("get parameter type from a function object") {
val primitiveFunc = (i: Int, j: Long) => "x"
val primitiveTypes = getParameterTypes(primitiveFunc)
assert(primitiveTypes.forall(_.isPrimitive))
assert(primitiveTypes === Seq(classOf[Int], classOf[Long]))

val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x"
val boxedTypes = getParameterTypes(boxedFunc)
assert(boxedTypes.forall(!_.isPrimitive))
assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long]))

val anyFunc = (i: Any, j: AnyRef) => "x"
val anyTypes = getParameterTypes(anyFunc)
assert(anyTypes.forall(!_.isPrimitive))
assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object]))
}

test("SPARK-15062: Get correct serializer for List[_]") {
val list = List(1, 2, 3)
val serializer = serializerFor[List[Int]](BoundReference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,15 @@ class AnalysisSuite extends AnalysisTest with Matchers {
checkUDF(udf1, expected1)

// only primitive parameter needs special null handling
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil)
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
nullableTypes = true :: false :: Nil)
val expected2 =
If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
checkUDF(udf2, expected2)

// special null handling should apply to all primitive parameters
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil)
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
nullableTypes = false :: false :: Nil)
val expected3 = If(
IsNull(short) || IsNull(double),
nullResult,
Expand All @@ -335,7 +337,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
val udf4 = ScalaUDF(
(s: Short, d: Double) => "x",
StringType,
short :: double.withNullability(false) :: Nil)
short :: double.withNullability(false) :: Nil,
nullableTypes = false :: false :: Nil)
val expected4 = If(
IsNull(short),
nullResult,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, udfInternal}
import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}

Expand Down Expand Up @@ -375,7 +376,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
import org.apache.spark.sql.functions.{rand, udf}
val c = Column(col)
val r = rand(seed)
val f = udf { (stratum: Any, x: Double) =>
val f = udfInternal { (stratum: Any, x: Double) =>
x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0)
}
df.filter(f(c, r))
Expand Down
Loading