Skip to content

Commit

Permalink
All-fields as an arg of aggregator count() can be resolved after othe…
Browse files Browse the repository at this point in the history
…r fields

Signed-off-by: Lantao Jin <ltjin@amazon.com>
  • Loading branch information
LantaoJin committed Oct 25, 2024
1 parent 54f4fa7 commit 4b78cc5
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1122,4 +1122,166 @@ class FlintSparkPPLAggregationsITSuite

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test count() at the first of stats clause") {
val frame = sql(s"""
| source = $testTable | eval a = 1 | stats count() as cnt, sum(a) as sum, avg(a) as avg
| """.stripMargin)
assertSameRows(Seq(Row(4, 4, 1.0)), frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table)
val sum =
Alias(
UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"sum")()
val avg =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"avg")()
val count =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false),
"cnt")()
val aggregate = Aggregate(Seq.empty, Seq(count, sum, avg), eval)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test count() in the middle of stats clause") {
val frame = sql(s"""
| source = $testTable | eval a = 1 | stats sum(a) as sum, count() as cnt, avg(a) as avg
| """.stripMargin)
assertSameRows(Seq(Row(4, 4, 1.0)), frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table)
val sum =
Alias(
UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"sum")()
val avg =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"avg")()
val count =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false),
"cnt")()
val aggregate = Aggregate(Seq.empty, Seq(sum, count, avg), eval)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test count() at the end of stats clause") {
val frame = sql(s"""
| source = $testTable | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt
| """.stripMargin)
assertSameRows(Seq(Row(4, 1.0, 4)), frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table)
val sum =
Alias(
UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"sum")()
val avg =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"avg")()
val count =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false),
"cnt")()
val aggregate = Aggregate(Seq.empty, Seq(sum, avg, count), eval)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test count() at the first of stats by clause") {
val frame = sql(s"""
| source = $testTable | eval a = 1 | stats count() as cnt, sum(a) as sum, avg(a) as avg by country
| """.stripMargin)
assertSameRows(Seq(Row(2, 2, 1.0, "Canada"), Row(2, 2, 1.0, "USA")), frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table)
val sum =
Alias(
UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"sum")()
val avg =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"avg")()
val count =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false),
"cnt")()
val grouping =
Alias(UnresolvedAttribute("country"), "country")()
val aggregate = Aggregate(Seq(grouping), Seq(count, sum, avg, grouping), eval)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test count() in the middle of stats by clause") {
val frame = sql(s"""
| source = $testTable | eval a = 1 | stats sum(a) as sum, count() as cnt, avg(a) as avg by country
| """.stripMargin)
assertSameRows(Seq(Row(2, 2, 1.0, "Canada"), Row(2, 2, 1.0, "USA")), frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table)
val sum =
Alias(
UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"sum")()
val avg =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"avg")()
val count =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false),
"cnt")()
val grouping =
Alias(UnresolvedAttribute("country"), "country")()
val aggregate = Aggregate(Seq(grouping), Seq(sum, count, avg, grouping), eval)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test count() at the end of stats by clause") {
val frame = sql(s"""
| source = $testTable | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt by country
| """.stripMargin)
assertSameRows(Seq(Row(2, 1.0, 2, "Canada"), Row(2, 1.0, 2, "USA")), frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table)
val sum =
Alias(
UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"sum")()
val avg =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"avg")()
val count =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false),
"cnt")()
val grouping =
Alias(UnresolvedAttribute("country"), "country")()
val aggregate = Aggregate(Seq(grouping), Seq(sum, avg, count, grouping), eval)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -662,11 +662,7 @@ public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContex

@Override
public Expression visitAllFields(AllFields node, CatalystPlanContext context) {
// Case of aggregation step - no start projection can be added
if (context.getNamedParseExpressions().isEmpty()) {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.<Seq<String>>empty()));
}
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.<Seq<String>>empty()));
return context.getNamedParseExpressions().peek();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -959,4 +959,58 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite

comparePlans(expectedPlan, logPlan, false)
}

test("test count() as the last aggregator in stats clause") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt"),
context)
val tableRelation = UnresolvedRelation(Seq("table"))
val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), tableRelation)
val sum =
Alias(
UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"sum")()
val avg =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"avg")()
val count =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false),
"cnt")()
val aggregate = Aggregate(Seq.empty, Seq(sum, avg, count), eval)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test count() as the last aggregator in stats by clause") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt by country"),
context)
val tableRelation = UnresolvedRelation(Seq("table"))
val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), tableRelation)
val sum =
Alias(
UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"sum")()
val avg =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false),
"avg")()
val count =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false),
"cnt")()
val grouping =
Alias(UnresolvedAttribute("country"), "country")()
val aggregate = Aggregate(Seq(grouping), Seq(sum, avg, count, grouping), eval)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}
}

0 comments on commit 4b78cc5

Please sign in to comment.