@@ -21,6 +21,7 @@ import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
2121import org .apache .spark .sql .catalyst .InternalRow
2222import org .apache .spark .sql .catalyst .expressions ._
2323import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , ExprCode , ExpressionCanonicalizer }
24+ import org .apache .spark .sql .catalyst .planning .Casts
2425import org .apache .spark .sql .catalyst .plans .physical ._
2526import org .apache .spark .sql .execution .metric .{LongSQLMetricValue , SQLMetrics }
2627import org .apache .spark .sql .types .LongType
@@ -80,12 +81,21 @@ case class Filter(condition: Expression, child: SparkPlan)
8081 // Split out all the IsNotNulls from condition.
8182 private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
8283 case IsNotNull (a) if child.output.contains(a) => true
84+ case IsNotNull (a) =>
85+ a match {
86+ case Casts (a) if child.output.contains(a) => true
87+ case _ => false
88+ }
8389 case _ => false
8490 }
8591
86- // The columns that will filtered out by `IsNotNull` could be considered as not nullable.
92+ // The columns that will be filtered out by `IsNotNull` could be considered as not nullable.
8793 private val notNullAttributes = notNullPreds.flatMap(_.references)
8894
95+ // only the attributes those will be filtered out by `IsNotNull` should be evaluated
96+ // before this plan, otherwise we could defer the evaluation until filtering out nulls.
97+ override def usedInputs : AttributeSet = AttributeSet (notNullAttributes)
98+
8999 override def output : Seq [Attribute ] = {
90100 child.output.map { a =>
91101 if (a.nullable && notNullAttributes.contains(a)) {
@@ -110,6 +120,9 @@ case class Filter(condition: Expression, child: SparkPlan)
110120 override def doConsume (ctx : CodegenContext , input : Seq [ExprCode ], row : String ): String = {
111121 val numOutput = metricTerm(ctx, " numOutputRows" )
112122
123+ val evaluated =
124+ evaluateRequiredVariables(child.output, input, references -- usedInputs)
125+
113126 // filter out the nulls
114127 val filterOutNull = notNullAttributes.map { a =>
115128 val idx = child.output.indexOf(a)
@@ -142,6 +155,7 @@ case class Filter(condition: Expression, child: SparkPlan)
142155 }
143156 s """
144157 | $filterOutNull
158+ | $evaluated
145159 | $predicates
146160 | $numOutput.add(1);
147161 | ${consume(ctx, resultVars)}
0 commit comments