Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.checkerframework.checker.nullness.qual.Nullable;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.AllFields;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Field;
Expand Down Expand Up @@ -294,44 +293,9 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {

context.relBuilder.aggregate(context.relBuilder.groupKey(groupByList), aggList);

// handle normal aggregate
// TODO Should we keep alignment with V2 behaviour in new Calcite implementation?
// TODO how about add a legacy enable config to control behaviour in Calcite?
// Some behaviours between PPL and Databases are different.
// As an example, in command `stats count() by colA, colB`:
// 1. the sequence of output schema is different:
// In PPL v2, the sequence of output schema is "count, colA, colB".
// But in most databases, the sequence of output schema is "colA, colB, count".
// 2. the output order is different:
// In PPL v2, the order of output results is ordered by "colA + colB".
// But in most databases, the output order is random.
// User must add ORDER BY clause after GROUP BY clause to keep the results aligning.
// Following logic is to align with the PPL legacy behaviour.

// alignment for 1.sequence of output schema: adding order-by
// we use the groupByList instead of node.getSortExprList as input because
// the groupByList may include span column.
node.getGroupExprList()
.forEach(
g -> {
// node.getGroupExprList() should all be instance of Alias
// which defined in AstBuilder.
assert g instanceof Alias;
});
List<String> aliasesFromGroupByList =
groupByList.stream()
.map(this::extractAliasLiteral)
.flatMap(Optional::stream)
.map(ref -> ((RexLiteral) ref).getValueAs(String.class))
.toList();
List<RexNode> aliasedGroupByList =
aliasesFromGroupByList.stream()
.map(context.relBuilder::field)
.map(f -> (RexNode) f)
.toList();
context.relBuilder.sort(aliasedGroupByList);

// alignment for 2.the output order: schema reordering
// schema reordering
// As an example, in command `stats count() by colA, colB`,
// the sequence of output schema is "count, colA, colB".
List<RexNode> outputFields = context.relBuilder.fields();
int numOfOutputFields = outputFields.size();
int numOfAggList = aggList.size();
Expand All @@ -341,6 +305,14 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
outputFields.subList(numOfOutputFields - numOfAggList, numOfOutputFields);
reordered.addAll(aggRexList);
// Add group by columns
List<RexNode> aliasedGroupByList =
groupByList.stream()
.map(this::extractAliasLiteral)
.flatMap(Optional::stream)
.map(ref -> ((RexLiteral) ref).getValueAs(String.class))
.map(context.relBuilder::field)
.map(f -> (RexNode) f)
.toList();
reordered.addAll(aliasedGroupByList);
context.relBuilder.project(reordered);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import static org.opensearch.sql.util.MatcherUtils.rows;
import static org.opensearch.sql.util.MatcherUtils.schema;
import static org.opensearch.sql.util.MatcherUtils.verifyDataRows;
import static org.opensearch.sql.util.MatcherUtils.verifyDataRowsInOrder;
import static org.opensearch.sql.util.MatcherUtils.verifySchema;
import static org.opensearch.sql.util.MatcherUtils.verifySchemaInOrder;

Expand Down Expand Up @@ -124,7 +123,7 @@ public void testAvgByMultipleFields() {
schema("avg(balance)", "double"),
schema("gender", "string"),
schema("city", "string"));
verifyDataRowsInOrder(
verifyDataRows(
actual1,
rows(40540.0, "F", "Nicholson"),
rows(32838.0, "F", "Nogal"),
Expand All @@ -142,7 +141,7 @@ public void testAvgByMultipleFields() {
schema("avg(balance)", "double"),
schema("city", "string"),
schema("gender", "string"));
verifyDataRowsInOrder(
verifyDataRows(
actual2,
rows(39225.0, "Brogan", "M"),
rows(5686.0, "Dante", "M"),
Expand All @@ -165,7 +164,7 @@ public void testStatsBySpanAndMultipleFields() throws IOException {
schema("span(age,10)", null, "integer"),
schema("gender", null, "string"),
schema("state", null, "string"));
verifyDataRowsInOrder(
verifyDataRows(
response,
rows(1, 20, "F", "VA"),
rows(1, 30, "F", "IN"),
Expand All @@ -178,7 +177,7 @@ public void testStatsBySpanAndMultipleFields() throws IOException {

@Test
public void testStatsByMultipleFieldsAndSpan() throws IOException {
// Use verifySchemaInOrder() and verifyDataRowsInOrder() to check that the span column is always
// Use verifySchemaInOrder() and verifyDataRows() to check that the span column is always
// the first column in result whatever the order of span in query is first or last one
JSONObject response =
executeQuery(
Expand All @@ -190,7 +189,7 @@ public void testStatsByMultipleFieldsAndSpan() throws IOException {
schema("span(age,10)", null, "integer"),
schema("gender", null, "string"),
schema("state", null, "string"));
verifyDataRowsInOrder(
verifyDataRows(
response,
rows(1, 20, "F", "VA"),
rows(1, 30, "F", "IN"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ public void testLeftJoinWithRelationSubquery() {
schema("b.country", "string"),
schema("age_span", "integer"),
schema("avg(salary)", "double"));
verifyDataRowsInOrder(
verifyDataRows(
actual, rows(70000.0, 30, "USA"), rows(null, 40, null), rows(100000, 70, "England"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ public UnresolvedPlan visitStatsCommand(StatsCommandContext ctx) {
Aggregation aggregation =
new Aggregation(
aggListBuilder.build(),
groupList,
Collections.emptyList(),
groupList,
span,
ArgumentFactory.getArgumentList(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ public void verifyResultCount(RelNode rel, int expectedRows) {

/** Verify the generated Spark SQL of the given RelNode */
public void verifyPPLToSparkSQL(RelNode rel, String expected) {
String normalized = expected.replace("\n", System.lineSeparator());
SqlImplementor.Result result = converter.visitRoot(rel);
final SqlNode sqlNode = result.asStatement();
final String sql = sqlNode.toSqlString(SparkSqlDialect.DEFAULT).getSql();
assertThat(sql, is(expected));
assertThat(sql, is(normalized));
}
}
Loading
Loading