@@ -573,12 +573,40 @@ class DatasetSuite extends QueryTest
573573 " a" , " 30" , " b" , " 3" , " c" , " 1" )
574574 }
575575
576+ test(" groupBy, flatMapSorted" ) {
577+ val ds = Seq ((" a" , 1 , 10 ), (" a" , 2 , 20 ), (" b" , 2 , 1 ), (" b" , 1 , 2 ), (" c" , 1 , 1 ))
578+ .toDF(" key" , " seq" , " value" )
579+ val grouped = ds.groupBy($" key" ).as[String , (String , Int , Int )]
580+ val aggregated = grouped.flatMapSortedGroups($" seq" , expr(" length(key)" ), $" value" ) {
581+ (g, iter) => Iterator (g, iter.mkString(" , " ))
582+ }
583+
584+ checkDatasetUnorderly(
585+ aggregated,
586+ " a" , " (a,1,10), (a,2,20)" ,
587+ " b" , " (b,1,2), (b,2,1)" ,
588+ " c" , " (c,1,1)"
589+ )
590+
591+ // Star is not allowed as group sort column
592+ checkError(
593+ exception = intercept[AnalysisException ] {
594+ grouped.flatMapSortedGroups($" *" ) {
595+ (g, iter) => Iterator (g, iter.mkString(" , " ))
596+ }
597+ },
598+ errorClass = " _LEGACY_ERROR_TEMP_1020" ,
599+ parameters = Map (" elem" -> " '*'" , " prettyName" -> " MapGroups" ))
600+ }
601+
576602 test(" groupBy function, flatMapSorted" ) {
577603 val ds = Seq ((" a" , 1 , 10 ), (" a" , 2 , 20 ), (" b" , 2 , 1 ), (" b" , 1 , 2 ), (" c" , 1 , 1 ))
578604 .toDF(" key" , " seq" , " value" )
579- val grouped = ds.groupByKey(v => (v.getString(0 ), " word" ))
580- val aggregated = grouped.flatMapSortedGroups($" seq" , expr(" length(key)" )) {
581- (g, iter) => Iterator (g._1, iter.mkString(" , " ))
605+ // groupByKey Row => String adds key columns `value` to the dataframe
606+ val grouped = ds.groupByKey(v => v.getString(0 ))
607+ // $"value" here is expected to not reference the key column
608+ val aggregated = grouped.flatMapSortedGroups($" seq" , expr(" length(key)" ), $" value" ) {
609+ (g, iter) => Iterator (g, iter.mkString(" , " ))
582610 }
583611
584612 checkDatasetUnorderly(
@@ -587,14 +615,42 @@ class DatasetSuite extends QueryTest
587615 " b" , " [b,1,2], [b,2,1]" ,
588616 " c" , " [c,1,1]"
589617 )
618+
619+ // Star is not allowed as group sort column
620+ checkError(
621+ exception = intercept[AnalysisException ] {
622+ grouped.flatMapSortedGroups($" *" ) {
623+ (g, iter) => Iterator (g, iter.mkString(" , " ))
624+ }
625+ },
626+ errorClass = " _LEGACY_ERROR_TEMP_1020" ,
627+ parameters = Map (" elem" -> " '*'" , " prettyName" -> " MapGroups" ))
628+ }
629+
630+ test(" groupBy, flatMapSorted desc" ) {
631+ val ds = Seq ((" a" , 1 , 10 ), (" a" , 2 , 20 ), (" b" , 2 , 1 ), (" b" , 1 , 2 ), (" c" , 1 , 1 ))
632+ .toDF(" key" , " seq" , " value" )
633+ val grouped = ds.groupBy($" key" ).as[String , (String , Int , Int )]
634+ val aggregated = grouped.flatMapSortedGroups($" seq" .desc, expr(" length(key)" ), $" value" ) {
635+ (g, iter) => Iterator (g, iter.mkString(" , " ))
636+ }
637+
638+ checkDatasetUnorderly(
639+ aggregated,
640+ " a" , " (a,2,20), (a,1,10)" ,
641+ " b" , " (b,2,1), (b,1,2)" ,
642+ " c" , " (c,1,1)"
643+ )
590644 }
591645
592646 test(" groupBy function, flatMapSorted desc" ) {
593647 val ds = Seq ((" a" , 1 , 10 ), (" a" , 2 , 20 ), (" b" , 2 , 1 ), (" b" , 1 , 2 ), (" c" , 1 , 1 ))
594648 .toDF(" key" , " seq" , " value" )
595- val grouped = ds.groupByKey(v => (v.getString(0 ), " word" ))
596- val aggregated = grouped.flatMapSortedGroups($" seq" .desc, expr(" length(key)" )) {
597- (g, iter) => Iterator (g._1, iter.mkString(" , " ))
649+ // groupByKey Row => String adds key columns `value` to the dataframe
650+ val grouped = ds.groupByKey(v => v.getString(0 ))
651+ // $"value" here is expected to not reference the key column
652+ val aggregated = grouped.flatMapSortedGroups($" seq" .desc, expr(" length(key)" ), $" value" ) {
653+ (g, iter) => Iterator (g, iter.mkString(" , " ))
598654 }
599655
600656 checkDatasetUnorderly(
@@ -759,30 +815,30 @@ class DatasetSuite extends QueryTest
759815 1 -> " a" , 2 -> " bc" , 3 -> " d" )
760816 }
761817
762- test(" cogroup sorted" ) {
818+ test(" cogroup with groupBy and sorted" ) {
763819 val left = Seq (1 -> " a" , 3 -> " xyz" , 5 -> " hello" , 3 -> " abc" , 3 -> " ijk" ).toDS()
764820 val right = Seq (2 -> " q" , 3 -> " w" , 5 -> " x" , 5 -> " z" , 3 -> " a" , 5 -> " y" ).toDS()
765- val groupedLeft = left.groupByKey(_._1)
766- val groupedRight = right.groupByKey(_._1)
821+ val groupedLeft = left.groupBy($ " _1 " ).as[ Int , ( Int , String )]
822+ val groupedRight = right.groupBy($ " _1 " ).as[ Int , ( Int , String )]
767823
768824 val neitherSortedExpected = Seq (1 -> " a#" , 2 -> " #q" , 3 -> " xyzabcijk#wa" , 5 -> " hello#xzy" )
769825 val leftSortedExpected = Seq (1 -> " a#" , 2 -> " #q" , 3 -> " abcijkxyz#wa" , 5 -> " hello#xzy" )
770826 val rightSortedExpected = Seq (1 -> " a#" , 2 -> " #q" , 3 -> " xyzabcijk#aw" , 5 -> " hello#xyz" )
771827 val bothSortedExpected = Seq (1 -> " a#" , 2 -> " #q" , 3 -> " abcijkxyz#aw" , 5 -> " hello#xyz" )
772828 val bothDescSortedExpected = Seq (1 -> " a#" , 2 -> " #q" , 3 -> " xyzijkabc#wa" , 5 -> " hello#zyx" )
773829
774- val leftOrder = Seq (left(" _2" ))
775- val rightOrder = Seq (right(" _2" ))
776- val leftDescOrder = Seq (left(" _2" ).desc)
777- val rightDescOrder = Seq (right(" _2" ).desc)
830+ val ascOrder = Seq ($" _2" )
831+ val descOrder = Seq ($" _2" .desc)
832+ val exprOrder = Seq (substring($" _2" , 0 , 1 ))
778833 val none = Seq .empty
779834
780835 Seq (
781836 (" neither" , none, none, neitherSortedExpected),
782- (" left" , leftOrder, none, leftSortedExpected),
783- (" right" , none, rightOrder, rightSortedExpected),
784- (" both" , leftOrder, rightOrder, bothSortedExpected),
785- (" both desc" , leftDescOrder, rightDescOrder, bothDescSortedExpected)
837+ (" left" , ascOrder, none, leftSortedExpected),
838+ (" right" , none, ascOrder, rightSortedExpected),
839+ (" both" , ascOrder, ascOrder, bothSortedExpected),
840+ (" expr" , exprOrder, exprOrder, bothSortedExpected),
841+ (" both desc" , descOrder, descOrder, bothDescSortedExpected)
786842 ).foreach { case (label, leftOrder, rightOrder, expected) =>
787843 withClue(s " $label sorted " ) {
788844 val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder : _* )(rightOrder : _* ) {
@@ -795,6 +851,43 @@ class DatasetSuite extends QueryTest
795851 }
796852 }
797853
854+ test(" cogroup with groupBy function and sorted" ) {
855+ val left = Seq (1 -> " a" , 3 -> " xyz" , 5 -> " hello" , 3 -> " abc" , 3 -> " ijk" ).toDS()
856+ val right = Seq (2 -> " q" , 3 -> " w" , 5 -> " x" , 5 -> " z" , 3 -> " a" , 5 -> " y" ).toDS()
857+ // this groupByKey produces conflicting _1 and _2 columns
858+ // that should be ignored when resolving sort expressions
859+ val groupedLeft = left.groupByKey(row => (row._1, row._1))
860+ val groupedRight = right.groupByKey(row => (row._1, row._1))
861+
862+ val neitherSortedExpected = Seq (1 -> " a#" , 2 -> " #q" , 3 -> " xyzabcijk#wa" , 5 -> " hello#xzy" )
863+ val leftSortedExpected = Seq (1 -> " a#" , 2 -> " #q" , 3 -> " abcijkxyz#wa" , 5 -> " hello#xzy" )
864+ val rightSortedExpected = Seq (1 -> " a#" , 2 -> " #q" , 3 -> " xyzabcijk#aw" , 5 -> " hello#xyz" )
865+ val bothSortedExpected = Seq (1 -> " a#" , 2 -> " #q" , 3 -> " abcijkxyz#aw" , 5 -> " hello#xyz" )
866+ val bothDescSortedExpected = Seq (1 -> " a#" , 2 -> " #q" , 3 -> " xyzijkabc#wa" , 5 -> " hello#zyx" )
867+
868+ val ascOrder = Seq ($" _2" )
869+ val descOrder = Seq ($" _2" .desc)
870+ val exprOrder = Seq (substring($" _2" , 0 , 1 ))
871+ val none = Seq .empty
872+
873+ Seq (
874+ (" neither" , none, none, neitherSortedExpected),
875+ (" left" , ascOrder, none, leftSortedExpected),
876+ (" right" , none, ascOrder, rightSortedExpected),
877+ (" both" , ascOrder, ascOrder, bothSortedExpected),
878+ (" expr" , exprOrder, exprOrder, bothSortedExpected),
879+ (" both desc" , descOrder, descOrder, bothDescSortedExpected)
880+ ).foreach { case (label, leftOrder, rightOrder, expected) =>
881+ withClue(s " $label sorted " ) {
882+ val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder : _* )(rightOrder : _* ) {
883+ (key, left, right) =>
884+ Iterator (key._1 -> (left.map(_._2).mkString + " #" + right.map(_._2).mkString))
885+ }
886+ checkDatasetUnorderly(cogrouped, expected.toList: _* )
887+ }
888+ }
889+ }
890+
798891 test(" SPARK-34806: observation on datasets" ) {
799892 val namedObservation = Observation (" named" )
800893 val unnamedObservation = Observation ()
0 commit comments