From ba7b48401d60fe610aebd4df35eb014e8f429200 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Thu, 21 Dec 2023 13:49:27 -0800 Subject: [PATCH] feat(isthmus): improved Calcite support for Substrait Aggregate rels Substrait Aggregates that contain expressions that are not field references and/or grouping keys that are not in input order require extra processing to be converted to Calcite Aggregates successfully AND correctly --- .../substrait/isthmus/AggregateValidator.java | 224 ++++++++++++++++++ .../isthmus/SubstraitRelNodeConverter.java | 24 +- .../isthmus/SubstraitRelVisitor.java | 2 + .../isthmus/ComplexAggregateTest.java | 186 +++++++++++++++ 4 files changed, 430 insertions(+), 6 deletions(-) create mode 100644 isthmus/src/main/java/io/substrait/isthmus/AggregateValidator.java create mode 100644 isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateValidator.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateValidator.java new file mode 100644 index 000000000..df800db52 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateValidator.java @@ -0,0 +1,224 @@ +package io.substrait.isthmus; + +import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; +import io.substrait.expression.FunctionArg; +import io.substrait.expression.ImmutableExpression; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.relation.Aggregate; +import io.substrait.relation.Project; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Not all Substrait {@link Aggregate} rels are convertable to {@link + * org.apache.calcite.rel.core.Aggregate} rels + * + *

The code in this class can: + * + *

+ */ +public class AggregateValidator { + + /** + * Checks that the given {@link Aggregate} is valid for use in Calcite + * + * @param aggregate + * @return + */ + public static boolean isValidCalciteAggregate(Aggregate aggregate) { + return aggregate.getMeasures().stream().allMatch(AggregateValidator::isValidCalciteMeasure) + && aggregate.getGroupings().stream().allMatch(AggregateValidator::isValidCalciteGrouping); + } + + /** + * Checks that all expressions present in the given {@link Aggregate.Measure} are {@link + * FieldReference}s, as Calcite expects all expressions in {@link + * org.apache.calcite.rel.core.Aggregate}s to be field references. + * + * @return true if the {@code measure} can be converted to a Calcite equivalent without changes, + * false otherwise. + */ + private static boolean isValidCalciteMeasure(Aggregate.Measure measure) { + return + // all function arguments to measures must be field references + measure.getFunction().arguments().stream().allMatch(farg -> isSimpleFieldReference(farg)) + && + // all sort fields must be field references + measure.getFunction().sort().stream().allMatch(sf -> isSimpleFieldReference(sf.expr())) + && + // pre-measure filter must be a field reference + measure.getPreMeasureFilter().map(f -> isSimpleFieldReference(f)).orElse(true); + } + + /** + * Checks that all expressions present in the given {@link Aggregate.Grouping} are {@link + * FieldReference}s, as Calcite expects all expressions in {@link + * org.apache.calcite.rel.core.Aggregate}s to be field references. + * + *

Additionally, checks that all grouping fields are specified in ascending order. + * + * @return true if the {@code grouping} can be converted to a Calcite equivalent without changes, + * false otherwise. + */ + private static boolean isValidCalciteGrouping(Aggregate.Grouping grouping) { + if (!grouping.getExpressions().stream().allMatch(e -> isSimpleFieldReference(e))) { + // all grouping expressions must be field references + return false; + } + + // Calcite stores grouping fields in an ImmutableBitSet and does not track the order of the + // grouping fields. The output record shape that Calcite generates ALWAYS has the groupings in + // ascending field order. This causes issues with Substrait in cases where the grouping fields + // in Substrait are not defined in ascending order. + + // For example, if a grouping is defined as (0, 2, 1) in Substrait, Calcite will output it as + // (0, 1, 2), which means that the Calcite output will no longer line up with the expectations + // of the Substrait plan. + List groupingFields = + grouping.getExpressions().stream() + // isSimpleFieldReference above guarantees that the expr is a FieldReference + .map(expr -> getFieldRefOffset((FieldReference) expr)) + .collect(Collectors.toList()); + + return isOrdered(groupingFields); + } + + private static boolean isSimpleFieldReference(FunctionArg e) { + return e instanceof FieldReference fr + && fr.segments().size() == 1 + && fr.segments().get(0) instanceof FieldReference.StructField; + } + + private static int getFieldRefOffset(FieldReference fr) { + return ((FieldReference.StructField) fr.segments().get(0)).offset(); + } + + private static boolean isOrdered(List list) { + for (int i = 1; i < list.size(); i++) { + if (list.get(i - 1) > list.get(i)) { + return false; + } + } + return true; + } + + public static class AggregateTransformer { + + // New expressions to include in the project before the aggregate + final List newExpressions; + + // Tracks the offset of the next expression added + int expressionOffset; + + private AggregateTransformer(Aggregate aggregate) { + this.newExpressions = new ArrayList<>(); + // The Substrait project output includes all input fields, followed by expressions + this.expressionOffset = aggregate.getInput().getRecordType().fields().size(); + } + + /** + * Transforms an {@link Aggregate} that cannot be handled by Calcite into an equivalent that can + * be handled by: + * + *

+ */ + public static Aggregate transformToValidCalciteAggregate(Aggregate aggregate) { + var at = new AggregateTransformer(aggregate); + + List newMeasures = + aggregate.getMeasures().stream().map(at::updateMeasure).collect(Collectors.toList()); + List newGroupings = + aggregate.getGroupings().stream().map(at::updateGrouping).collect(Collectors.toList()); + + Project preAggregateProject = + Project.builder().input(aggregate.getInput()).expressions(at.newExpressions).build(); + + return Aggregate.builder() + .from(aggregate) + .input(preAggregateProject) + .measures(newMeasures) + .groupings(newGroupings) + .build(); + } + + private Aggregate.Measure updateMeasure(Aggregate.Measure measure) { + AggregateFunctionInvocation oldAggregateFunctionInvocation = measure.getFunction(); + + List newFunctionArgs = + oldAggregateFunctionInvocation.arguments().stream() + .map(this::projectOutNonFieldReference) + .collect(Collectors.toList()); + + List newSortFields = + oldAggregateFunctionInvocation.sort().stream() + .map( + sf -> + Expression.SortField.builder() + .from(sf) + .expr(projectOutNonFieldReference(sf.expr())) + .build()) + .collect(Collectors.toList()); + + Optional newPreMeasureFilter = + measure.getPreMeasureFilter().map(this::projectOutNonFieldReference); + + AggregateFunctionInvocation newAggregateFunctionInvocation = + AggregateFunctionInvocation.builder() + .from(oldAggregateFunctionInvocation) + .arguments(newFunctionArgs) + .sort(newSortFields) + .build(); + + return Aggregate.Measure.builder() + .function(newAggregateFunctionInvocation) + .preMeasureFilter(newPreMeasureFilter) + .build(); + } + + private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) { + // project out all groupings unconditionally, even field references + // this ensures that out of order groupings are re-projected into in order groupings + List newGroupingExpressions = + grouping.getExpressions().stream().map(this::projectOut).collect(Collectors.toList()); + return Aggregate.Grouping.builder().expressions(newGroupingExpressions).build(); + } + + private Expression projectOutNonFieldReference(FunctionArg farg) { + if ((farg instanceof Expression e)) { + return projectOutNonFieldReference(e); + } else { + throw new IllegalArgumentException("cannot handle non-expression argument for aggregate"); + } + } + + private Expression projectOutNonFieldReference(Expression expr) { + if (isSimpleFieldReference(expr)) { + return expr; + } + return projectOut(expr); + } + + /** + * Adds a new expression to the project at {@link AggregateTransformer#expressionOffset} and + * returns a field reference to the new expression + */ + private Expression projectOut(Expression expr) { + newExpressions.add(expr); + return ImmutableFieldReference.builder() + // create a field reference to the new expression, then update the expression offset + .addSegments(FieldReference.StructField.of(expressionOffset++)) + .type(expr.getType()) + .build(); + } + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index fb31f8d22..547e483c1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -31,6 +31,7 @@ import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitDef; import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; @@ -227,6 +228,11 @@ public RelNode visit(Set set) throws RuntimeException { @Override public RelNode visit(Aggregate aggregate) throws RuntimeException { + if (!AggregateValidator.isValidCalciteAggregate(aggregate)) { + aggregate = + AggregateValidator.AggregateTransformer.transformToValidCalciteAggregate(aggregate); + } + RelNode child = aggregate.getInput().accept(this); var groupExprLists = aggregate.getGroupings().stream() @@ -268,8 +274,8 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { } List argIndex = new ArrayList<>(); for (RexNode arg : arguments) { - // TODO: rewrite compound expression into project Rel - checkRexInputRefOnly(arg, "argument", measure.getFunction().declaration().name()); + // arguments are guaranteed to be RexInputRef because of the prior call to + // AggregateValidator.AggregateTransformer.transformToValidCalciteAggregate argIndex.add(((RexInputRef) arg).getIndex()); } @@ -292,12 +298,18 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { int filterArg = -1; if (measure.getPreMeasureFilter().isPresent()) { RexNode filter = measure.getPreMeasureFilter().get().accept(expressionRexConverter); - // TODO: rewrite compound expression into project Rel - // Calcite's AggregateCall only allow agg filter to be a direct filter from input - checkRexInputRefOnly(filter, "filter", measure.getFunction().declaration().name()); filterArg = ((RexInputRef) filter).getIndex(); } + RelCollation relCollation = RelCollations.EMPTY; + if (!measure.getFunction().sort().isEmpty()) { + relCollation = + RelCollations.of( + measure.getFunction().sort().stream() + .map(sortField -> toRelFieldCollation(sortField)) + .collect(Collectors.toList())); + } + return AggregateCall.create( aggFunction, distinct, @@ -306,7 +318,7 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { argIndex, filterArg, null, - RelCollations.EMPTY, + relCollation, returnType, null); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 5dd07e7a3..ce8612b9b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -271,6 +271,8 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal if (call.filterArg != -1) { builder.preMeasureFilter(FieldReference.newRootStructReference(call.filterArg, inputType)); } + // TODO: handle the collation on the AggregateCall + // https://github.com/substrait-io/substrait-java/issues/215 return builder.build(); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java new file mode 100644 index 000000000..fa3b9c3c8 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java @@ -0,0 +1,186 @@ +package io.substrait.isthmus; + +import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.expression.Expression; +import io.substrait.relation.Aggregate; +import io.substrait.relation.NamedScan; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class ComplexAggregateTest extends PlanTestBase { + + final TypeCreator R = TypeCreator.of(false); + SubstraitBuilder b = new SubstraitBuilder(extensions); + + /** + * Check that: + * + *
    + *
  1. The {@code pojo} pojo given is transformed as expected by {@link + * AggregateValidator.AggregateTransformer#transformToValidCalciteAggregate} + *
  2. The {@code} (original) pojo can be converted to Calcite without issues + *
+ * + * @param pojo a pojo that requires transformation for use in Calcite + * @param expectedTransform the expected transformation output + */ + protected void validateAggregateTransformation(Aggregate pojo, Rel expectedTransform) { + var converterPojo = + AggregateValidator.AggregateTransformer.transformToValidCalciteAggregate(pojo); + assertEquals(expectedTransform, converterPojo); + + // Substrait POJO -> Calcite + new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo); + } + + private List columnTypes = List.of(R.I32, R.I32, R.I32, R.I32); + private List columnNames = List.of("a", "b", "c", "d"); + private NamedScan table = b.namedScan(List.of("example"), columnNames, columnTypes); + + private Aggregate.Grouping emptyGrouping = Aggregate.Grouping.builder().build(); + + @Test + void handleComplexMeasureArgument() { + // SELECT sum(c + 7) FROM example + var rel = + b.aggregate( + input -> emptyGrouping, + input -> List.of(b.sum(b.add(b.fieldReference(input, 2), b.i32(7)))), + table); + + var expectedFinal = + b.aggregate( + input -> emptyGrouping, + // sum call references input field + input -> List.of(b.sum(input, 4)), + b.project( + // add call is moved to child project + input -> List.of(b.add(b.fieldReference(input, 2), b.i32(7))), + table)); + + validateAggregateTransformation(rel, expectedFinal); + } + + @Test + void handleComplexPreMeasureFilter() { + // SELECT sum(a) FILTER (b = 42) FROM example + var rel = + b.aggregateM( + input -> emptyGrouping, + input -> + List.of(b.measure(b.sum(input, 0), b.equal(b.fieldReference(input, 1), b.i32(42)))), + table); + + var expectedFinal = + b.aggregateM( + input -> emptyGrouping, + input -> List.of(b.measure(b.sum(input, 0), b.fieldReference(input, 4))), + b.project(input -> List.of(b.equal(b.fieldReference(input, 1), b.i32(42))), table)); + + validateAggregateTransformation(rel, expectedFinal); + } + + @Test + void handleComplexSortingArguments() { + // SELECT sum(d ORDER BY -b ASC) FROM example + var rel = + b.aggregate( + input -> emptyGrouping, + input -> + List.of( + AggregateFunctionInvocation.builder() + .from(b.sum(input, 3)) + .sort( + List.of( + b.sortField( + b.negate(b.fieldReference(input, 1)), + Expression.SortDirection.ASC_NULLS_FIRST))) + .build()), + table); + + var expectedFinal = + b.aggregate( + input -> emptyGrouping, + input -> + List.of( + AggregateFunctionInvocation.builder() + // sum input does not need to be rewritten + .from(b.sum(input, 3)) + .sort( + List.of( + b.sortField( + // sort field references input + b.fieldReference(input, 4), + Expression.SortDirection.ASC_NULLS_FIRST))) + .build()), + b.project( + // negate call is moved to child project + input -> List.of(b.negate(b.fieldReference(input, 1))), + table)); + + validateAggregateTransformation(rel, expectedFinal); + } + + @Test + void handleComplexGroupingArgument() { + var rel = + b.aggregate( + input -> + b.grouping( + b.fieldReference(input, 2), b.add(b.fieldReference(input, 1), b.i32(42))), + input -> List.of(), + table); + + var expectedFinal = + b.aggregate( + // grouping exprs are now field references to input + input -> b.grouping(input, 4, 5), + input -> List.of(), + b.project( + input -> + List.of( + b.fieldReference(input, 2), b.add(b.fieldReference(input, 1), b.i32(42))), + table)); + + validateAggregateTransformation(rel, expectedFinal); + } + + @Test + void handleOutOfOrderGroupingArguments() { + var rel = b.aggregate(input -> b.grouping(input, 1, 0, 2), input -> List.of(), table); + + var expectedFinal = + b.aggregate( + // grouping exprs are now field references to input + input -> b.grouping(input, 4, 5, 6), + input -> List.of(), + b.project( + // ALL grouping exprs are added to the child projects (including field references) + input -> + List.of( + b.fieldReference(input, 1), + b.fieldReference(input, 0), + b.fieldReference(input, 2)), + table)); + + validateAggregateTransformation(rel, expectedFinal); + } + + @Test + void outOfOrderGroupingKeysHaveCorrectCalciteType() { + Rel rel = + b.aggregate( + input -> b.grouping(input, 2, 0), + input -> List.of(), + b.namedScan(List.of("foo"), List.of("a", "b", "c"), List.of(R.I64, R.I64, R.STRING))); + var relNode = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(rel); + assertRowMatch(relNode.getRowType(), R.STRING, R.I64); + } +}