Skip to content

[SPARK-34638][SQL] Single field nested column prune on generator output #31966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ case class ProjectionOverSchema(schema: StructType) {
getProjection(a.child).map(p => (p, p.dataType)).map {
case (projection, ArrayType(projSchema @ StructType(_), _)) =>
// For case-sensitivity aware field resolution, we should take `ordinal` which
// points to correct struct field.
// points to correct struct field, because `ExtractValue` actually does column
// name resolving correctly.
val selectedField = a.child.dataType.asInstanceOf[ArrayType]
.elementType.asInstanceOf[StructType](a.ordinal)
val prunedField = projSchema(selectedField.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,27 @@ object NestedColumnAliasing {
* of it.
*/
object GeneratorNestedColumnAliasing {
// Partitions `attrToAliases` based on whether the attribute is in Generator's output.
private def aliasesOnGeneratorOutput(
attrToAliases: Map[ExprId, Seq[Alias]],
generatorOutput: Seq[Attribute]) = {
val generatorOutputExprId = generatorOutput.map(_.exprId)
attrToAliases.partition { k =>
generatorOutputExprId.contains(k._1)
}
}

// Partitions `nestedFieldToAlias` based on whether the attribute of nested field extractor
// is in Generator's output.
private def nestedFieldOnGeneratorOutput(
nestedFieldToAlias: Map[ExtractValue, Alias],
generatorOutput: Seq[Attribute]) = {
val generatorOutputSet = AttributeSet(generatorOutput)
nestedFieldToAlias.partition { pair =>
pair._1.references.subsetOf(generatorOutputSet)
}
}

def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
// Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we
// need to prune nested columns through Project and under Generate. The difference is
Expand All @@ -241,12 +262,81 @@ object GeneratorNestedColumnAliasing {
// On top on `Generate`, a `Project` that might have nested column accessors.
// We try to get alias maps for both project list and generator's children expressions.
val exprsToPrune = projectList ++ g.generator.children
NestedColumnAliasing.getAliasSubMap(exprsToPrune, g.qualifiedGeneratorOutput).map {
NestedColumnAliasing.getAliasSubMap(exprsToPrune).map {
case (nestedFieldToAlias, attrToAliases) =>
val (nestedFieldsOnGenerator, nestedFieldsNotOnGenerator) =
nestedFieldOnGeneratorOutput(nestedFieldToAlias, g.qualifiedGeneratorOutput)
val (attrToAliasesOnGenerator, attrToAliasesNotOnGenerator) =
aliasesOnGeneratorOutput(attrToAliases, g.qualifiedGeneratorOutput)

// Push nested column accessors through `Generator`.
// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
val newChild =
NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild)
val newChild = NestedColumnAliasing.replaceWithAliases(g,
nestedFieldsNotOnGenerator, attrToAliasesNotOnGenerator)
val pushedThrough = Project(NestedColumnAliasing
.getNewProjectList(projectList, nestedFieldsNotOnGenerator), newChild)

// If the generator output is `ArrayType`, we cannot push through the extractor.
// It is because we don't allow field extractor on two-level array,
// i.e., attr.field when attr is a ArrayType(ArrayType(...)).
// Similarily, we also cannot push through if the child of generator is `MapType`.
g.generator.children.head.dataType match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Let me play more with this PR for a while. It seems I need more tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

case _: MapType => return Some(pushedThrough)
case ArrayType(_: ArrayType, _) => return Some(pushedThrough)
case _ =>
}

// Pruning on `Generator`'s output. We only process single field case.
// For multiple field case, we cannot directly move field extractor into
// the generator expression. A workaround is to re-construct array of struct
// from multiple fields. But it will be more complicated and may not worth.
// TODO(SPARK-34956): support multiple fields.
if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty) {
pushedThrough
} else {
// Only one nested column accessor.
// E.g., df.select(explode($"items").as("item")).select($"item.a")
pushedThrough match {
case p @ Project(_, newG: Generate) =>
// Replace the child expression of `ExplodeBase` generator with
// nested column accessor.
// E.g., df.select(explode($"items").as("item")).select($"item.a") =>
// df.select(explode($"items.a").as("item.a"))
val rewrittenG = newG.transformExpressions {
case e: ExplodeBase =>
val extractor = nestedFieldsOnGenerator.head._1.transformUp {
case _: Attribute =>
e.child
case g: GetStructField =>
ExtractValue(g.child, Literal(g.extractFieldName), SQLConf.get.resolver)
}
e.withNewChildren(Seq(extractor))
}

// As we change the child of the generator, its output data type must be updated.
val updatedGeneratorOutput = rewrittenG.generatorOutput
.zip(rewrittenG.generator.elementSchema.toAttributes)
.map { case (oldAttr, newAttr) =>
newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
}
assert(updatedGeneratorOutput.length == rewrittenG.generatorOutput.length,
"Updated generator output must have the same length " +
"with original generator output.")
val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput)

// Replace nested column accessor with generator output.
p.withNewChildren(Seq(updatedGenerate)).transformExpressions {
case f: ExtractValue if nestedFieldsOnGenerator.contains(f) =>
updatedGenerate.output
.find(a => attrToAliasesOnGenerator.contains(a.exprId))
.getOrElse(f)
}

case other =>
// We should not reach here.
throw new IllegalStateException(s"Unreasonable plan after optimization: $other")
}
}
}

case g: Generate if SQLConf.get.nestedSchemaPruningEnabled &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,14 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
comparePlans(optimized, expected)
}

test("Nested field pruning for Project and Generate: not prune on generator output") {
test("Nested field pruning for Project and Generate: multiple-field case is not supported") {
val companies = LocalRelation(
'id.int,
'employers.array(employer))

val query = companies
.generate(Explode('employers.getField("company")), outputNames = Seq("company"))
.select('company.getField("name"))
.select('company.getField("name"), 'company.getField("address"))
.analyze
val optimized = Optimize.execute(query)

Expand All @@ -347,7 +347,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
.generate(Explode($"${aliases(0)}"),
unrequiredChildIndex = Seq(0),
outputNames = Seq("company"))
.select('company.getField("name").as("company.name"))
.select('company.getField("name").as("company.name"),
'company.getField("address").as("company.address"))
.analyze
comparePlans(optimized, expected)
}
Expand Down Expand Up @@ -684,6 +685,29 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
).analyze
comparePlans(optimized2, expected2)
}

test("SPARK-34638: nested column prune on generator output for one field") {
val companies = LocalRelation(
'id.int,
'employers.array(employer))

val query = companies
.generate(Explode('employers.getField("company")), outputNames = Seq("company"))
.select('company.getField("name"))
.analyze
val optimized = Optimize.execute(query)

val aliases = collectGeneratedAliases(optimized)

val expected = companies
.select('employers.getField("company").getField("name").as(aliases(0)))
.generate(Explode($"${aliases(0)}"),
unrequiredChildIndex = Seq(0),
outputNames = Seq("company"))
.select('company.as("company.name"))
.analyze
comparePlans(optimized, expected)
}
}

object NestedColumnAliasingSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,43 @@ abstract class SchemaPruningSuite
}
}

testSchemaPruning("SPARK-34638: nested column prune on generator output") {
val query1 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
.select("friend.first")
checkScan(query1, "struct<friends:array<struct<first:string>>>")
checkAnswer(query1, Row("Susan") :: Nil)

// Currently we don't prune multiple field case.
val query2 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
.select("friend.first", "friend.middle")
checkScan(query2, "struct<friends:array<struct<first:string,middle:string,last:string>>>")
checkAnswer(query2, Row("Susan", "Z.") :: Nil)

val query3 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
.select("friend.first", "friend.middle", "friend")
checkScan(query3, "struct<friends:array<struct<first:string,middle:string,last:string>>>")
checkAnswer(query3, Row("Susan", "Z.", Row("Susan", "Z.", "Smith")) :: Nil)
}

testSchemaPruning("SPARK-34638: nested column prune on generator output - case-sensitivity") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
val query1 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
.select("friend.First")
checkScan(query1, "struct<friends:array<struct<first:string>>>")
checkAnswer(query1, Row("Susan") :: Nil)

val query2 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
.select("friend.MIDDLE")
checkScan(query2, "struct<friends:array<struct<middle:string>>>")
checkAnswer(query2, Row("Z.") :: Nil)
}
}

testSchemaPruning("select one deep nested complex field after repartition") {
val query = sql("select * from contacts")
.repartition(100)
Expand Down Expand Up @@ -816,4 +853,21 @@ abstract class SchemaPruningSuite
Row("John", "Y.") :: Nil)
}
}

test("SPARK-34638: queries should not fail on unsupported cases") {
withTable("nested_array") {
sql("select * from values array(array(named_struct('a', 1, 'b', 3), " +
"named_struct('a', 2, 'b', 4))) T(items)").write.saveAsTable("nested_array")
val query = sql("select d.a from (select explode(c) d from " +
"(select explode(items) c from nested_array))")
checkAnswer(query, Row(1) :: Row(2) :: Nil)
}

withTable("map") {
sql("select * from values map(1, named_struct('a', 1, 'b', 3), " +
"2, named_struct('a', 2, 'b', 4)) T(items)").write.saveAsTable("map")
val query = sql("select d.a from (select explode(items) (c, d) from map)")
checkAnswer(query, Row(1) :: Row(2) :: Nil)
}
}
}