Skip to content

Commit 6e4748b

Browse files
committed
Defer input evaluation and fix Cast issue in IsNotNull filtering.
1 parent 2ef4c59 commit 6e4748b

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,21 @@ object Unions {
202202
}
203203
}
204204
}
205+
206+
/**
207+
* A pattern that finds the original expression from a sequence of casts.
208+
*/
209+
object Casts {
210+
def unapply(expr: Expression): Option[Expression] = expr match {
211+
case c: Cast => Some(collectCasts(expr))
212+
case _ => None
213+
}
214+
215+
private def collectCasts(e: Expression): Expression = {
216+
if (e.isInstanceOf[Cast]) {
217+
collectCasts(e.children(0))
218+
} else {
219+
e
220+
}
221+
}
222+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
24+
import org.apache.spark.sql.catalyst.planning.Casts
2425
import org.apache.spark.sql.catalyst.plans.physical._
2526
import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
2627
import org.apache.spark.sql.types.LongType
@@ -80,12 +81,21 @@ case class Filter(condition: Expression, child: SparkPlan)
8081
// Split out all the IsNotNulls from condition.
8182
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
8283
case IsNotNull(a) if child.output.contains(a) => true
84+
case IsNotNull(a) =>
85+
a match {
86+
case Casts(a) if child.output.contains(a) => true
87+
case _ => false
88+
}
8389
case _ => false
8490
}
8591

86-
// The columns that will filtered out by `IsNotNull` could be considered as not nullable.
92+
// The columns that will be filtered out by `IsNotNull` could be considered as not nullable.
8793
private val notNullAttributes = notNullPreds.flatMap(_.references)
8894

95+
// only the attributes those will be filtered out by `IsNotNull` should be evaluated
96+
// before this plan, otherwise we could defer the evaluation until filtering out nulls.
97+
override def usedInputs: AttributeSet = AttributeSet(notNullAttributes)
98+
8999
override def output: Seq[Attribute] = {
90100
child.output.map { a =>
91101
if (a.nullable && notNullAttributes.contains(a)) {
@@ -110,6 +120,9 @@ case class Filter(condition: Expression, child: SparkPlan)
110120
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
111121
val numOutput = metricTerm(ctx, "numOutputRows")
112122

123+
val evaluated =
124+
evaluateRequiredVariables(child.output, input, references -- usedInputs)
125+
113126
// filter out the nulls
114127
val filterOutNull = notNullAttributes.map { a =>
115128
val idx = child.output.indexOf(a)
@@ -142,6 +155,7 @@ case class Filter(condition: Expression, child: SparkPlan)
142155
}
143156
s"""
144157
|$filterOutNull
158+
|$evaluated
145159
|$predicates
146160
|$numOutput.add(1);
147161
|${consume(ctx, resultVars)}

0 commit comments

Comments
 (0)