Skip to content

Commit cc7af86

Browse files
committed
[SPARK-12813][SQL] Eliminate serialization for back to back operations
The goal of this PR is to eliminate unnecessary translations when there are back-to-back `MapPartitions` operations. In order to achieve this I also made the following simplifications: - Operators no longer have hold encoders, instead they have only the expressions that they need. The benefits here are twofold: the expressions are visible to transformations so go through the normal resolution/binding process. now that they are visible we can change them on a case by case basis. - Operators no longer have type parameters. Since the engine is responsible for its own type checking, having the types visible to the complier was an unnecessary complication. We still leverage the scala compiler in the companion factory when constructing a new operator, but after this the types are discarded. Deferred to a follow up PR: - Remove as much of the resolution/binding from Dataset/GroupedDataset as possible. We should still eagerly check resolution and throw an error though in the case of mismatches for an `as` operation. - Eliminate serializations in more cases by adding more cases to `EliminateSerialization` Author: Michael Armbrust <michael@databricks.com> Closes apache#10747 from marmbrus/encoderExpressions.
1 parent 2578298 commit cc7af86

File tree

17 files changed

+518
-274
lines changed

17 files changed

+518
-274
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,10 @@ object CleanupAliases extends Rule[LogicalPlan] {
12141214
Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
12151215
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
12161216

1217+
// Operators that operate on objects should only have expressions from encoders, which should
1218+
// never have extra aliases.
1219+
case o: ObjectOperator => o
1220+
12171221
case other =>
12181222
var stop = false
12191223
other transformExpressionsDown {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ abstract class Star extends LeafExpression with NamedExpression {
160160
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
161161
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
162162
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
163+
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
163164
override lazy val resolved = false
164165

165166
def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression]
@@ -246,6 +247,8 @@ case class MultiAlias(child: Expression, names: Seq[String])
246247

247248
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
248249

250+
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
251+
249252
override lazy val resolved = false
250253

251254
override def toString: String = s"$child AS $names"
@@ -259,6 +262,7 @@ case class MultiAlias(child: Expression, names: Seq[String])
259262
* @param expressions Expressions to expand.
260263
*/
261264
case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable {
265+
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
262266
override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = expressions
263267
override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
264268
}
@@ -298,6 +302,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
298302
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
299303
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
300304
override def name: String = throw new UnresolvedException(this, "name")
305+
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
301306

302307
override lazy val resolved = false
303308
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,16 @@ case class ExpressionEncoder[T](
207207
resolve(attrs, OuterScopes.outerScopes).bind(attrs)
208208
}
209209

210+
211+
/**
212+
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
213+
* of this object.
214+
*/
215+
def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(toRowExpressions).map {
216+
case (_, ne: NamedExpression) => ne.newInstance()
217+
case (name, e) => Alias(e, name)()
218+
}
219+
210220
/**
211221
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
212222
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
3131
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
3232
extends LeafExpression with NamedExpression {
3333

34-
override def toString: String = s"input[$ordinal, $dataType]"
34+
override def toString: String = s"input[$ordinal, ${dataType.simpleString}]"
3535

3636
// Use special getter for primitive types (for UnsafeRow)
3737
override def eval(input: InternalRow): Any = {
@@ -66,6 +66,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
6666

6767
override def exprId: ExprId = throw new UnsupportedOperationException
6868

69+
override def newInstance(): NamedExpression = this
70+
6971
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
7072
val javaType = ctx.javaType(dataType)
7173
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ trait NamedExpression extends Expression {
7979
/** Returns the metadata when an expression is a reference to another expression with metadata. */
8080
def metadata: Metadata = Metadata.empty
8181

82+
/** Returns a copy of this expression with a new `exprId`. */
83+
def newInstance(): NamedExpression
84+
8285
protected def typeSuffix =
8386
if (resolved) {
8487
dataType match {
@@ -144,6 +147,9 @@ case class Alias(child: Expression, name: String)(
144147
}
145148
}
146149

150+
def newInstance(): NamedExpression =
151+
Alias(child, name)(qualifiers = qualifiers, explicitMetadata = explicitMetadata)
152+
147153
override def toAttribute: Attribute = {
148154
if (resolved) {
149155
AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ case class Invoke(
172172
$objNullCheck
173173
"""
174174
}
175+
176+
override def toString: String = s"$targetObject.$functionName"
175177
}
176178

177179
object NewInstance {
@@ -253,6 +255,8 @@ case class NewInstance(
253255
"""
254256
}
255257
}
258+
259+
override def toString: String = s"newInstance($cls)"
256260
}
257261

258262
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
6767
RemoveDispensableExpressions,
6868
SimplifyFilters,
6969
SimplifyCasts,
70-
SimplifyCaseConversionExpressions) ::
70+
SimplifyCaseConversionExpressions,
71+
EliminateSerialization) ::
7172
Batch("Decimal Optimizations", FixedPoint(100),
7273
DecimalAggregates) ::
7374
Batch("LocalRelation", FixedPoint(100),
@@ -96,6 +97,19 @@ object SamplePushDown extends Rule[LogicalPlan] {
9697
}
9798
}
9899

100+
/**
101+
* Removes cases where we are unnecessarily going between the object and serialized (InternalRow)
102+
* representation of data item. For example back to back map operations.
103+
*/
104+
object EliminateSerialization extends Rule[LogicalPlan] {
105+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
106+
case m @ MapPartitions(_, input, _, child: ObjectOperator)
107+
if !input.isInstanceOf[Attribute] && m.input.dataType == child.outputObject.dataType =>
108+
val childWithoutSerialization = child.withObjectOutput
109+
m.copy(input = childWithoutSerialization.output.head, child = childWithoutSerialization)
110+
}
111+
}
112+
99113
/**
100114
* Pushes certain operations to both sides of a Union, Intersect or Except operator.
101115
* Operations that are safe to pushdown are listed as follows.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 0 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22-
import org.apache.spark.sql.Encoder
23-
import org.apache.spark.sql.catalyst.encoders._
2422
import org.apache.spark.sql.catalyst.expressions._
2523
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2624
import org.apache.spark.sql.catalyst.plans._
@@ -480,120 +478,3 @@ case object OneRowRelation extends LeafNode {
480478
*/
481479
override def statistics: Statistics = Statistics(sizeInBytes = 1)
482480
}
483-
484-
/**
485-
* A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are
486-
* used respectively to decode/encode from the JVM object representation expected by `func.`
487-
*/
488-
case class MapPartitions[T, U](
489-
func: Iterator[T] => Iterator[U],
490-
tEncoder: ExpressionEncoder[T],
491-
uEncoder: ExpressionEncoder[U],
492-
output: Seq[Attribute],
493-
child: LogicalPlan) extends UnaryNode {
494-
override def producedAttributes: AttributeSet = outputSet
495-
}
496-
497-
/** Factory for constructing new `AppendColumn` nodes. */
498-
object AppendColumns {
499-
def apply[T, U : Encoder](
500-
func: T => U,
501-
tEncoder: ExpressionEncoder[T],
502-
child: LogicalPlan): AppendColumns[T, U] = {
503-
val attrs = encoderFor[U].schema.toAttributes
504-
new AppendColumns[T, U](func, tEncoder, encoderFor[U], attrs, child)
505-
}
506-
}
507-
508-
/**
509-
* A relation produced by applying `func` to each partition of the `child`, concatenating the
510-
* resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to
511-
* decode/encode from the JVM object representation expected by `func.`
512-
*/
513-
case class AppendColumns[T, U](
514-
func: T => U,
515-
tEncoder: ExpressionEncoder[T],
516-
uEncoder: ExpressionEncoder[U],
517-
newColumns: Seq[Attribute],
518-
child: LogicalPlan) extends UnaryNode {
519-
override def output: Seq[Attribute] = child.output ++ newColumns
520-
override def producedAttributes: AttributeSet = AttributeSet(newColumns)
521-
}
522-
523-
/** Factory for constructing new `MapGroups` nodes. */
524-
object MapGroups {
525-
def apply[K, T, U : Encoder](
526-
func: (K, Iterator[T]) => TraversableOnce[U],
527-
kEncoder: ExpressionEncoder[K],
528-
tEncoder: ExpressionEncoder[T],
529-
groupingAttributes: Seq[Attribute],
530-
child: LogicalPlan): MapGroups[K, T, U] = {
531-
new MapGroups(
532-
func,
533-
kEncoder,
534-
tEncoder,
535-
encoderFor[U],
536-
groupingAttributes,
537-
encoderFor[U].schema.toAttributes,
538-
child)
539-
}
540-
}
541-
542-
/**
543-
* Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`.
544-
* Func is invoked with an object representation of the grouping key an iterator containing the
545-
* object representation of all the rows with that key.
546-
*/
547-
case class MapGroups[K, T, U](
548-
func: (K, Iterator[T]) => TraversableOnce[U],
549-
kEncoder: ExpressionEncoder[K],
550-
tEncoder: ExpressionEncoder[T],
551-
uEncoder: ExpressionEncoder[U],
552-
groupingAttributes: Seq[Attribute],
553-
output: Seq[Attribute],
554-
child: LogicalPlan) extends UnaryNode {
555-
override def producedAttributes: AttributeSet = outputSet
556-
}
557-
558-
/** Factory for constructing new `CoGroup` nodes. */
559-
object CoGroup {
560-
def apply[Key, Left, Right, Result : Encoder](
561-
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
562-
keyEnc: ExpressionEncoder[Key],
563-
leftEnc: ExpressionEncoder[Left],
564-
rightEnc: ExpressionEncoder[Right],
565-
leftGroup: Seq[Attribute],
566-
rightGroup: Seq[Attribute],
567-
left: LogicalPlan,
568-
right: LogicalPlan): CoGroup[Key, Left, Right, Result] = {
569-
CoGroup(
570-
func,
571-
keyEnc,
572-
leftEnc,
573-
rightEnc,
574-
encoderFor[Result],
575-
encoderFor[Result].schema.toAttributes,
576-
leftGroup,
577-
rightGroup,
578-
left,
579-
right)
580-
}
581-
}
582-
583-
/**
584-
* A relation produced by applying `func` to each grouping key and associated values from left and
585-
* right children.
586-
*/
587-
case class CoGroup[Key, Left, Right, Result](
588-
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
589-
keyEnc: ExpressionEncoder[Key],
590-
leftEnc: ExpressionEncoder[Left],
591-
rightEnc: ExpressionEncoder[Right],
592-
resultEnc: ExpressionEncoder[Result],
593-
output: Seq[Attribute],
594-
leftGroup: Seq[Attribute],
595-
rightGroup: Seq[Attribute],
596-
left: LogicalPlan,
597-
right: LogicalPlan) extends BinaryNode {
598-
override def producedAttributes: AttributeSet = outputSet
599-
}

0 commit comments

Comments
 (0)