Skip to content

Commit addb3ab

Browse files
committed
Merge pull request #23 from marmbrus/streaming-attributes
Fix attribute rewiring
2 parents 0630d29 + 764aac9 commit addb3ab

File tree

6 files changed

+60
-22
lines changed

6 files changed

+60
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,8 @@ object CleanupAliases extends Rule[LogicalPlan] {
12151215
Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
12161216
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
12171217

1218+
case o: ObjectOperator => o
1219+
12181220
case other =>
12191221
var stop = false
12201222
other transformExpressionsDown {
@@ -1265,22 +1267,26 @@ object ResolveUpCast extends Rule[LogicalPlan] {
12651267
}
12661268

12671269
def apply(plan: LogicalPlan): LogicalPlan = {
1268-
plan transformAllExpressions {
1269-
case u @ UpCast(child, _, _) if !child.resolved => u
1270-
1271-
case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
1272-
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
1273-
fail(child, to, walkedTypePath)
1274-
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
1275-
fail(child, to, walkedTypePath)
1276-
case (from, to) if illegalNumericPrecedence(from, to) =>
1277-
fail(child, to, walkedTypePath)
1278-
case (TimestampType, DateType) =>
1279-
fail(child, DateType, walkedTypePath)
1280-
case (StringType, to: NumericType) =>
1281-
fail(child, to, walkedTypePath)
1282-
case _ => Cast(child, dataType)
1283-
}
1270+
plan transform {
1271+
case o: ObjectOperator => o
1272+
case other =>
1273+
other transformExpressions {
1274+
case u@UpCast(child, _, _) if !child.resolved => u
1275+
1276+
case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
1277+
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
1278+
fail(child, to, walkedTypePath)
1279+
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
1280+
fail(child, to, walkedTypePath)
1281+
case (from, to) if illegalNumericPrecedence(from, to) =>
1282+
fail(child, to, walkedTypePath)
1283+
case (TimestampType, DateType) =>
1284+
fail(child, DateType, walkedTypePath)
1285+
case (StringType, to: NumericType) =>
1286+
fail(child, to, walkedTypePath)
1287+
case _ => Cast(child, dataType)
1288+
}
1289+
}
12841290
}
12851291
}
12861292
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ object NullPropagation extends Rule[LogicalPlan] {
448448
*/
449449
object ConstantFolding extends Rule[LogicalPlan] {
450450
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
451+
case o: ObjectOperator => o
451452
case q: LogicalPlan => q transformExpressionsDown {
452453
// Skip redundant folding of literals. This rule is technically not necessary. Placing this
453454
// here avoids running the next rule for Literal values, which would create a new Literal

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
package org.apache.spark.sql.catalyst.plans
1919

20+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2021
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn}
2122
import org.apache.spark.sql.catalyst.trees.TreeNode
2223
import org.apache.spark.sql.types.{DataType, StructType}
24+
import org.apache.spark.sql.catalyst.util._
2325

2426
abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] {
2527
self: PlanType =>
@@ -83,6 +85,14 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
8385
}
8486

8587
def recursiveTransform(arg: Any): AnyRef = arg match {
88+
case e: ExpressionEncoder[_] =>
89+
val newEncoder = new ExpressionEncoder(
90+
e.schema,
91+
e.flat,
92+
e.toRowExpressions.map(transformExpressionDown),
93+
transformExpressionDown(e.fromRowExpression),
94+
e.clsTag)
95+
newEncoder
8696
case e: Expression => transformExpressionDown(e)
8797
case Some(e: Expression) => Some(transformExpressionDown(e))
8898
case m: Map[_, _] => m

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,8 @@ case object OneRowRelation extends LeafNode {
479479
override def statistics: Statistics = Statistics(sizeInBytes = 1)
480480
}
481481

482+
trait ObjectOperator
483+
482484
/**
483485
* A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are
484486
* used respectively to decode/encode from the JVM object representation expected by `func.`
@@ -488,7 +490,7 @@ case class MapPartitions[T, U](
488490
tEncoder: ExpressionEncoder[T],
489491
uEncoder: ExpressionEncoder[U],
490492
output: Seq[Attribute],
491-
child: LogicalPlan) extends UnaryNode {
493+
child: LogicalPlan) extends UnaryNode with ObjectOperator {
492494
override def producedAttributes: AttributeSet = outputSet
493495
}
494496

@@ -513,7 +515,7 @@ case class AppendColumns[T, U](
513515
tEncoder: ExpressionEncoder[T],
514516
uEncoder: ExpressionEncoder[U],
515517
newColumns: Seq[Attribute],
516-
child: LogicalPlan) extends UnaryNode {
518+
child: LogicalPlan) extends UnaryNode with ObjectOperator {
517519
override def output: Seq[Attribute] = child.output ++ newColumns
518520
override def producedAttributes: AttributeSet = AttributeSet(newColumns)
519521
}
@@ -549,7 +551,7 @@ case class MapGroups[K, T, U](
549551
uEncoder: ExpressionEncoder[U],
550552
groupingAttributes: Seq[Attribute],
551553
output: Seq[Attribute],
552-
child: LogicalPlan) extends UnaryNode {
554+
child: LogicalPlan) extends UnaryNode with ObjectOperator {
553555
override def producedAttributes: AttributeSet = outputSet
554556
}
555557

@@ -592,6 +594,6 @@ case class CoGroup[Key, Left, Right, Result](
592594
leftGroup: Seq[Attribute],
593595
rightGroup: Seq[Attribute],
594596
left: LogicalPlan,
595-
right: LogicalPlan) extends BinaryNode {
597+
right: LogicalPlan) extends BinaryNode with ObjectOperator {
596598
override def producedAttributes: AttributeSet = outputSet
597599
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst
1919

2020
import java.io._
2121

22+
import org.apache.spark.Logging
2223
import org.apache.spark.util.Utils
2324

24-
package object util {
25+
package object util extends Logging{
2526

2627
/** Silences output to stderr or stdout for the duration of f */
2728
def quietly[A](f: => A): A = {
@@ -42,6 +43,24 @@ package object util {
4243
}
4344
}
4445

46+
private val analysisRule = """.*org\.apache\.spark\.sql\.catalyst\.analysis\.([A-Za-z]+).*""".r
47+
48+
/**
49+
* Logs along with the name of the analyzer rule that is running. This is pretty expensive so
50+
* always logs at warning.
51+
*/
52+
def logRule(msg: String): Unit = {
53+
val error = try sys.error("") catch {
54+
case e: Exception =>
55+
stackTraceToString(e)
56+
}
57+
58+
val rule = error.split("\n").collect {
59+
case analysisRule(r) => r
60+
}.headOption.getOrElse("unknown rule")
61+
logWarning(s"$rule: $msg")
62+
}
63+
4564
def fileToString(file: File, encoding: String = "UTF-8"): String = {
4665
val inStream = new FileInputStream(file)
4766
val outStream = new ByteArrayOutputStream

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class StreamExecution(
108108
val newPlan = batch.data.logicalPlan
109109

110110
assert(output.size == newPlan.output.size)
111-
replacements ++= newPlan.output.zip(output)
111+
replacements ++= output.zip(newPlan.output)
112112
newPlan
113113
}.getOrElse {
114114
LocalRelation(output)

0 commit comments

Comments
 (0)