Skip to content

[SPARK-32002][SQL]Support ExtractValue from nested ArrayStruct #30467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ case class ProjectionOverSchema(schema: StructType) {
s"unmatched child schema for GetArrayStructFields: ${projSchema.toString}"
)
}
case ExtractNestedArrayField(child, _, _, field, containsNull, containsNullSeq) =>
getProjection(child).map(p => (p, p.dataType)).map {
case (projection, ExtractNestedArrayType(projSchema @ StructType(_), _, _)) =>
ExtractNestedArrayField(projection,
projSchema.fieldIndex(field.name),
projSchema.fields.length,
projSchema(field.name),
containsNull,
containsNullSeq)
case (_, projSchema) =>
throw new IllegalStateException(
s"unmatched child schema for ExtractNestedArrayField: ${projSchema.toString}"
)
}
case MapKeys(child) =>
getProjection(child).map { projection => MapKeys(projection) }
case MapValues(child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ object SelectedField {
/**
* Convert an expression into the parts of the schema (the field) it accesses.
*/
private def selectField(expr: Expression, dataTypeOpt: Option[DataType]): Option[StructField] = {
private def selectField(expr: Expression, dataTypeOpt: Option[DataType],
nestArray: Boolean = false): Option[StructField] = {
expr match {
case a: Attribute =>
dataTypeOpt.map { dt =>
Expand All @@ -81,16 +82,37 @@ object SelectedField {
// GetArrayStructFields is the top level extractor. This means its result is
// not pruned and we need to use the element type of the array its producing.
field.dataType
case Some(ArrayType(dataType, _)) =>
case Some(ArrayType(dataType, nullable)) =>
// GetArrayStructFields is part of a chain of extractors and its result is pruned
// by a parent expression. In this case need to use the parent element type.
dataType
if (nestArray) ArrayType(dataType, nullable) else dataType
case Some(x) =>
// This should not happen.
throw new AnalysisException(s"DataType '$x' is not supported by GetArrayStructFields.")
}
val newField = StructField(field.name, newFieldDataType, field.nullable)
selectField(child, Option(ArrayType(struct(newField), containsNull)))
case ExtractNestedArrayField(child, _, _, field @ StructField(_, _, _, _), _, _) =>
val newFieldDataType = dataTypeOpt match {
case None =>
// ExtractNestedArrayField is the top level extractor. This means its result is
// not pruned and we need to use the element type of the array its producing.
field.dataType
case Some(dataType) =>
dataType
}
val structType = struct(StructField(field.name, newFieldDataType, field.nullable))

val newDataType = child match {
case ExtractNestedArrayField(_, _, _, childField, containsNull, _) =>
childField.dataType match {
case _: ArrayType => ArrayType(structType, containsNull)
case _ => structType
}
case GetArrayStructFields(_, _, _, _, nullable) => ArrayType(structType, nullable)
case _ => structType
}
selectField(child, Some(newDataType), nestArray = true)
case GetMapValue(child, _, _) =>
// GetMapValue does not select a field from a struct (i.e. prune the struct) so it can't be
// the top-level extractor. However it can be part of an extractor chain.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator,
CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand All @@ -36,12 +37,13 @@ object ExtractValue {
* Returns the resolved `ExtractValue`. It will return one kind of concrete `ExtractValue`,
* depend on the type of `child` and `extraction`.
*
* `child` | `extraction` | concrete `ExtractValue`
* ----------------------------------------------------------------
* Struct | Literal String | GetStructField
* Array[Struct] | Literal String | GetArrayStructFields
* Array | Integral type | GetArrayItem
* Map | map key type | GetMapValue
* `child` | `extraction` | concrete `ExtractValue`
* --------------------------------------------------------------------------------
* Struct | Literal String | GetStructField
* Array[Struct] | Literal String | GetArrayStructFields
* Array[ ...Array[struct] ] | Literal String | ExtractNestedArrayField
* Array | Integral type | GetArrayItem
* Map | map key type | GetMapValue
*/
def apply(
child: Expression,
Expand All @@ -60,6 +62,13 @@ object ExtractValue {
GetArrayStructFields(child, fields(ordinal).copy(name = fieldName),
ordinal, fields.length, containsNull || fields(ordinal).nullable)

case (ExtractNestedArrayType(StructType(fields), containsNull, containsNullSeq),
NonNullLiteral(v, StringType)) if containsNullSeq.nonEmpty =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
ExtractNestedArrayField(child, ordinal, fields.length,
fields(ordinal).copy(name = fieldName), containsNull, containsNullSeq)

case (_: ArrayType, _) => GetArrayItem(child, extraction)

case (MapType(kt, _, _), _) => GetMapValue(child, extraction)
Expand Down Expand Up @@ -218,6 +227,85 @@ case class GetArrayStructFields(
}
}

/**
* ExtractNestedArrayType is used to match consecutive nested array types.
*
* ReturnType: (DataType: the innermost dataType, Boolean: the outermost array contains null
* , Seq[Boolean]: the second outer layer to the innermost layer contains null)
*
*/
object ExtractNestedArrayType {
type ReturnType = Option[(DataType, Boolean, Seq[Boolean])]

def unapply(dataType: DataType): ReturnType = {
dataType match {
case ArrayType(dt, containsNull) =>
unapply(dt) match {
case Some((d, cn, seq)) => Some((d, containsNull, cn +: seq))
case None => Some((dt, containsNull, Seq.empty[Boolean]))
}
case _ => None
}
}
}

/**
* For a child whose data type is a nested array containing struct at the innermost level, extracts
* the `ordinal`-th fields of multi-level nested array, and returns them as a new nested array.
*/
case class ExtractNestedArrayField(
child: Expression,
ordinal: Int,
numFields: Int,
field: StructField,
containsNull: Boolean,
containsNullSeq: Seq[Boolean]) extends UnaryExpression
with ExtractValue with NullIntolerant with CodegenFallback {

protected override def nullSafeEval(input: Any): Any = {
val array = input.asInstanceOf[ArrayData]
new GenericArrayData(
(0 until array.numElements()).map(n => evalArrayItem(n, array, containsNullSeq.size)))
}

private def evalArrayItem(original: Int, array: ArrayData, num: Int): ArrayData = {
if (array.isNullAt(original)) {
null
}
else {
val innerArray = array.get(original, nestedArrayType(num)).asInstanceOf[ArrayData]
new GenericArrayData((0 until innerArray.numElements()).map(n => {
if (num == 1) {
extractStruct(n, innerArray)
}
else {
evalArrayItem(n, innerArray, num - 1)
}
}))
}
}

private def extractStruct(n: Int, array: ArrayData): Any = {
if (array.isNullAt(n)) {
null
} else {
val row = array.getStruct(n, numFields)
if (row.isNullAt(ordinal)) {
null
} else {
row.get(ordinal, field.dataType)
}
}
}

override def dataType: DataType = ArrayType(nestedArrayType(0), containsNull)

def nestedArrayType(num: Int): DataType = {
(num until containsNullSeq.size).reverse
.foldLeft(field.dataType) { (e, i) => ArrayType(e, containsNullSeq(i))}
}
}

/**
* Returns the field at `ordinal` in the Array `child`.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,32 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession {
val df = spark.createDataFrame(List(Row(Seq(Row(1), Row(null)))).asJava, schema)
checkAnswer(df.select($"arr".getField("i")), Row(Seq(1, null)))
}

test("SPARK-32002: Support ExtractValue from nested ArrayStruct") {
val jsonStr1 = """{"a": [{"b": [{"c": [1,2]}]}]}"""
val jsonStr2 = """{"a": [{"b": [{"c": [1]}, {"c": [2]}]}]}"""
val df = spark.read.json(Seq(jsonStr1, jsonStr2).toDS())
checkAnswer(df.select($"a.b.c"), Row(Seq(Seq(Seq(1, 2))))
:: Row(Seq(Seq(Seq(1), Seq(2)))) :: Nil)

def genJson(start: Char, end: Char, vStr: String): String = {
(start to end).map(c => s"""{"$c": [""").mkString +
vStr + (start to end).map(_ => "]}").mkString
}

def genResult(start: Char, end: Char, r: Seq[Int]): Any = {
(start until end).fold(r) { (z, _) => Seq(z)}
}

val start: Char = 'a'
for (i <- 2 to 10) {
val end: Char = (start + i).toChar
val json = genJson(start, end, "1,2,3")
val df = spark.read.json(Seq(json).toDS())
checkAnswer(df.select((start to end).mkString(".")),
Row(genResult(start, end, Seq(1, 2, 3))))
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources

import java.io.File

import org.scalactic.Equality

import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.SchemaPruningTest
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType


class NestArraySchemaPruningSuite
extends QueryTest
with FileBasedDataSourceTest
with SchemaPruningTest
with SharedSparkSession
with AdaptiveSparkPlanHelper {
case class AdRecord(positions: Array[Positions])
case class Positions(imps: Array[Impression])
case class Impression(id: String, ad: Advertising, clicks: Array[Clicks])
case class Advertising(index: Int)
case class Clicks(fraud_type: Int)

val adRecords = AdRecord(Array(Positions(Array(Impression("1", Advertising(1),
Array(Clicks(0), Clicks(1))))))) :: AdRecord(Array(Positions(Array(
Impression("2", Advertising(2), Array(Clicks(1), Clicks(2))))))) :: Nil

testSchemaPruning("Nested arrays for pruning schema") {
val queryIndex = sql("select positions.imps.ad.index from adRecords")
checkScan(queryIndex,
"struct<positions:array<struct<imps:array<struct<ad:struct<index:int>>>>>>")
checkAnswer(queryIndex, Row(Seq(Seq(1))) :: Row(Seq(Seq(2))) :: Nil)

val queryId = sql("select positions.imps.id from adRecords")
checkScan(queryId,
"struct<positions:array<struct<imps:array<struct<id:string>>>>>")
checkAnswer(queryId, Row(Seq(Seq("1"))) :: Row(Seq(Seq("2"))) :: Nil)

val queryIndexAndFraud =
sql("select positions.imps.ad.index, positions.imps.clicks.fraud_type from adRecords")
checkScan(queryIndexAndFraud, "struct<positions:array<struct<imps:array<struct<ad:struct" +
"<index:int>, clicks:array<struct<fraud_type:int>>>>>>>")
checkAnswer(queryIndexAndFraud, Row(Seq(Seq(1)), Seq(Seq(Seq(0, 1))))
:: Row(Seq(Seq(2)), Seq(Seq(Seq(1, 2)))) :: Nil)
}

protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = {
test(s"$testName") {
withSQLConf(vectorizedReaderEnabledKey -> "true") {
withData(testThunk)
}
withSQLConf(vectorizedReaderEnabledKey -> "false") {
withData(testThunk)
}
}
}

private def withData(testThunk: => Unit): Unit = {
withTempPath { dir =>
val path = dir.getCanonicalPath

makeDataSourceFile(adRecords, new File(path + "/ad_records/a=1"))

val schema = "`positions` ARRAY<STRUCT<`imps`: ARRAY<STRUCT<`id`: STRING, " +
"`ad`: STRUCT<`index`: INT>, `clicks`: ARRAY<STRUCT<`fraud_type`: INT>>>>>>"
spark.read.format(dataSourceName).schema(schema).load(path + "/ad_records")
.createOrReplaceTempView("adRecords")

testThunk
}
}

protected val schemaEquality = new Equality[StructType] {
override def areEqual(a: StructType, b: Any): Boolean =
b match {
case otherType: StructType => a.sameType(otherType)
case _ => false
}
}

protected def checkScan(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = {
checkScanSchemata(df, expectedSchemaCatalogStrings: _*)
// We check here that we can execute the query without throwing an exception. The results
// themselves are irrelevant, and should be checked elsewhere as needed
df.collect()
}

protected def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = {
val fileSourceScanSchemata =
collect(df.queryExecution.executedPlan) {
case scan: FileSourceScanExec => scan.requiredSchema
}
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
s"but expected $expectedSchemaCatalogStrings")
fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach {
case (scanSchema, expectedScanSchemaCatalogString) =>
val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString)
implicit val equality = schemaEquality
assert(scanSchema === expectedScanSchema)
}
}

override protected val dataSourceName: String = "parquet"
override protected val vectorizedReaderEnabledKey: String =
SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key
}