Skip to content

Commit

Permalink
[SPARK-48510] 2/2] Support UDAF toColumn API in Spark Connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR follows apache#46245 to add support `udaf.toColumn` API in Spark Connect.

Here we introduce a new Protobuf message, `proto.TypedAggregateExpression`, that includes a serialized UDF packet. On the server, we unpack it into an `Aggregator` object and generate a real `TypedAggregateExpression` instance with the encoder information passed along with the UDF.

### Why are the changes needed?

Because the `toColumn` API is not supported in the previous PR.

### Does this PR introduce _any_ user-facing change?

Yes, from now on users could create typed UDAF using `udaf.toColumn` API/.

### How was this patch tested?

New tests.

### Was this patch authored or co-authored using generative AI tooling?

Nope.

Closes apache#46849 from xupefei/connect-udaf-tocolumn.

Authored-by: Paddy Xu <xupaddy@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
xupefei authored and HyukjinKwon committed Jul 12, 2024
1 parent 8d3d4f9 commit e20db13
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ message Expression {
CallFunction call_function = 16;
NamedArgumentExpression named_argument_expression = 17;
MergeAction merge_action = 19;
TypedAggregateExpression typed_aggregate_expression = 20;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// relations they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -402,6 +403,11 @@ message JavaUDF {
bool aggregate = 3;
}

message TypedAggregateExpression {
// (Required) The aggregate function object packed into bytes.
ScalarScalaUDF scalar_scala_udf = 1;
}

message CallFunction {
// (Required) Unparsed name of the SQL function.
string function_name = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, Mu
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
Expand All @@ -67,6 +67,7 @@ import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, Spark
import org.apache.spark.sql.connect.utils.MetricGenerator
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
Expand Down Expand Up @@ -1455,7 +1456,7 @@ class SparkConnectPlanner(
}

val projection = rel.getExpressionsList.asScala.toSeq
.map(transformExpression)
.map(transformExpression(_, Some(baseRel)))
.map(toNamedExpression)

logical.Project(projectList = projection, child = baseRel)
Expand All @@ -1472,21 +1473,40 @@ class SparkConnectPlanner(
* Catalyst expression
*/
@DeveloperApi
def transformExpression(exp: proto.Expression): Expression = if (exp.hasCommon) {
def transformExpression(exp: proto.Expression): Expression = transformExpression(exp, None)

/**
* Transforms an input protobuf expression into the Catalyst expression. This is usually not
* called directly. Typically the planner will traverse the expressions automatically, only
* plugins are expected to manually perform expression transformations.
*
* @param exp
* the input expression
* @param baseRelationOpt
* inputs of the base relation that contains this expression
* @return
* Catalyst expression
*/
@DeveloperApi
def transformExpression(
exp: proto.Expression,
baseRelationOpt: Option[LogicalPlan]): Expression = if (exp.hasCommon) {
try {
val origin = exp.getCommon.getOrigin
PySparkCurrentOrigin.set(
origin.getPythonOrigin.getFragment,
origin.getPythonOrigin.getCallSite)
withOrigin { doTransformExpression(exp) }
withOrigin { doTransformExpression(exp, baseRelationOpt) }
} finally {
PySparkCurrentOrigin.clear()
}
} else {
doTransformExpression(exp)
doTransformExpression(exp, baseRelationOpt)
}

private def doTransformExpression(exp: proto.Expression): Expression = {
private def doTransformExpression(
exp: proto.Expression,
baseRelationOpt: Option[LogicalPlan]): Expression = {
exp.getExprTypeCase match {
case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
Expand Down Expand Up @@ -1523,6 +1543,8 @@ class SparkConnectPlanner(
transformNamedArgumentExpression(exp.getNamedArgumentExpression)
case proto.Expression.ExprTypeCase.MERGE_ACTION =>
transformMergeAction(exp.getMergeAction)
case proto.Expression.ExprTypeCase.TYPED_AGGREGATE_EXPRESSION =>
transformTypedAggregateExpression(exp.getTypedAggregateExpression, baseRelationOpt)
case _ =>
throw InvalidPlanInput(
s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported")
Expand Down Expand Up @@ -2584,8 +2606,35 @@ class SparkConnectPlanner(
if expr.getUnresolvedFunction.getFunctionName == "reduce" =>
// The reduce func needs the input data attribute, thus handle it specially here
transformTypedReduceExpression(expr.getUnresolvedFunction, plan.output)
case _ => transformExpression(expr)
case _ => transformExpression(expr, Some(plan))
}
}

private def transformTypedAggregateExpression(
expr: proto.TypedAggregateExpression,
baseRelationOpt: Option[LogicalPlan]): AggregateExpression = {
val udf = expr.getScalarScalaUdf
assert(udf.getAggregate)

val udfPacket = unpackScalaUDF[UdfPacket](udf)
assert(udfPacket.inputEncoders.size == 1, "UDAF should have exactly one input encoder")

val aggregator = udfPacket.function.asInstanceOf[Aggregator[Any, Any, Any]]
val tae =
TypedAggregateExpression(aggregator)(aggregator.bufferEncoder, aggregator.outputEncoder)
val taeWithInput = baseRelationOpt match {
case Some(baseRelation) =>
val inputEncoder = TypedScalaUdf.encoderFor(
udfPacket.inputEncoders.head,
"input",
Some(baseRelation.output))
TypedAggUtils
.withInputType(tae, inputEncoder, baseRelation.output)
.asInstanceOf[TypedAggregateExpression]
case _ =>
tae
}
taeWithInput.toAggregateExpression()
}

private def transformMergeAction(action: proto.MergeAction): MergeAction = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

package org.apache.spark.sql.expressions

import org.apache.spark.sql.{Encoder, TypedColumn}
import scala.reflect.runtime.universe._

import org.apache.spark.connect.proto
import org.apache.spark.sql.{encoderFor, Encoder, TypedColumn}
import org.apache.spark.sql.catalyst.ScalaReflection

/**
* A base class for user-defined aggregations, which can be used in `Dataset` operations to take
Expand Down Expand Up @@ -92,9 +96,52 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
def outputEncoder: Encoder[OUT]

/**
* Returns this `Aggregator` as a `TypedColumn` that can be used in `Dataset`. operations.
* Returns this `Aggregator` as a `TypedColumn` that can be used in `Dataset` operations.
* @since 4.0.0
*/
def toColumn: TypedColumn[IN, OUT] = {
throw new UnsupportedOperationException("toColumn is not implemented.")
val ttpe = getInputTypeTag[IN]
val inputEncoder = ScalaReflection.encoderFor(ttpe)
val udaf =
ScalaUserDefinedFunction(
this,
Seq(inputEncoder),
encoderFor(outputEncoder),
aggregate = true)

val builder = proto.TypedAggregateExpression.newBuilder()
builder.setScalarScalaUdf(udaf.udf)
val expr = proto.Expression.newBuilder().setTypedAggregateExpression(builder).build()

new TypedColumn(expr, encoderFor(outputEncoder))
}

private final def getInputTypeTag[T]: TypeTag[T] = {
val mirror = runtimeMirror(this.getClass.getClassLoader)
val tpe = mirror.classSymbol(this.getClass).toType
// Find the most generic (last in the tree) Aggregator class
val baseAgg =
tpe.baseClasses
.findLast(_.asClass.toType <:< typeOf[Aggregator[_, _, _]])
.getOrElse(throw new IllegalStateException("Could not find the Aggregator base class."))
val typeArgs = tpe.baseType(baseAgg).typeArgs
assert(
typeArgs.length == 3,
s"Aggregator should have 3 type arguments, " +
s"but found ${typeArgs.length}: ${typeArgs.mkString}.")
val inType = typeArgs.head

import scala.reflect.api._
TypeTag(
mirror,
new TypeCreator {
def apply[U <: Universe with Singleton](m: Mirror[U]): U#Type =
if (m eq mirror) {
inType.asInstanceOf[U#Type]
} else {
throw new IllegalArgumentException(
s"Type tag defined in $mirror cannot be migrated to other mirrors.")
}
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ case class ScalaUserDefinedFunction private[sql] (
aggregate: Boolean)
extends UserDefinedFunction {

private[this] lazy val udf = {
private[expressions] lazy val udf = {
val scalaUdfBuilder = proto.ScalarScalaUDF
.newBuilder()
.setPayload(ByteString.copyFrom(serializedUdfPacket))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,17 +367,7 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest with RemoteSparkSession
test("UDAF custom Aggregator - case class as input types") {
val session: SparkSession = spark
import session.implicits._
val agg = new Aggregator[UdafTestInput, (Long, Long), Long] {
override def zero: (Long, Long) = (0L, 0L)
override def reduce(b: (Long, Long), a: UdafTestInput): (Long, Long) =
(b._1 + a.id, b._2 + a.extra)
override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) =
(b1._1 + b2._1, b1._2 + b2._2)
override def finish(reduction: (Long, Long)): Long = reduction._1 + reduction._2
override def bufferEncoder: Encoder[(Long, Long)] =
Encoders.tuple(Encoders.scalaLong, Encoders.scalaLong)
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
val agg = new CompleteUdafTestInputAggregator()
spark.udf.register("agg", udaf(agg))
val result = spark
.range(10)
Expand All @@ -388,6 +378,66 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest with RemoteSparkSession
.head()
assert(result == 135) // 45 + 90
}

test("UDAF custom Aggregator - toColumn") {
val session: SparkSession = spark
import session.implicits._
val aggCol = new CompleteUdafTestInputAggregator().toColumn
val ds = spark.range(10).withColumn("extra", col("id") * 2).as[UdafTestInput]

assert(ds.select(aggCol).head() == 135) // 45 + 90
assert(ds.agg(aggCol).head().getLong(0) == 135) // 45 + 90
}

test("UDAF custom Aggregator - multiple extends - toColumn") {
val session: SparkSession = spark
import session.implicits._
val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn
val ds = spark.range(10).withColumn("extra", col("id") * 2).as[UdafTestInput]

assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4
assert(ds.agg(aggCol).head().getLong(0) == 540) // (45 + 90) * 4
}
}

case class UdafTestInput(id: Long, extra: Long)

// An Aggregator that takes [[UdafTestInput]] as input.
final class CompleteUdafTestInputAggregator
extends Aggregator[UdafTestInput, (Long, Long), Long] {
override def zero: (Long, Long) = (0L, 0L)
override def reduce(b: (Long, Long), a: UdafTestInput): (Long, Long) =
(b._1 + a.id, b._2 + a.extra)
override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) =
(b1._1 + b2._1, b1._2 + b2._2)
override def finish(reduction: (Long, Long)): Long = reduction._1 + reduction._2
override def bufferEncoder: Encoder[(Long, Long)] =
Encoders.tuple(Encoders.scalaLong, Encoders.scalaLong)
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

// Same as [[CompleteUdafTestInputAggregator]] but the input type is not defined.
abstract class IncompleteUdafTestInputAggregator[T] extends Aggregator[T, (Long, Long), Long] {
override def zero: (Long, Long) = (0L, 0L)
override def reduce(b: (Long, Long), a: T): (Long, Long) // Incomplete!
override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) =
(b1._1 + b2._1, b1._2 + b2._2)
override def finish(reduction: (Long, Long)): Long = reduction._1 + reduction._2
override def bufferEncoder: Encoder[(Long, Long)] =
Encoders.tuple(Encoders.scalaLong, Encoders.scalaLong)
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

// A layer over [[IncompleteUdafTestInputAggregator]] but the input type is still not defined.
abstract class IncompleteChildUdafTestInputAggregator[T]
extends IncompleteUdafTestInputAggregator[T] {
override def finish(reduction: (Long, Long)): Long = (reduction._1 + reduction._2) * 2
}

// Another layer that finally defines the input type.
final class CompleteGrandChildUdafTestInputAggregator
extends IncompleteChildUdafTestInputAggregator[UdafTestInput] {
override def reduce(b: (Long, Long), a: UdafTestInput): (Long, Long) =
(b._1 + a.id, b._2 + a.extra)
override def finish(reduction: (Long, Long)): Long = (reduction._1 + reduction._2) * 4
}
Loading

0 comments on commit e20db13

Please sign in to comment.