Skip to content

Commit a9ca4ab

Browse files
cloud-fandongjoon-hyun
authored andcommitted
[SPARK-32167][SQL] Fix GetArrayStructFields to respect inner field's nullability together
### What changes were proposed in this pull request? Fix nullability of `GetArrayStructFields`. It should consider both the original array's `containsNull` and the inner field's nullability. ### Why are the changes needed? Fix a correctness issue. ### Does this PR introduce _any_ user-facing change? Yes. See the added test. ### How was this patch tested? a new UT and end-to-end test Closes #28992 from cloud-fan/bug. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org> (cherry picked from commit 5d296ed) Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 6a4200c commit a9ca4ab

File tree

4 files changed

+42
-5
lines changed

4 files changed

+42
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ object ExtractValue {
5757
val fieldName = v.toString
5858
val ordinal = findField(fields, fieldName, resolver)
5959
GetArrayStructFields(child, fields(ordinal).copy(name = fieldName),
60-
ordinal, fields.length, containsNull)
60+
ordinal, fields.length, containsNull || fields(ordinal).nullable)
6161

6262
case (_: ArrayType, _) => GetArrayItem(child, extraction)
6363

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.Row
2122
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue}
2223
import org.apache.spark.sql.catalyst.dsl.expressions._
2324
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
@@ -159,6 +160,31 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
159160
checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null)
160161
}
161162

163+
test("SPARK-32167: nullability of GetArrayStructFields") {
164+
val resolver = SQLConf.get.resolver
165+
166+
val array1 = ArrayType(
167+
new StructType().add("a", "int", nullable = true),
168+
containsNull = false)
169+
val data1 = Literal.create(Seq(Row(null)), array1)
170+
val get1 = ExtractValue(data1, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
171+
assert(get1.containsNull)
172+
173+
val array2 = ArrayType(
174+
new StructType().add("a", "int", nullable = false),
175+
containsNull = true)
176+
val data2 = Literal.create(Seq(null), array2)
177+
val get2 = ExtractValue(data2, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
178+
assert(get2.containsNull)
179+
180+
val array3 = ArrayType(
181+
new StructType().add("a", "int", nullable = false),
182+
containsNull = false)
183+
val data3 = Literal.create(Seq(Row(1)), array3)
184+
val get3 = ExtractValue(data3, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
185+
assert(!get3.containsNull)
186+
}
187+
162188
test("CreateArray") {
163189
val intSeq = Seq(5, 10, 15, 20, 25)
164190
val longSeq = intSeq.map(_.toLong)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,13 +254,13 @@ class SelectedFieldSuite extends AnalysisTest {
254254
StructField("col3", ArrayType(StructType(
255255
StructField("field1", StructType(
256256
StructField("subfield1", IntegerType, nullable = false) :: Nil))
257-
:: Nil), containsNull = false), nullable = false)
257+
:: Nil), containsNull = true), nullable = false)
258258
}
259259

260260
testSelect(arrayWithStructAndMap, "col3.field2['foo'] as foo") {
261261
StructField("col3", ArrayType(StructType(
262262
StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false))
263-
:: Nil), containsNull = false), nullable = false)
263+
:: Nil), containsNull = true), nullable = false)
264264
}
265265

266266
// |-- col1: string (nullable = false)
@@ -471,7 +471,7 @@ class SelectedFieldSuite extends AnalysisTest {
471471
testSelect(mapWithArrayOfStructKey, "map_keys(col2)[0].field1 as foo") {
472472
StructField("col2", MapType(
473473
ArrayType(StructType(
474-
StructField("field1", StringType) :: Nil), containsNull = false),
474+
StructField("field1", StringType) :: Nil), containsNull = true),
475475
ArrayType(StructType(
476476
StructField("field3", StructType(
477477
StructField("subfield3", IntegerType) ::
@@ -482,7 +482,7 @@ class SelectedFieldSuite extends AnalysisTest {
482482
StructField("col2", MapType(
483483
ArrayType(StructType(
484484
StructField("field2", StructType(
485-
StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false),
485+
StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = true),
486486
ArrayType(StructType(
487487
StructField("field3", StructType(
488488
StructField("subfield3", IntegerType) ::

sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717

1818
package org.apache.spark.sql
1919

20+
import scala.collection.JavaConverters._
21+
2022
import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
2123
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2224
import org.apache.spark.sql.test.SharedSparkSession
25+
import org.apache.spark.sql.types.{ArrayType, StructType}
2326

2427
class ComplexTypesSuite extends QueryTest with SharedSparkSession {
28+
import testImplicits._
2529

2630
override def beforeAll(): Unit = {
2731
super.beforeAll()
@@ -106,4 +110,11 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession {
106110
checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil)
107111
checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
108112
}
113+
114+
test("SPARK-32167: get field from an array of struct") {
115+
val innerStruct = new StructType().add("i", "int", nullable = true)
116+
val schema = new StructType().add("arr", ArrayType(innerStruct, containsNull = false))
117+
val df = spark.createDataFrame(List(Row(Seq(Row(1), Row(null)))).asJava, schema)
118+
checkAnswer(df.select($"arr".getField("i")), Row(Seq(1, null)))
119+
}
109120
}

0 commit comments

Comments
 (0)