Skip to content

[SPARK-4502][SQL]Support parquet nested struct pruning and add relevant test #14957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.types.{StructField, StructType}

/**
* A strategy for planning scans over collections of files that might be partitioned or bucketed
Expand Down Expand Up @@ -97,7 +99,19 @@ object FileSourceStrategy extends Strategy with Logging {
dataColumns
.filter(requiredAttributes.contains)
.filterNot(partitionColumns.contains)
val outputSchema = readDataColumns.toStructType
val outputSchema = if (
fsRelation.sqlContext.conf.parquetNestedColumnPruningEnabled &&
fsRelation.fileFormat.isInstanceOf[ParquetFileFormat]
) {
val fullSchema = readDataColumns.toStructType
val prunedSchema = StructType(
generateStructFieldsContainsNesting(projects, fullSchema))
// Merge schema in same StructType and merge with filterAttributes
prunedSchema.fields.map(f => StructType(Array(f))).reduceLeft(_ merge _)
.merge(filterAttributes.toSeq.toStructType)
} else {
readDataColumns.toStructType
}
logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}")

val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter)
Expand Down Expand Up @@ -126,4 +140,64 @@ object FileSourceStrategy extends Strategy with Logging {

case _ => Nil
}

private[sql] def generateStructFieldsContainsNesting(
projects: Seq[Expression],
fullSchema: StructType) : Seq[StructField] = {
// By traverse projects, we can fisrt generate the access path of nested struct, then use the
// access path reconstruct the schema after pruning.
// In the process of traversing, we should deal with all expressions releted with complex
// struct type like GetArrayItem, GetArrayStructFields, GetMapValue and GetStructField
def generateStructField(
curField: List[String],
node: Expression) : Seq[StructField] = {
node match {
case ai: GetArrayItem =>
// Here we drop the previous for simplify array and map support.
// Same strategy in GetArrayStructFields and GetMapValue
generateStructField(List.empty[String], ai.child)
case asf: GetArrayStructFields =>
generateStructField(List.empty[String], asf.child)
case mv: GetMapValue =>
generateStructField(List.empty[String], mv.child)
case attr: AttributeReference =>
// Finally reach the leaf node AttributeReference, call getFieldRecursively
// and pass the access path of current nested struct
Seq(getFieldRecursively(fullSchema, attr.name :: curField))
case sf: GetStructField if !sf.child.isInstanceOf[CreateNamedStruct] &&
!sf.child.isInstanceOf[CreateStruct] =>
val name = sf.name.getOrElse(sf.dataType match {
case StructType(fiedls) =>
fiedls(sf.ordinal).name
})
generateStructField(name :: curField, sf.child)
case _ =>
if (node.children.nonEmpty) {
node.children.flatMap(child => generateStructField(curField, child))
} else {
Seq.empty[StructField]
}
}
}

def getFieldRecursively(schema: StructType, name: List[String]): StructField = {
if (name.length > 1) {
val curField = name.head
val curFieldType = schema(curField)
curFieldType.dataType match {
case st: StructType =>
val newField = getFieldRecursively(StructType(st.fields), name.drop(1))
StructField(curFieldType.name, StructType(Seq(newField)),
curFieldType.nullable, curFieldType.metadata)
case _ =>
throw new IllegalArgumentException(s"""Field "$curField" is not struct field.""")
}
} else {
schema(name.head)
}
}
Copy link
Contributor

@liancheng liancheng Oct 21, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this function can be simplified to:

def getNestedField(schema: StructType, path: Seq[String]): StructField = {
  require(path.nonEmpty, "<error message>")

  path.tail.foldLeft(schema(path.head)) { (field, name) =>
    field.dataType match {
      case t: StructType => t(name)
      case _ => ??? // Throw exception here
    }
  }
}

Copy link
Member Author

@xuanyuanking xuanyuanking Oct 23, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The func getFieldRecursively here need the return value which is a StructField contains all nested relation in path. For example:
The fullSchema is:

root
 |-- col: struct (nullable = true)
 |    |-- s1: struct (nullable = true)
 |    |    |-- s1_1: long (nullable = true)
 |    |    |-- s1_2: long (nullable = true)
 |    |-- str: string (nullable = true)
 |-- num: long (nullable = true)
 |-- str: string (nullable = true)

and when we want to get col.s1.s1_1, the func should return:

StructField(col,StructType(StructField(s1,StructType(StructField(s1_1,LongType,true)),true)),true)

So maybe I can't use the simplified func getNestedField because it returns only the last StructField:

StructField(s1_1,LongType,true)


projects.flatMap(p => generateStructField(List.empty[String], p))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val PARQUET_NESTED_COLUMN_PRUNING = SQLConfigBuilder("spark.sql.parquet.nestedColumnPruning")
.doc("When true, Parquet column pruning also works for nested fields.")
.booleanConf
.createWithDefault(false)

val PARQUET_CACHE_METADATA = SQLConfigBuilder("spark.sql.parquet.cacheMetadata")
.doc("Turns on caching of Parquet schema metadata. Can speed up querying of static data.")
.booleanConf
Expand Down Expand Up @@ -724,6 +729,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {

def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP)

def parquetNestedColumnPruningEnabled: Boolean = getConf(PARQUET_NESTED_COLUMN_PRUNING)

def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT)

def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING)
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,15 @@ import org.apache.hadoop.mapreduce.Job

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{util, InternalRow}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper}
import org.apache.spark.sql.catalyst.util
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateNamedStruct, Expression, ExpressionSet, GetArrayItem, GetStructField, Literal, PredicateHelper}
import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper {
Expand Down Expand Up @@ -442,6 +441,132 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
}
}

test("[SPARK-4502] pruning nested schema by GetStructField projects") {
// Construct fullSchema like below:
// root
// |-- col: struct (nullable = true)
// | |-- s1: struct (nullable = true)
// | | |-- s1_1: long (nullable = true)
// | | |-- s1_2: long (nullable = true)
// | |-- str: string (nullable = true)
// |-- num: long (nullable = true)
// |-- str: string (nullable = true)
val nested_s1 = StructField("s1",
StructType(
Seq(
StructField("s1_1", LongType, true),
StructField("s1_2", LongType, true)
)
), true)
val flat_str = StructField("str", StringType, true)

val fullSchema = StructType(
Seq(
StructField("col", StructType(Seq(nested_s1, flat_str)), true),
StructField("num", LongType, true),
flat_str
))

// Attr of struct col
val colAttr = AttributeReference("col", StructType(
Seq(nested_s1, flat_str)), true)()
// Child expression of col.s1.s1_1
val childExp = GetStructField(
GetStructField(colAttr, 0, Some("s1")), 0, Some("s1_1"))

// Project list of "select num, col.s1.s1_1 as s1_1"
val projects = Seq(
AttributeReference("num", LongType, true)(),
Alias(childExp, "s1_1")()
)
val expextResult =
Seq(
StructField("num", LongType, true),
StructField("col", StructType(
Seq(
StructField(
"s1",
StructType(Seq(StructField("s1_1", LongType, true))),
true)
)
), true)
)
// Call the function generateStructFieldsContainsNesting
val result = FileSourceStrategy.generateStructFieldsContainsNesting(projects,
fullSchema)
assert(result == expextResult)
}

test("[SPARK-4502] pruning nested schema by GetArrayItem projects") {
// Construct fullSchema like below:
// root
// |-- col: struct (nullable = true)
// | |-- info_list: array (nullable = true)
// | | |-- element: struct (containsNull = true)
// | | | |-- s1: struct (nullable = true)
// | | | | |-- s1_1: long (nullable = true)
// | | | | |-- s1_2: long (nullable = true)
val nested_s1 = StructField("s1",
StructType(
Seq(
StructField("s1_1", LongType, true),
StructField("s1_2", LongType, true)
)
), true)
val nested_arr = StructField("info_list", ArrayType(StructType(Seq(nested_s1))), true)

val fullSchema = StructType(
Seq(
StructField("col", StructType(Seq(nested_arr)), true)
))

// Attr of struct col
val colAttr = AttributeReference("col", StructType(
Seq(nested_arr)), true)()
// Child expression of col.info_list[0].s1.s1_1
val arrayChildExp = GetStructField(
GetStructField(
GetArrayItem(
GetStructField(colAttr, 0, Some("info_list")),
Literal(0)
), 0, Some("s1")
), 0, Some("s1_1")
)
// Project list of "select col.info_list[0].s1.s1_1 as complex_get"
val projects = Seq(
Alias(arrayChildExp, "complex_get")()
)
val expextResult =
Seq(
StructField("col", StructType(Seq(nested_arr)))
)
// Call the function generateStructFieldsContainsNesting
val result = FileSourceStrategy.generateStructFieldsContainsNesting(projects,
fullSchema)
assert(result == expextResult)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to split this method into several test cases that test some typical but minimal cases.

BTW, I tried the following test code:

  test("foo") {
    val schema = new StructType()
      .add("f0", IntegerType)
      .add("f1", new StructType()
        .add("f10", IntegerType))

    val expr = GetStructField(
      CreateNamedStruct(Seq(
        Literal("f10"),
        AttributeReference("f0", IntegerType)()
      )),
      0,
      Some("f10")
    )

    StructType(
      FileSourceStrategy.generateStructFieldsContainsNesting(expr :: Nil, schema)
    ).printTreeString()
  }

and it fails with the following exception:

[info] - foo *** FAILED *** (37 milliseconds)
[info]   java.lang.IllegalArgumentException: Field "f0" is not struct field.
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategy$.getFieldRecursively$1(FileSourceStrategy.scala:188)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategy$.org$apache$spark$sql$execution$datasources$FileSourceStrategy$$generateStructField$1(FileSourceStrategy.scala:166)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategy$$anonfun$org$apache$spark$sql$execution$datasources$FileSourceStrategy$$generateStructField$1$1.apply(FileSourceStrategy.scala:171)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategy$$anonfun$org$apache$spark$sql$execution$datasources$FileSourceStrategy$$generateStructField$1$1.apply(FileSourceStrategy.scala:171)
[info]   at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
[info]   at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
[info]   at scala.collection.immutable.List.foreach(List.scala:381)
[info]   at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241)
[info]   at scala.collection.immutable.List.flatMap(List.scala:344)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategy$.org$apache$spark$sql$execution$datasources$FileSourceStrategy$$generateStructField$1(FileSourceStrategy.scala:171)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategy$$anonfun$generateStructFieldsContainsNesting$1.apply(FileSourceStrategy.scala:195)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategy$$anonfun$generateStructFieldsContainsNesting$1.apply(FileSourceStrategy.scala:195)
[info]   at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
[info]   at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
[info]   at scala.collection.immutable.List.foreach(List.scala:381)
[info]   at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241)
[info]   at scala.collection.immutable.List.flatMap(List.scala:344)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategy$.generateStructFieldsContainsNesting(FileSourceStrategy.scala:195)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategySuite$$anonfun$16.apply$mcV$sp(FileSourceStrategySuite.scala:462)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategySuite$$anonfun$16.apply(FileSourceStrategySuite.scala:446)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategySuite$$anonfun$16.apply(FileSourceStrategySuite.scala:446)
[info]   at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22)
[info]   at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85)
[info]   at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104)
[info]   at org.scalatest.Transformer.apply(Transformer.scala:22)
[info]   at org.scalatest.Transformer.apply(Transformer.scala:20)
[info]   at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166)
[info]   at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68)
[info]   at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163)
[info]   at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175)
[info]   at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175)
[info]   at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306)
[info]   at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategySuite.org$scalatest$BeforeAndAfterEach$$super$runTest(FileSourceStrategySuite.scala:42)
[info]   at org.scalatest.BeforeAndAfterEach$class.runTest(BeforeAndAfterEach.scala:255)
[info]   at org.apache.spark.sql.execution.datasources.FileSourceStrategySuite.runTest(FileSourceStrategySuite.scala:42)
[info]   at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208)
[info]   at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208)
[info]   at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:413)
[info]   at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:401)
[info]   at scala.collection.immutable.List.foreach(List.scala:381)
[info]   at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401)
[info]   at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:396)
[info]   at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:483)
[info]   at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:208)
[info]   at org.scalatest.FunSuite.runTests(FunSuite.scala:1555)
[info]   at org.scalatest.Suite$class.run(Suite.scala:1424)
[info]   at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1555)
[info]   at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212)
[info]   at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212)
[info]   at org.scalatest.SuperEngine.runImpl(Engine.scala:545)
[info]   at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:212)
[info]   at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:31)
[info]   at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:257)
[info]   at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:256)
[info]   at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:31)
[info]   at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:357)
[info]   at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:502)
[info]   at sbt.ForkMain$Run$2.call(ForkMain.java:296)
[info]   at sbt.ForkMain$Run$2.call(ForkMain.java:286)
[info]   at java.util.concurrent.FutureTask.run(FutureTask.java:266)
[info]   at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
[info]   at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
[info]   at java.lang.Thread.run(Thread.java:745)

Basically, we also need to consider named_struct and struct expressions to get corner cases correct.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix done. Thanks for liancheng's remind.
Here I considered the CreateStruct(Unsafe) and CreateNamedStruct(Unsafe), other expressions in complexTypeCreator(CreateArray, CreateMap) just ignore.


test("[SPARK-4502] pruning nested schema while named_struct in project") {
val schema = new StructType()
.add("f0", IntegerType)
.add("f1", new StructType()
.add("f10", IntegerType))

val expr = GetStructField(
CreateNamedStruct(Seq(
Literal("f10"),
AttributeReference("f0", IntegerType)()
)),
0,
Some("f10")
)

val expect = new StructType()
.add("f0", IntegerType)

assert(FileSourceStrategy.generateStructFieldsContainsNesting(expr :: Nil, schema) == expect)
}

test("spark.files.ignoreCorruptFiles should work in SQL") {
val inputFile = File.createTempFile("input-", ".gz")
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,36 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
}
}

test("SPARK-4502 parquet nested fields pruning") {
// Schema of "test-data/nested-array-struct.parquet":
// root
// |-- col: struct (nullable = true)
// | |-- s1: struct (nullable = true)
// | | |-- s1_1: long (nullable = true)
// | | |-- s1_2: long (nullable = true)
// | |-- str: string (nullable = true)
// |-- num: long (nullable = true)
// |-- str: string (nullable = true)
withTempView("tmp_table") {
val df = readResourceParquetFile("test-data/nested-struct.snappy.parquet")
df.createOrReplaceTempView("tmp_table")
// normal test
val query1 = "select num,col.s1.s1_1 from tmp_table"
val result1 = sql(query1)
withSQLConf(SQLConf.PARQUET_NESTED_COLUMN_PRUNING.key -> "true") {
checkAnswer(sql(query1), result1)
}
// test for same struct meta merge
// col.s1.s1_1 and col.str should merge
// like col.[s1.s1_1, str] before pass to parquet
val query2 = "select col.s1.s1_1,col.str from tmp_table"
val result2 = sql(query2)
withSQLConf(SQLConf.PARQUET_NESTED_COLUMN_PRUNING.key -> "true") {
checkAnswer(sql(query2), result2)
}
}
}

test("expand UDT in StructType") {
val schema = new StructType().add("n", new NestedStructUDT, nullable = true)
val expected = new StructType().add("n", new NestedStructUDT().sqlType, nullable = true)
Expand Down