Skip to content

Commit

Permalink
[SPARK-32159][SQL] Fix integration between Aggregator[Array[_], _, _]…
Browse files Browse the repository at this point in the history
… and UnresolvedMapObjects

Context: The fix for SPARK-27296 introduced by apache#25024 allows `Aggregator` objects to appear in queries. This works fine for aggregators with atomic input types, e.g. `Aggregator[Double, _, _]`.

However it can cause a null pointer exception if the input type is `Array[_]`.  This was historically considered an ignorable case for serialization of `UnresolvedMapObjects`, but the new ScalaAggregator class causes these expressions to be serialized over to executors because the resolve-and-bind is being deferred.

### What changes were proposed in this pull request?
A new rule `ResolveEncodersInScalaAgg` that performs the resolution of the expressions contained in the encoders so that properly resolved expressions are serialized over to executors.

### Why are the changes needed?
Applying an aggregator of the form `Aggregator[Array[_], _, _]` using `functions.udaf()` currently causes a null pointer error in Catalyst.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
A unit test has been added that does aggregation with array types for input, buffer, and output. I have done additional testing with my own custom aggregators in the spark REPL.

Closes apache#28983 from erikerlandson/fix-spark-32159.

Authored-by: Erik Erlandson <eerlands@redhat.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
erikerlandson authored and cloud-fan committed Jul 9, 2020
1 parent c5bd073 commit 1cb5bfc
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,13 @@ object MapObjects {
elementType: DataType,
elementNullable: Boolean = true,
customCollectionCls: Option[Class[_]] = None): MapObjects = {
// UnresolvedMapObjects does not serialize its 'function' field.
// If an array expression or array Encoder is not correctly resolved before
// serialization, this exception condition may occur.
require(function != null,
"MapObjects applied with a null function. " +
"Likely cause is failure to resolve an array expression or encoder. " +
"(See UnresolvedMapObjects)")
val loopVar = LambdaVariable("MapObject", elementType, elementNullable)
MapObjects(loopVar, function(loopVar), inputData, customCollectionCls)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -458,7 +460,8 @@ case class ScalaUDAF(
case class ScalaAggregator[IN, BUF, OUT](
children: Seq[Expression],
agg: Aggregator[IN, BUF, OUT],
inputEncoderNR: ExpressionEncoder[IN],
inputEncoder: ExpressionEncoder[IN],
bufferEncoder: ExpressionEncoder[BUF],
nullable: Boolean = true,
isDeterministic: Boolean = true,
mutableAggBufferOffset: Int = 0,
Expand All @@ -469,17 +472,16 @@ case class ScalaAggregator[IN, BUF, OUT](
with ImplicitCastInputTypes
with Logging {

private[this] lazy val inputDeserializer = inputEncoderNR.resolveAndBind().createDeserializer()
private[this] lazy val bufferEncoder =
agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind()
// input and buffer encoders are resolved by ResolveEncodersInScalaAgg
private[this] lazy val inputDeserializer = inputEncoder.createDeserializer()
private[this] lazy val bufferSerializer = bufferEncoder.createSerializer()
private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer()
private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]]
private[this] lazy val outputSerializer = outputEncoder.createSerializer()

def dataType: DataType = outputEncoder.objSerializer.dataType

def inputTypes: Seq[DataType] = inputEncoderNR.schema.map(_.dataType)
def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType)

override lazy val deterministic: Boolean = isDeterministic

Expand Down Expand Up @@ -517,3 +519,18 @@ case class ScalaAggregator[IN, BUF, OUT](

override def nodeName: String = agg.getClass.getSimpleName
}

/**
* An extension rule to resolve encoder expressions from a [[ScalaAggregator]]
*/
object ResolveEncodersInScalaAgg extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.resolved => p
case p => p.transformExpressionsUp {
case agg: ScalaAggregator[_, _, _] =>
agg.copy(
inputEncoder = agg.inputEncoder.resolveAndBind(),
bufferEncoder = agg.bufferEncoder.resolveAndBind())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT](
// This is also used by udf.register(...) when it detects a UserDefinedAggregator
def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = {
val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]]
ScalaAggregator(exprs, aggregator, iEncoder, nullable, deterministic)
val bEncoder = aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]]
ScalaAggregator(exprs, aggregator, iEncoder, bEncoder, nullable, deterministic)
}

override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
import org.apache.spark.sql.execution.command.CommandCheck
import org.apache.spark.sql.execution.datasources._
Expand Down Expand Up @@ -175,6 +176,7 @@ abstract class BaseSessionStateBuilder(
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
new FallBackFileSourceV2(session) +:
ResolveEncodersInScalaAgg +:
new ResolveSessionCatalog(
catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +:
customResolutionRules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner}
import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
import org.apache.spark.sql.execution.command.CommandCheck
import org.apache.spark.sql.execution.datasources._
Expand Down Expand Up @@ -76,6 +77,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
new FallBackFileSourceV2(session) +:
ResolveEncodersInScalaAgg +:
new ResolveSessionCatalog(
catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +:
customResolutionRules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,27 @@ object CountSerDeAgg extends Aggregator[Int, CountSerDeSQL, CountSerDeSQL] {
def outputEncoder: Encoder[CountSerDeSQL] = ExpressionEncoder[CountSerDeSQL]()
}

object ArrayDataAgg extends Aggregator[Array[Double], Array[Double], Array[Double]] {
def zero: Array[Double] = Array(0.0, 0.0, 0.0)
def reduce(s: Array[Double], array: Array[Double]): Array[Double] = {
require(s.length == array.length)
for ( j <- 0 until s.length ) {
s(j) += array(j)
}
s
}
def merge(s1: Array[Double], s2: Array[Double]): Array[Double] = {
require(s1.length == s2.length)
for ( j <- 0 until s1.length ) {
s1(j) += s2(j)
}
s1
}
def finish(s: Array[Double]): Array[Double] = s
def bufferEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]]
def outputEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]]
}

abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._

Expand Down Expand Up @@ -156,20 +177,11 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
(3, null, null)).toDF("key", "value1", "value2")
data2.write.saveAsTable("agg2")

val data3 = Seq[(Seq[Integer], Integer, Integer)](
(Seq[Integer](1, 1), 10, -10),
(Seq[Integer](null), -60, 60),
(Seq[Integer](1, 1), 30, -30),
(Seq[Integer](1), 30, 30),
(Seq[Integer](2), 1, 1),
(null, -10, 10),
(Seq[Integer](2, 3), -1, null),
(Seq[Integer](2, 3), 1, 1),
(Seq[Integer](2, 3, 4), null, 1),
(Seq[Integer](null), 100, -10),
(Seq[Integer](3), null, 3),
(null, null, null),
(Seq[Integer](3), null, null)).toDF("key", "value1", "value2")
val data3 = Seq[(Seq[Double], Int)](
(Seq(1.0, 2.0, 3.0), 0),
(Seq(4.0, 5.0, 6.0), 0),
(Seq(7.0, 8.0, 9.0), 0)
).toDF("data", "dummy")
data3.write.saveAsTable("agg3")

val data4 = Seq[Boolean](true, false, true).toDF("boolvalues")
Expand All @@ -184,6 +196,7 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
spark.udf.register("mydoublesum", udaf(MyDoubleSumAgg))
spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg))
spark.udf.register("longProductSum", udaf(LongProductSumAgg))
spark.udf.register("arraysum", udaf(ArrayDataAgg))
}

override def afterAll(): Unit = {
Expand Down Expand Up @@ -354,6 +367,12 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil)
}

test("SPARK-32159: array encoders should be resolved in analyzer") {
checkAnswer(
spark.sql("SELECT arraysum(data) FROM agg3"),
Row(Seq(12.0, 15.0, 18.0)) :: Nil)
}

test("verify aggregator ser/de behavior") {
val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1")
val agg = udaf(CountSerDeAgg)
Expand Down

0 comments on commit 1cb5bfc

Please sign in to comment.