From 6d94059329ad7f3b0cc2c124bed4f8b22b85551a Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Wed, 8 Mar 2023 15:34:58 -0800 Subject: [PATCH 1/2] feat: support output order remapping (#132) * refactor: expand access to SubstraitRelNodeConverter fields * feat: add deriveRecordType to Cross * feat: substrait builder dsl * feat: self-contained Substrait to Calcite converter * chore: bump junit-jupiter * test: check for application of remappings * feat: apply remaps * refactor: move RelOutputTest to SubstraitRelNodeConverterTest --- build.gradle.kts | 2 +- core/build.gradle.kts | 4 +- .../io/substrait/dsl/SubstraitBuilder.java | 288 ++++++++++++++++++ .../java/io/substrait/relation/Cross.java | 11 + .../substrait/relation/ProtoRelConverter.java | 10 +- .../relation/RelCopyOnWriteVisitor.java | 1 - .../java/io/substrait/type/NamedStruct.java | 2 +- isthmus/build.gradle.kts | 2 +- .../isthmus/SubstraitRelNodeConverter.java | 87 ++++-- .../isthmus/SubstraitRelVisitor.java | 7 +- .../substrait/isthmus/SubstraitToCalcite.java | 119 ++++++++ .../io/substrait/isthmus/PlanTestBase.java | 13 +- .../java/io/substrait/isthmus/RelCreator.java | 2 +- .../SubstraitRelNodeConverterTest.java | 239 +++++++++++++++ 14 files changed, 740 insertions(+), 47 deletions(-) create mode 100644 core/src/main/java/io/substrait/dsl/SubstraitBuilder.java create mode 100644 isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java create mode 100644 isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java diff --git a/build.gradle.kts b/build.gradle.kts index a371caf93..47a9da29f 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -16,7 +16,7 @@ repositories { mavenCentral() } java { toolchain { languageVersion.set(JavaLanguageVersion.of(17)) } } dependencies { - testImplementation("org.junit.jupiter:junit-jupiter-api:5.6.0") + testImplementation("org.junit.jupiter:junit-jupiter-api:5.9.2") testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine") implementation("org.slf4j:slf4j-jdk14:1.7.30") annotationProcessor("org.immutables:value:2.8.8") diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 03c4dc8de..221e7cbd7 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -68,8 +68,8 @@ signing { } dependencies { - testImplementation("org.junit.jupiter:junit-jupiter-api:5.6.0") - testImplementation("org.junit.jupiter:junit-jupiter-params:5.6.0") + testImplementation("org.junit.jupiter:junit-jupiter-api:5.9.2") + testImplementation("org.junit.jupiter:junit-jupiter-params:5.9.2") testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine") implementation("com.google.protobuf:protobuf-java:3.17.3") implementation("com.fasterxml.jackson.core:jackson-databind:2.13.4") diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java new file mode 100644 index 000000000..8ce083110 --- /dev/null +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -0,0 +1,288 @@ +package io.substrait.dsl; + +import com.github.bsideup.jabel.Desugar; +import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.function.SimpleExtension; +import io.substrait.plan.ImmutableRoot; +import io.substrait.plan.Plan; +import io.substrait.proto.AggregateFunction; +import io.substrait.relation.Aggregate; +import io.substrait.relation.Cross; +import io.substrait.relation.Fetch; +import io.substrait.relation.Filter; +import io.substrait.relation.Join; +import io.substrait.relation.NamedScan; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; +import io.substrait.relation.Set; +import io.substrait.relation.Sort; +import io.substrait.type.NamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class SubstraitBuilder { + static final TypeCreator R = TypeCreator.of(false); + static final TypeCreator N = TypeCreator.of(true); + private final SimpleExtension.ExtensionCollection extensions; + + public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) { + this.extensions = extensions; + } + + // Relations + public Aggregate aggregate( + Function groupingFn, + Function> measuresFn, + Rel input) { + Function> groupingsFn = + groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList())); + return aggregate(groupingsFn, measuresFn, Optional.empty(), input); + } + + public Aggregate aggregate( + Function groupingFn, + Function> measuresFn, + Rel.Remap remap, + Rel input) { + Function> groupingsFn = + groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList())); + return aggregate(groupingsFn, measuresFn, Optional.of(remap), input); + } + + private Aggregate aggregate( + Function> groupingsFn, + Function> measuresFn, + Optional remap, + Rel input) { + var groupings = groupingsFn.apply(input); + var measures = + measuresFn.apply(input).stream() + .map(m -> Aggregate.Measure.builder().function(m).build()) + .collect(java.util.stream.Collectors.toList()); + return Aggregate.builder() + .groupings(groupings) + .measures(measures) + .remap(remap) + .input(input) + .build(); + } + + public Cross cross(Rel left, Rel right) { + return cross(left, right, Optional.empty()); + } + + public Cross cross(Rel left, Rel right, Rel.Remap remap) { + return cross(left, right, Optional.of(remap)); + } + + private Cross cross(Rel left, Rel right, Optional remap) { + return Cross.builder().left(left).right(right).remap(remap).build(); + } + + public Fetch fetch(long offset, long count, Rel input) { + return fetch(offset, count, Optional.empty(), input); + } + + public Fetch fetch(long offset, long count, Rel.Remap remap, Rel input) { + return fetch(offset, count, Optional.of(remap), input); + } + + private Fetch fetch(long offset, long count, Optional remap, Rel input) { + return Fetch.builder().offset(offset).count(count).input(input).remap(remap).build(); + } + + public Filter filter(Function conditionFn, Rel input) { + return filter(conditionFn, Optional.empty(), input); + } + + public Filter filter(Function conditionFn, Rel.Remap remap, Rel input) { + return filter(conditionFn, Optional.of(remap), input); + } + + private Filter filter( + Function conditionFn, Optional remap, Rel input) { + var condition = conditionFn.apply(input); + return Filter.builder().input(input).condition(condition).remap(remap).build(); + } + + @Desugar + public record JoinInput(Rel left, Rel right) {} + + public Join innerJoin(Function conditionFn, Rel left, Rel right) { + return join(conditionFn, Join.JoinType.INNER, left, right); + } + + public Join innerJoin( + Function conditionFn, Rel.Remap remap, Rel left, Rel right) { + return join(conditionFn, Join.JoinType.INNER, remap, left, right); + } + + public Join join( + Function conditionFn, Join.JoinType joinType, Rel left, Rel right) { + return join(conditionFn, joinType, Optional.empty(), left, right); + } + + public Join join( + Function conditionFn, + Join.JoinType joinType, + Rel.Remap remap, + Rel left, + Rel right) { + return join(conditionFn, joinType, Optional.of(remap), left, right); + } + + private Join join( + Function conditionFn, + Join.JoinType joinType, + Optional remap, + Rel left, + Rel right) { + var condition = conditionFn.apply(new JoinInput(left, right)); + return Join.builder() + .left(left) + .right(right) + .condition(condition) + .joinType(joinType) + .remap(remap) + .build(); + } + + public NamedScan namedScan( + Iterable tableName, Iterable columnNames, Iterable types) { + return namedScan(tableName, columnNames, types, Optional.empty()); + } + + public NamedScan namedScan( + Iterable tableName, + Iterable columnNames, + Iterable types, + Rel.Remap remap) { + return namedScan(tableName, columnNames, types, Optional.of(remap)); + } + + private NamedScan namedScan( + Iterable tableName, + Iterable columnNames, + Iterable types, + Optional remap) { + var struct = Type.Struct.builder().addAllFields(types).nullable(false).build(); + var namedStruct = NamedStruct.of(columnNames, struct); + return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build(); + } + + public Project project(Function> expressionsFn, Rel input) { + return project(expressionsFn, Optional.empty(), input); + } + + public Project project( + Function> expressionsFn, Rel.Remap remap, Rel input) { + return project(expressionsFn, Optional.of(remap), input); + } + + private Project project( + Function> expressionsFn, + Optional remap, + Rel input) { + var expressions = expressionsFn.apply(input); + return Project.builder().input(input).expressions(expressions).remap(remap).build(); + } + + public Set set(Set.SetOp op, Rel... inputs) { + return set(op, Optional.empty(), inputs); + } + + public Set set(Set.SetOp op, Rel.Remap remap, Rel... inputs) { + return set(op, Optional.of(remap), inputs); + } + + private Set set(Set.SetOp op, Optional remap, Rel... inputs) { + return Set.builder().setOp(op).remap(remap).addAllInputs(Arrays.asList(inputs)).build(); + } + + public Sort sort(Function> sortFieldFn, Rel input) { + return sort(sortFieldFn, Optional.empty(), input); + } + + public Sort sort( + Function> sortFieldFn, + Rel.Remap remap, + Rel input) { + return sort(sortFieldFn, Optional.of(remap), input); + } + + private Sort sort( + Function> sortFieldFn, + Optional remap, + Rel input) { + var condition = sortFieldFn.apply(input); + return Sort.builder().input(input).sortFields(condition).remap(remap).build(); + } + + // Expressions + + public Expression.BoolLiteral bool(boolean v) { + return Expression.BoolLiteral.builder().value(v).build(); + } + + public FieldReference fieldReference(Rel input, int index) { + return ImmutableFieldReference.newInputRelReference(index, input); + } + + public List fieldReferences(Rel input, int... indexes) { + return Arrays.stream(indexes) + .mapToObj(index -> fieldReference(input, index)) + .collect(java.util.stream.Collectors.toList()); + } + + public List sortFields(Rel input, int... indexes) { + return Arrays.stream(indexes) + .mapToObj( + index -> + Expression.SortField.builder() + .expr(ImmutableFieldReference.newInputRelReference(index, input)) + .direction(Expression.SortDirection.ASC_NULLS_LAST) + .build()) + .collect(java.util.stream.Collectors.toList()); + } + + // Aggregate Functions + + public Aggregate.Grouping grouping(Rel input, int... indexes) { + var columns = fieldReferences(input, indexes); + return Aggregate.Grouping.builder().addAllExpressions(columns).build(); + } + + public AggregateFunctionInvocation count(Rel input, int field) { + var declaration = + extensions.getAggregateFunction( + SimpleExtension.FunctionAnchor.of("/functions_aggregate_generic.yaml", "count:any")); + return AggregateFunctionInvocation.builder() + .arguments(fieldReferences(input, field)) + .outputType(R.I64) + .declaration(declaration) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(AggregateFunction.AggregationInvocation.AGGREGATION_INVOCATION_ALL) + .build(); + } + + // Scalar Functions + + // Misc + + public Plan.Root root(Rel rel) { + return ImmutableRoot.builder().input(rel).build(); + } + + public Rel.Remap remap(Integer... fields) { + return Rel.Remap.of(Arrays.asList(fields)); + } +} diff --git a/core/src/main/java/io/substrait/relation/Cross.java b/core/src/main/java/io/substrait/relation/Cross.java index 2f250dffd..668d6b862 100644 --- a/core/src/main/java/io/substrait/relation/Cross.java +++ b/core/src/main/java/io/substrait/relation/Cross.java @@ -1,10 +1,21 @@ package io.substrait.relation; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.stream.Stream; import org.immutables.value.Value; @Value.Immutable public abstract class Cross extends BiRel { + @Override + protected Type.Struct deriveRecordType() { + return TypeCreator.REQUIRED.struct( + Stream.concat( + getLeft().getRecordType().fields().stream(), + getRight().getRecordType().fields().stream())); + } + @Override public O accept(RelVisitor visitor) throws E { return visitor.visit(this); diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 360879f44..0f02bcbbb 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -336,15 +336,7 @@ private Join newJoin(JoinRel rel) { private Rel newCross(CrossRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); - Type.Struct leftStruct = left.getRecordType(); - Type.Struct rightStruct = right.getRecordType(); - Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - return Cross.builder() - .left(left) - .right(right) - .deriveRecordType(unionedStruct) - .remap(optionalRelmap(rel.getCommon())) - .build(); + return Cross.builder().left(left).right(right).remap(optionalRelmap(rel.getCommon())).build(); } private Set newSet(SetRel rel) { diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 648b66de9..14ef98a03 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -163,7 +163,6 @@ public Optional visit(Cross cross) throws RuntimeException { .from(cross) .left(left.orElse(cross.getLeft())) .right(right.orElse(cross.getRight())) - .deriveRecordType(unionedStruct) .build()); } diff --git a/core/src/main/java/io/substrait/type/NamedStruct.java b/core/src/main/java/io/substrait/type/NamedStruct.java index 80e3f74f1..58a54215b 100644 --- a/core/src/main/java/io/substrait/type/NamedStruct.java +++ b/core/src/main/java/io/substrait/type/NamedStruct.java @@ -10,7 +10,7 @@ public interface NamedStruct { List names(); - public static NamedStruct of(List names, Type.Struct type) { + static NamedStruct of(Iterable names, Type.Struct type) { return ImmutableNamedStruct.builder().addAllNames(names).struct(type).build(); } diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index 65e0b0d53..2264c9932 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -77,7 +77,7 @@ dependencies { implementation(project(":core")) implementation("org.apache.calcite:calcite-core:${CALCITE_VERSION}") implementation("org.apache.calcite:calcite-server:${CALCITE_VERSION}") - implementation("org.junit.jupiter:junit-jupiter:5.7.0") + implementation("org.junit.jupiter:junit-jupiter:5.9.2") implementation("org.reflections:reflections:0.9.12") implementation("com.google.guava:guava:29.0-jre") implementation("org.graalvm.sdk:graal-sdk:22.0.0.2") diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index c74de0ce2..bf6961a78 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -23,8 +23,10 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.Stream; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitDef; import org.apache.calcite.prepare.Prepare; @@ -35,6 +37,8 @@ import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexSlot; @@ -49,13 +53,14 @@ */ public class SubstraitRelNodeConverter extends AbstractRelVisitor { - private final RelDataTypeFactory typeFactory; + protected final RelDataTypeFactory typeFactory; - private final ScalarFunctionConverter scalarFunctionConverter; - private final AggregateFunctionConverter aggregateFunctionConverter; - private final ExpressionRexConverter expressionRexConverter; + protected final ScalarFunctionConverter scalarFunctionConverter; + protected final AggregateFunctionConverter aggregateFunctionConverter; + protected final ExpressionRexConverter expressionRexConverter; - private final RelBuilder relBuilder; + protected final RelBuilder relBuilder; + protected final RexBuilder rexBuilder; public SubstraitRelNodeConverter( SimpleExtension.ExtensionCollection extensions, @@ -63,6 +68,7 @@ public SubstraitRelNodeConverter( RelBuilder relBuilder) { this.typeFactory = typeFactory; this.relBuilder = relBuilder; + this.rexBuilder = new RexBuilder(typeFactory); this.scalarFunctionConverter = new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory); @@ -98,12 +104,14 @@ public static RelNode convert( public RelNode visit(Filter filter) throws RuntimeException { RelNode input = filter.getInput().accept(this); RexNode filterCondition = filter.getCondition().accept(expressionRexConverter); - return relBuilder.push(input).filter(filterCondition).build(); + RelNode node = relBuilder.push(input).filter(filterCondition).build(); + return applyRemap(node, filter.getRemap()); } @Override public RelNode visit(NamedScan namedScan) throws RuntimeException { - return relBuilder.scan(namedScan.getNames()).build(); + RelNode node = relBuilder.scan(namedScan.getNames()).build(); + return applyRemap(node, namedScan.getRemap()); } @Override @@ -114,24 +122,29 @@ public RelNode visit(LocalFiles localFiles) throws RuntimeException { @Override public RelNode visit(Project project) throws RuntimeException { RelNode child = project.getInput().accept(this); - List rexList = - project.getExpressions().stream() - .map(expr -> expr.accept(expressionRexConverter)) - .collect(java.util.stream.Collectors.toList()); - return relBuilder.push(child).project(rexList).build(); + Stream directOutputs = + IntStream.range(0, child.getRowType().getFieldCount()) + .mapToObj(fieldIndex -> rexBuilder.makeInputRef(child, fieldIndex)); + + Stream exprs = + project.getExpressions().stream().map(expr -> expr.accept(expressionRexConverter)); + + List rexExprs = + Stream.concat(directOutputs, exprs).collect(java.util.stream.Collectors.toList()); + + RelNode node = relBuilder.push(child).project(rexExprs).build(); + return applyRemap(node, project.getRemap()); } @Override public RelNode visit(Cross cross) throws RuntimeException { - var left = cross.getLeft().accept(this); - var right = cross.getRight().accept(this); + RelNode left = cross.getLeft().accept(this); + RelNode right = cross.getRight().accept(this); // Calcite represents CROSS JOIN as the equivalent INNER JOIN with true condition - return relBuilder - .push(left) - .push(right) - .join(JoinRelType.INNER, relBuilder.literal(true)) - .build(); + RelNode node = + relBuilder.push(left).push(right).join(JoinRelType.INNER, relBuilder.literal(true)).build(); + return applyRemap(node, cross.getRemap()); } @Override @@ -153,7 +166,8 @@ public RelNode visit(Join join) throws RuntimeException { case UNKNOWN -> throw new UnsupportedOperationException( "Unknown join type is not supported"); }; - return relBuilder.push(left).push(right).join(joinType, condition).build(); + RelNode node = relBuilder.push(left).push(right).join(joinType, condition).build(); + return applyRemap(node, join.getRemap()); } @Override @@ -175,7 +189,8 @@ public RelNode visit(Set set) throws RuntimeException { case UNKNOWN -> throw new UnsupportedOperationException( "Unknown set operation is not supported"); }; - return builder.build(); + RelNode node = builder.build(); + return applyRemap(node, set.getRemap()); } @Override @@ -197,7 +212,8 @@ public RelNode visit(Aggregate aggregate) throws RuntimeException { aggregate.getMeasures().stream() .map(this::fromMeasure) .collect(java.util.stream.Collectors.toList()); - return relBuilder.push(child).aggregate(groupKey, aggregateCalls).build(); + RelNode node = relBuilder.push(child).aggregate(groupKey, aggregateCalls).build(); + return applyRemap(node, aggregate.getRemap()); } private AggregateCall fromMeasure(Aggregate.Measure measure) { @@ -278,7 +294,8 @@ public RelNode visit(Sort sort) throws RuntimeException { if (relFieldCollations.isEmpty()) { return relBuilder.push(child).sort(Collections.EMPTY_LIST).build(); } - return relBuilder.push(child).sort(RelCollations.of(relFieldCollations)).build(); + RelNode node = relBuilder.push(child).sort(RelCollations.of(relFieldCollations)).build(); + return applyRemap(node, sort.getRemap()); } @Override @@ -293,7 +310,8 @@ public RelNode visit(Fetch fetch) throws RuntimeException { if (count > Integer.MAX_VALUE) { throw new RuntimeException(String.format("count is overflowed as an integer: %d", count)); } - return relBuilder.push(child).limit((int) offset, (int) count).build(); + RelNode node = relBuilder.push(child).limit((int) offset, (int) count).build(); + return applyRemap(node, fetch.getRemap()); } private RelFieldCollation toRelFieldCollation(Expression.SortField sortField) { @@ -331,6 +349,27 @@ public RelNode visitFallback(Rel rel) throws RuntimeException { rel, rel.getClass().getCanonicalName(), this.getClass().getCanonicalName())); } + private RelNode applyRemap(RelNode relNode, Optional remap) { + if (remap.isPresent()) { + return applyRemap(relNode, remap.get()); + } + return relNode; + } + + private RelNode applyRemap(RelNode relNode, Rel.Remap remap) { + var rowType = relNode.getRowType(); + var fieldNames = rowType.getFieldNames(); + List rexList = + remap.indices().stream() + .map( + index -> { + RelDataTypeField t = rowType.getField(fieldNames.get(index), true, false); + return new RexInputRef(index, t.getValue()); + }) + .collect(java.util.stream.Collectors.toList()); + return relBuilder.push(relNode).project(rexList).build(); + } + private void checkRexInputRefOnly(RexNode rexNode, String context, String aggName) { if (!(rexNode instanceof RexInputRef)) { throw new UnsupportedOperationException( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 3a3c44d48..b53d8679d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -185,12 +185,7 @@ public Rel visit(LogicalJoin join) { if (joinType == Join.JoinType.INNER && TRUE.equals(condition) && featureBoard.crossJoinPolicy().equals(KEEP_AS_CROSS_JOIN)) { - return Cross.builder() - .left(left) - .right(right) - .deriveRecordType( - Type.Struct.builder().from(left.getRecordType()).from(right.getRecordType()).build()) - .build(); + return Cross.builder().left(left).right(right).build(); } return Join.builder().condition(condition).joinType(joinType).left(left).right(right).build(); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java new file mode 100644 index 000000000..a06bfd29f --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -0,0 +1,119 @@ +package io.substrait.isthmus; + +import io.substrait.function.SimpleExtension; +import io.substrait.relation.AbstractRelVisitor; +import io.substrait.relation.NamedScan; +import io.substrait.relation.Rel; +import io.substrait.type.NamedStruct; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.jdbc.LookupCalciteSchema; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.schema.Table; +import org.apache.calcite.tools.Frameworks; +import org.apache.calcite.tools.RelBuilder; + +/** + * Converts between Substrait {@link Rel}s and Calcite {@link RelNode}s. + * + *

Can be extended to customize the {@link RelBuilder} and {@link SubstraitRelNodeConverter} used + * in the conversion. + */ +public class SubstraitToCalcite { + + private final SimpleExtension.ExtensionCollection extensions; + private final RelDataTypeFactory typeFactory; + + public SubstraitToCalcite( + SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) { + this.extensions = extensions; + this.typeFactory = typeFactory; + } + + /** + * Extracts a {@link CalciteSchema} from a {@link Rel} + * + *

Override this method to customize schema extraction. + */ + protected CalciteSchema toSchema(Rel rel) { + Map, NamedStruct> tableMap = NamedStructGatherer.gatherTables(rel); + Function, Table> lookup = + id -> { + NamedStruct table = tableMap.get(id); + if (table == null) { + return null; + } + return new SqlConverterBase.DefinedTable( + id.get(id.size() - 1), + typeFactory, + TypeConverter.convert(typeFactory, table.struct(), table.names())); + }; + return LookupCalciteSchema.createRootSchema(lookup); + } + + /** + * Creates a {@link RelBuilder} from the extracted {@link CalciteSchema} + * + *

Override this method to customize the {@link RelBuilder}. + */ + protected RelBuilder createRelBuilder(CalciteSchema schema) { + return RelBuilder.create(Frameworks.newConfigBuilder().defaultSchema(schema.plus()).build()); + } + + /** + * Creates a {@link SubstraitRelNodeConverter} from the {@link RelBuilder} + * + *

Override this method to customize the {@link SubstraitRelNodeConverter}. + */ + protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { + return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder); + } + + /** + * Converts a Substrait {@link Rel} to a Calcite {@link RelNode} + * + *

Generates a {@link CalciteSchema} based on the contents of the {@link Rel}, which will be + * used to construct a {@link RelBuilder} with the required schema information to build {@link + * RelNode}s, and a then a {@link SubstraitRelNodeConverter} to perform the actual conversion. + * + * @param rel {@link Rel} to convert + * @return {@link RelNode} + */ + public RelNode convert(Rel rel) { + CalciteSchema rootSchema = toSchema(rel); + RelBuilder relBuilder = createRelBuilder(rootSchema); + SubstraitRelNodeConverter converter = createSubstraitRelNodeConverter(relBuilder); + return rel.accept(converter); + } + + private static class NamedStructGatherer extends AbstractRelVisitor { + Map, NamedStruct> tableMap; + + private NamedStructGatherer() { + this.tableMap = new HashMap<>(); + } + + public static Map, NamedStruct> gatherTables(Rel rel) { + var visitor = new NamedStructGatherer(); + rel.accept(visitor); + return visitor.tableMap; + } + + @Override + public Void visit(NamedScan namedScan) { + List tableName = namedScan.getNames(); + tableMap.put(tableName, namedScan.getInitialSchema()); + return null; + } + + @Override + public Void visitFallback(Rel rel) { + for (Rel input : rel.getInputs()) input.accept(this); + return null; + } + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 2f93a8c18..94d067108 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -5,6 +5,7 @@ import com.google.common.base.Charsets; import com.google.common.io.Resources; +import io.substrait.function.SimpleExtension; import io.substrait.plan.Plan; import io.substrait.plan.PlanProtoConverter; import io.substrait.plan.ProtoPlanConverter; @@ -23,10 +24,20 @@ import org.junit.jupiter.api.Assertions; public class PlanTestBase { + final SimpleExtension.ExtensionCollection extensions; + + { + try { + extensions = SimpleExtension.loadDefaults(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + protected final RelCreator creator = new RelCreator(); protected final RelBuilder builder = creator.createRelBuilder(); protected final RexBuilder rex = creator.rex(); - protected final RelDataTypeFactory type = creator.type(); + protected final RelDataTypeFactory typeFactory = creator.typeFactory(); public static String asString(String resource) throws IOException { return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCreator.java b/isthmus/src/test/java/io/substrait/isthmus/RelCreator.java index c68ee5293..34f45dedb 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCreator.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCreator.java @@ -74,7 +74,7 @@ public RexBuilder rex() { return cluster.getRexBuilder(); } - public RelDataTypeFactory type() { + public RelDataTypeFactory typeFactory() { return cluster.getTypeFactory(); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java new file mode 100644 index 000000000..c7e0facdc --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java @@ -0,0 +1,239 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.plan.Plan; +import io.substrait.relation.Rel; +import io.substrait.relation.Set.SetOp; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.calcite.rel.type.RelDataType; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +public class SubstraitRelNodeConverterTest extends PlanTestBase { + + static final TypeCreator R = TypeCreator.of(false); + static final TypeCreator N = TypeCreator.of(true); + + final SubstraitBuilder b = new SubstraitBuilder(extensions); + + // Define a shared table (i.e. a NamedScan) for use in tests. + final List commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN); + final List commonTableTypeTwice = + Stream.concat(commonTableType.stream(), commonTableType.stream()) + .collect(Collectors.toList()); + final Rel commonTable = + b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType); + + final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); + + void assertRowMatch(RelDataType actual, Type... expected) { + assertRowMatch(actual, Arrays.asList(expected)); + } + + void assertRowMatch(RelDataType actual, List expected) { + Type type = TypeConverter.convert(actual); + assertInstanceOf(Type.Struct.class, type); + Type.Struct struct = (Type.Struct) type; + assertEquals(expected, struct.fields()); + } + + @Nested + class Aggregate { + @Test + public void direct() { + Plan.Root root = + b.root( + b.aggregate( + input -> b.grouping(input, 0, 2), + input -> List.of(b.count(input, 0)), + commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), R.I32, N.STRING, R.I64); + } + + @Test + public void emit() { + Plan.Root root = + b.root( + b.aggregate( + input -> b.grouping(input, 0, 2), + input -> List.of(b.count(input, 0)), + b.remap(1, 2), + commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), N.STRING, R.I64); + } + } + + @Nested + class Cross { + @Test + public void direct() { + Plan.Root root = b.root(b.cross(commonTable, commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), commonTableTypeTwice); + } + + @Test + public void emit() { + Plan.Root root = b.root(b.cross(commonTable, commonTable, b.remap(0, 1, 4, 6))); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), R.I32, R.FP32, R.I32, N.STRING); + } + } + + @Nested + class Fetch { + @Test + public void direct() { + Plan.Root root = b.root(b.fetch(20, 40, commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), commonTableType); + } + + @Test + public void emit() { + Plan.Root root = b.root(b.fetch(20, 40, b.remap(0, 2), commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), R.I32, N.STRING); + } + } + + @Nested + class Filter { + @Test + public void direct() { + Plan.Root root = b.root(b.filter(input -> b.bool(true), commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), commonTableType); + } + + @Test + public void emit() { + Plan.Root root = b.root(b.filter(input -> b.bool(true), b.remap(0, 2), commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), R.I32, N.STRING); + } + } + + @Nested + class Join { + @Test + public void direct() { + Plan.Root root = b.root(b.innerJoin(input -> b.bool(true), commonTable, commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), commonTableTypeTwice); + } + + @Test + public void emit() { + Plan.Root root = + b.root(b.innerJoin(input -> b.bool(true), b.remap(0, 6), commonTable, commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), R.I32, N.STRING); + } + } + + @Nested + class NamedScan { + @Test + public void direct() { + Plan.Root root = + b.root(b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.I32, R.FP32))); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), R.I32, R.FP32); + } + + @Test + public void emit() { + Plan.Root root = + b.root( + b.namedScan( + List.of("example"), List.of("a", "b"), List.of(R.I32, R.FP32), b.remap(1))); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), R.FP32); + } + } + + @Nested + class Project { + @Test + public void direct() { + Plan.Root root = b.root(b.project(input -> b.fieldReferences(input, 1, 0, 2), commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch( + relNode.getRowType(), R.I32, R.FP32, N.STRING, N.BOOLEAN, R.FP32, R.I32, N.STRING); + } + + @Test + public void emit() { + Plan.Root root = + b.root( + b.project( + input -> b.fieldReferences(input, 1, 0, 2), b.remap(0, 2, 4, 6), commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), R.I32, N.STRING, R.FP32, N.STRING); + } + } + + @Nested + class Set { + @Test + public void direct() { + Plan.Root root = b.root(b.set(SetOp.UNION_ALL, commonTable, commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), commonTableType); + } + + @Test + public void emit() { + Plan.Root root = b.root(b.set(SetOp.UNION_ALL, b.remap(0, 2), commonTable, commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), R.I32, N.STRING); + } + } + + @Nested + class Sort { + @Test + public void direct() { + Plan.Root root = b.root(b.sort(input -> b.sortFields(input, 0, 1, 2), commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), commonTableType); + } + + @Test + public void emit() { + Plan.Root root = + b.root(b.sort(input -> b.sortFields(input, 0, 1, 2), b.remap(0, 2), commonTable)); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), R.I32, N.STRING); + } + } +} From dd069f69e9ddd7e7e668188ab535174181e61d02 Mon Sep 17 00:00:00 2001 From: semantic-release-bot Date: Sun, 12 Mar 2023 03:05:09 +0000 Subject: [PATCH 2/2] chore(release): 0.7.0 --- CHANGELOG.md | 7 +++++++ gradle.properties | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a879076e9..a116beb0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,13 @@ Release Notes --- +## [0.7.0](https://github.com/substrait-io/substrait-java/compare/v0.6.0...v0.7.0) (2023-03-12) + + +### Features + +* support output order remapping ([#132](https://github.com/substrait-io/substrait-java/issues/132)) ([6d94059](https://github.com/substrait-io/substrait-java/commit/6d94059329ad7f3b0cc2c124bed4f8b22b85551a)) + ## [0.6.0](https://github.com/substrait-io/substrait-java/compare/v0.5.0...v0.6.0) (2023-03-05) diff --git a/gradle.properties b/gradle.properties index 78f48bdef..b18fc9489 100644 --- a/gradle.properties +++ b/gradle.properties @@ -20,7 +20,7 @@ slf4j.version=1.7.25 jackson.version=2.12.4 #version that is going to be updated automatically by releases -version = 0.6.0 +version = 0.7.0 #signing SIGNING_KEY_ID = 193EAE47