Skip to content

Commit 607e753

Browse files
EnricoMicloud-fan
authored andcommitted
[SPARK-38591][SQL][FOLLOW-UP] Fix ambiguous references for sorted cogroups
### What changes were proposed in this pull request? Sort order for left and right cogroups must be resolved against left and right plan, respectively. Otherwise, ambiguous reference exception can be thrown. ### Why are the changes needed? #39640 added sorted groups for `flatMapGroups` and `cogroup`. Sort order for the `cogroup` can be ambiguous when resolved against all children of `CoGroup`: ```Scala leftGroupedDf.cogroup(rightGroupedDf)($"time")($"time") { ... } ``` Grouped DataFrames `leftGroupedDf` and `rightGroupedDf` both contain column `"time"`. Left and right sort order `$"time"` is ambiguous when resolved against all children. They must be resolved against left or right child, exclusively. ### Does this PR introduce _any_ user-facing change? This fixes errors like [AMBIGUOUS_REFERENCE] Reference `time` is ambiguous, could be: [`time`, `time`]. ### How was this patch tested? Tested in `AnalysisSuite` on `Analyzer` level, and E2E in `DatasetSuite`. Closes #39744 from EnricoMi/branch-sorted-groups-ambiguous-reference. Authored-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 7b426ac commit 607e753

File tree

3 files changed

+186
-18
lines changed

3 files changed

+186
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,6 +1593,37 @@ class Analyzer(override val catalogManager: CatalogManager)
15931593
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
15941594
}
15951595

1596+
case mg: MapGroups if mg.dataOrder.exists(!_.resolved) =>
1597+
// Resolve against `AppendColumns`'s children, instead of `AppendColumns`,
1598+
// because `AppendColumns`'s serializer might produce conflict attribute
1599+
// names leading to ambiguous references exception.
1600+
val planForResolve = mg.child match {
1601+
case appendColumns: AppendColumns => appendColumns.child
1602+
case plan => plan
1603+
}
1604+
val resolvedOrder = mg.dataOrder
1605+
.map(resolveExpressionByPlanOutput(_, planForResolve).asInstanceOf[SortOrder])
1606+
mg.copy(dataOrder = resolvedOrder)
1607+
1608+
// Left and right sort expression have to be resolved against the respective child plan only
1609+
case cg: CoGroup if cg.leftOrder.exists(!_.resolved) || cg.rightOrder.exists(!_.resolved) =>
1610+
// Resolve against `AppendColumns`'s children, instead of `AppendColumns`,
1611+
// because `AppendColumns`'s serializer might produce conflict attribute
1612+
// names leading to ambiguous references exception.
1613+
val (leftPlanForResolve, rightPlanForResolve) = Seq(cg.left, cg.right).map {
1614+
case appendColumns: AppendColumns => appendColumns.child
1615+
case plan => plan
1616+
} match {
1617+
case Seq(left, right) => (left, right)
1618+
}
1619+
1620+
val resolvedLeftOrder = cg.leftOrder
1621+
.map(resolveExpressionByPlanOutput(_, leftPlanForResolve).asInstanceOf[SortOrder])
1622+
val resolvedRightOrder = cg.rightOrder
1623+
.map(resolveExpressionByPlanOutput(_, rightPlanForResolve).asInstanceOf[SortOrder])
1624+
1625+
cg.copy(leftOrder = resolvedLeftOrder, rightOrder = resolvedRightOrder)
1626+
15961627
// Skips plan which contains deserializer expressions, as they should be resolved by another
15971628
// rule: ResolveDeserializer.
15981629
case plan if containsDeserializer(plan.expressions) => plan

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.scalatest.matchers.must.Matchers
2828

2929
import org.apache.spark.api.python.PythonEvalType
3030
import org.apache.spark.sql.AnalysisException
31-
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
31+
import org.apache.spark.sql.catalyst.{AliasIdentifier, QueryPlanningTracker, TableIdentifier}
3232
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
3333
import org.apache.spark.sql.catalyst.dsl.expressions._
3434
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -1343,4 +1343,48 @@ class AnalysisSuite extends AnalysisTest with Matchers {
13431343
,
13441344
queryContext = Array(ExpectedContext("SELECT *\nFROM t1\nWHERE 'true'", 31, 59)))
13451345
}
1346+
1347+
test("SPARK-38591: resolve left and right CoGroup sort order on respective side only") {
1348+
def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] = {
1349+
Iterator.empty
1350+
}
1351+
1352+
implicit val intEncoder = ExpressionEncoder[Int]
1353+
1354+
val left = testRelation2.select($"e").analyze
1355+
val right = testRelation3.select($"e").analyze
1356+
val leftWithKey = AppendColumns[Int, Int]((x: Int) => x, left)
1357+
val rightWithKey = AppendColumns[Int, Int]((x: Int) => x, right)
1358+
val order = SortOrder($"e", Ascending)
1359+
1360+
val cogroup = leftWithKey.cogroup[Int, Int, Int, Int](
1361+
rightWithKey,
1362+
func,
1363+
leftWithKey.newColumns,
1364+
rightWithKey.newColumns,
1365+
left.output,
1366+
right.output,
1367+
order :: Nil,
1368+
order :: Nil
1369+
)
1370+
1371+
// analyze the plan
1372+
val actualPlan = getAnalyzer.executeAndCheck(cogroup, new QueryPlanningTracker)
1373+
val cg = actualPlan.collectFirst {
1374+
case cg: CoGroup => cg
1375+
}
1376+
// assert sort order reference only their respective plan
1377+
assert(cg.isDefined)
1378+
cg.foreach { cg =>
1379+
assert(cg.leftOrder != cg.rightOrder)
1380+
1381+
assert(cg.leftOrder.flatMap(_.references).nonEmpty)
1382+
assert(cg.leftOrder.flatMap(_.references).forall(cg.left.output.contains))
1383+
assert(!cg.leftOrder.flatMap(_.references).exists(cg.right.output.contains))
1384+
1385+
assert(cg.rightOrder.flatMap(_.references).nonEmpty)
1386+
assert(cg.rightOrder.flatMap(_.references).forall(cg.right.output.contains))
1387+
assert(!cg.rightOrder.flatMap(_.references).exists(cg.left.output.contains))
1388+
}
1389+
}
13461390
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 110 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -573,12 +573,40 @@ class DatasetSuite extends QueryTest
573573
"a", "30", "b", "3", "c", "1")
574574
}
575575

576+
test("groupBy, flatMapSorted") {
577+
val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1))
578+
.toDF("key", "seq", "value")
579+
val grouped = ds.groupBy($"key").as[String, (String, Int, Int)]
580+
val aggregated = grouped.flatMapSortedGroups($"seq", expr("length(key)"), $"value") {
581+
(g, iter) => Iterator(g, iter.mkString(", "))
582+
}
583+
584+
checkDatasetUnorderly(
585+
aggregated,
586+
"a", "(a,1,10), (a,2,20)",
587+
"b", "(b,1,2), (b,2,1)",
588+
"c", "(c,1,1)"
589+
)
590+
591+
// Star is not allowed as group sort column
592+
checkError(
593+
exception = intercept[AnalysisException] {
594+
grouped.flatMapSortedGroups($"*") {
595+
(g, iter) => Iterator(g, iter.mkString(", "))
596+
}
597+
},
598+
errorClass = "_LEGACY_ERROR_TEMP_1020",
599+
parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"))
600+
}
601+
576602
test("groupBy function, flatMapSorted") {
577603
val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1))
578604
.toDF("key", "seq", "value")
579-
val grouped = ds.groupByKey(v => (v.getString(0), "word"))
580-
val aggregated = grouped.flatMapSortedGroups($"seq", expr("length(key)")) {
581-
(g, iter) => Iterator(g._1, iter.mkString(", "))
605+
// groupByKey Row => String adds key columns `value` to the dataframe
606+
val grouped = ds.groupByKey(v => v.getString(0))
607+
// $"value" here is expected to not reference the key column
608+
val aggregated = grouped.flatMapSortedGroups($"seq", expr("length(key)"), $"value") {
609+
(g, iter) => Iterator(g, iter.mkString(", "))
582610
}
583611

584612
checkDatasetUnorderly(
@@ -587,14 +615,42 @@ class DatasetSuite extends QueryTest
587615
"b", "[b,1,2], [b,2,1]",
588616
"c", "[c,1,1]"
589617
)
618+
619+
// Star is not allowed as group sort column
620+
checkError(
621+
exception = intercept[AnalysisException] {
622+
grouped.flatMapSortedGroups($"*") {
623+
(g, iter) => Iterator(g, iter.mkString(", "))
624+
}
625+
},
626+
errorClass = "_LEGACY_ERROR_TEMP_1020",
627+
parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"))
628+
}
629+
630+
test("groupBy, flatMapSorted desc") {
631+
val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1))
632+
.toDF("key", "seq", "value")
633+
val grouped = ds.groupBy($"key").as[String, (String, Int, Int)]
634+
val aggregated = grouped.flatMapSortedGroups($"seq".desc, expr("length(key)"), $"value") {
635+
(g, iter) => Iterator(g, iter.mkString(", "))
636+
}
637+
638+
checkDatasetUnorderly(
639+
aggregated,
640+
"a", "(a,2,20), (a,1,10)",
641+
"b", "(b,2,1), (b,1,2)",
642+
"c", "(c,1,1)"
643+
)
590644
}
591645

592646
test("groupBy function, flatMapSorted desc") {
593647
val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1))
594648
.toDF("key", "seq", "value")
595-
val grouped = ds.groupByKey(v => (v.getString(0), "word"))
596-
val aggregated = grouped.flatMapSortedGroups($"seq".desc, expr("length(key)")) {
597-
(g, iter) => Iterator(g._1, iter.mkString(", "))
649+
// groupByKey Row => String adds key columns `value` to the dataframe
650+
val grouped = ds.groupByKey(v => v.getString(0))
651+
// $"value" here is expected to not reference the key column
652+
val aggregated = grouped.flatMapSortedGroups($"seq".desc, expr("length(key)"), $"value") {
653+
(g, iter) => Iterator(g, iter.mkString(", "))
598654
}
599655

600656
checkDatasetUnorderly(
@@ -759,30 +815,30 @@ class DatasetSuite extends QueryTest
759815
1 -> "a", 2 -> "bc", 3 -> "d")
760816
}
761817

762-
test("cogroup sorted") {
818+
test("cogroup with groupBy and sorted") {
763819
val left = Seq(1 -> "a", 3 -> "xyz", 5 -> "hello", 3 -> "abc", 3 -> "ijk").toDS()
764820
val right = Seq(2 -> "q", 3 -> "w", 5 -> "x", 5 -> "z", 3 -> "a", 5 -> "y").toDS()
765-
val groupedLeft = left.groupByKey(_._1)
766-
val groupedRight = right.groupByKey(_._1)
821+
val groupedLeft = left.groupBy($"_1").as[Int, (Int, String)]
822+
val groupedRight = right.groupBy($"_1").as[Int, (Int, String)]
767823

768824
val neitherSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#wa", 5 -> "hello#xzy")
769825
val leftSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#wa", 5 -> "hello#xzy")
770826
val rightSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#aw", 5 -> "hello#xyz")
771827
val bothSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#aw", 5 -> "hello#xyz")
772828
val bothDescSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzijkabc#wa", 5 -> "hello#zyx")
773829

774-
val leftOrder = Seq(left("_2"))
775-
val rightOrder = Seq(right("_2"))
776-
val leftDescOrder = Seq(left("_2").desc)
777-
val rightDescOrder = Seq(right("_2").desc)
830+
val ascOrder = Seq($"_2")
831+
val descOrder = Seq($"_2".desc)
832+
val exprOrder = Seq(substring($"_2", 0, 1))
778833
val none = Seq.empty
779834

780835
Seq(
781836
("neither", none, none, neitherSortedExpected),
782-
("left", leftOrder, none, leftSortedExpected),
783-
("right", none, rightOrder, rightSortedExpected),
784-
("both", leftOrder, rightOrder, bothSortedExpected),
785-
("both desc", leftDescOrder, rightDescOrder, bothDescSortedExpected)
837+
("left", ascOrder, none, leftSortedExpected),
838+
("right", none, ascOrder, rightSortedExpected),
839+
("both", ascOrder, ascOrder, bothSortedExpected),
840+
("expr", exprOrder, exprOrder, bothSortedExpected),
841+
("both desc", descOrder, descOrder, bothDescSortedExpected)
786842
).foreach { case (label, leftOrder, rightOrder, expected) =>
787843
withClue(s"$label sorted") {
788844
val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder: _*)(rightOrder: _*) {
@@ -795,6 +851,43 @@ class DatasetSuite extends QueryTest
795851
}
796852
}
797853

854+
test("cogroup with groupBy function and sorted") {
855+
val left = Seq(1 -> "a", 3 -> "xyz", 5 -> "hello", 3 -> "abc", 3 -> "ijk").toDS()
856+
val right = Seq(2 -> "q", 3 -> "w", 5 -> "x", 5 -> "z", 3 -> "a", 5 -> "y").toDS()
857+
// this groupByKey produces conflicting _1 and _2 columns
858+
// that should be ignored when resolving sort expressions
859+
val groupedLeft = left.groupByKey(row => (row._1, row._1))
860+
val groupedRight = right.groupByKey(row => (row._1, row._1))
861+
862+
val neitherSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#wa", 5 -> "hello#xzy")
863+
val leftSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#wa", 5 -> "hello#xzy")
864+
val rightSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#aw", 5 -> "hello#xyz")
865+
val bothSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#aw", 5 -> "hello#xyz")
866+
val bothDescSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzijkabc#wa", 5 -> "hello#zyx")
867+
868+
val ascOrder = Seq($"_2")
869+
val descOrder = Seq($"_2".desc)
870+
val exprOrder = Seq(substring($"_2", 0, 1))
871+
val none = Seq.empty
872+
873+
Seq(
874+
("neither", none, none, neitherSortedExpected),
875+
("left", ascOrder, none, leftSortedExpected),
876+
("right", none, ascOrder, rightSortedExpected),
877+
("both", ascOrder, ascOrder, bothSortedExpected),
878+
("expr", exprOrder, exprOrder, bothSortedExpected),
879+
("both desc", descOrder, descOrder, bothDescSortedExpected)
880+
).foreach { case (label, leftOrder, rightOrder, expected) =>
881+
withClue(s"$label sorted") {
882+
val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder: _*)(rightOrder: _*) {
883+
(key, left, right) =>
884+
Iterator(key._1 -> (left.map(_._2).mkString + "#" + right.map(_._2).mkString))
885+
}
886+
checkDatasetUnorderly(cogrouped, expected.toList: _*)
887+
}
888+
}
889+
}
890+
798891
test("SPARK-34806: observation on datasets") {
799892
val namedObservation = Observation("named")
800893
val unnamedObservation = Observation()

0 commit comments

Comments
 (0)