Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49924][SQL] Keep containsNull after ArrayCompact replacement #48410

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.SparkException.internalError
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.KnownNotContainsNull
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
Expand Down Expand Up @@ -5330,15 +5331,12 @@ case class ArrayCompact(child: Expression)
child.dataType.asInstanceOf[ArrayType].elementType, true)
lazy val lambda = LambdaFunction(isNotNull(lv), Seq(lv))

override lazy val replacement: Expression = ArrayFilter(child, lambda)
override lazy val replacement: Expression = KnownNotContainsNull(ArrayFilter(child, lambda))

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

override def prettyName: String = "array_compact"

override def dataType: ArrayType =
child.dataType.asInstanceOf[ArrayType].copy(containsNull = false)

override protected def withNewChildInternal(newChild: Expression): ArrayCompact =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{ArrayType, DataType}

trait TaggingExpression extends UnaryExpression {
override def nullable: Boolean = child.nullable
Expand Down Expand Up @@ -52,6 +52,17 @@ case class KnownNotNull(child: Expression) extends TaggingExpression {
copy(child = newChild)
}

case class KnownNotContainsNull(child: Expression) extends TaggingExpression {
override def dataType: DataType =
child.dataType.asInstanceOf[ArrayType].copy(containsNull = false)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
child.genCode(ctx)

override protected def withNewChildInternal(newChild: Expression): KnownNotContainsNull =
copy(child = newChild)
}

case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression {
override protected def withNewChildInternal(newChild: Expression): KnownFloatingPointNormalized =
copy(child = newChild)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, AttributeReference, IntegerLiteral, Literal, Multiply, NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayCompact, AttributeReference, CreateArray, CreateStruct, IntegerLiteral, Literal, MapFromEntries, Multiply, NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, StructField, StructType}

/**
* A dummy optimizer rule for testing that decrements integer literals until 0.
Expand Down Expand Up @@ -313,4 +313,25 @@ class OptimizerSuite extends PlanTest {
assert(message1.contains("not a valid aggregate expression"))
}
}

test("SPARK-49924: Keep containsNull after ArrayCompact replacement") {
val optimizer = new SimpleTestOptimizer() {
override def defaultBatches: Seq[Batch] =
Batch("test", fixedPoint,
ReplaceExpressions) :: Nil
}

val array1 = ArrayCompact(CreateArray(Literal(1) :: Literal.apply(null) :: Nil, false))
val plan1 = Project(Alias(array1, "arr")() :: Nil, OneRowRelation()).analyze
val optimized1 = optimizer.execute(plan1)
assert(optimized1.schema ===
StructType(StructField("arr", ArrayType(IntegerType, false), false) :: Nil))

val struct = CreateStruct(Literal(1) :: Literal(2) :: Nil)
val array2 = ArrayCompact(CreateArray(struct :: Literal.apply(null) :: Nil, false))
val plan2 = Project(Alias(MapFromEntries(array2), "map")() :: Nil, OneRowRelation()).analyze
val optimized2 = optimizer.execute(plan2)
assert(optimized2.schema ===
StructType(StructField("map", MapType(IntegerType, IntegerType, false), false) :: Nil))
}
}
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [filter(e#0, lambdafunction(isnotnull(lambda arg#0), lambda arg#0, false)) AS array_compact(e)#0]
Project [knownnotcontainsnull(filter(e#0, lambdafunction(isnotnull(lambda arg#0), lambda arg#0, false))) AS array_compact(e)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]