Skip to content

Commit

Permalink
fix(core): wrong type derivation for ConsistentPartitionWindow (#286)
Browse files Browse the repository at this point in the history
The output of a ConsistentPartitionWindow consists of:
* all input columns
* all window expressions

The deriveRecordType() has been updated to reflect this
  • Loading branch information
bvolpato authored Aug 16, 2024
1 parent a0ca17b commit 60575b3
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ protected Type.Struct deriveRecordType() {
.struct(
Stream.concat(
initial.fields().stream(),
getPartitionExpressions().stream().map(Expression::getType)));
getWindowFunctions().stream().map(WindowRelFunctionInvocation::outputType)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
public class ConsistentPartitionWindowRelRoundtripTest extends TestBase {

@Test
void consistentPartitionWindowRoundtrip() {
void consistentPartitionWindowRoundtripSingle() {
var windowFunctionDeclaration =
defaultExtensionCollection.getWindowFunction(
SimpleExtension.FunctionAnchor.of(
Expand Down Expand Up @@ -63,5 +63,83 @@ void consistentPartitionWindowRoundtrip() {
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1);
io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel);
assertEquals(rel1, rel2);

// Make sure that the record types match I64, I16, I32 and then the I64 from the window
// function.
assertEquals(rel2.getRecordType().fields(), Arrays.asList(R.I64, R.I16, R.I32, R.I64));
}

@Test
void consistentPartitionWindowRoundtripMulti() {
var windowFunctionLeadDeclaration =
defaultExtensionCollection.getWindowFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any"));
var windowFunctionLagDeclaration =
defaultExtensionCollection.getWindowFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any"));
Rel input =
b.namedScan(
Arrays.asList("test"),
Arrays.asList("a", "b", "c"),
Arrays.asList(R.I64, R.I16, R.I32));
Rel rel1 =
ImmutableConsistentPartitionWindow.builder()
.input(input)
.windowFunctions(
Arrays.asList(
ConsistentPartitionWindow.WindowRelFunctionInvocation.builder()
.declaration(windowFunctionLeadDeclaration)
// lead(a)
.arguments(Arrays.asList(b.fieldReference(input, 0)))
.options(
Arrays.asList(
FunctionOption.builder()
.name("option")
.addValues("VALUE1", "VALUE2")
.build()))
.outputType(R.I64)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.lowerBound(ImmutableWindowBound.Unbounded.UNBOUNDED)
.upperBound(ImmutableWindowBound.Following.CURRENT_ROW)
.boundsType(Expression.WindowBoundsType.RANGE)
.build(),
ConsistentPartitionWindow.WindowRelFunctionInvocation.builder()
.declaration(windowFunctionLagDeclaration)
// lag(a)
.arguments(Arrays.asList(b.fieldReference(input, 0)))
.options(
Arrays.asList(
FunctionOption.builder()
.name("option")
.addValues("VALUE1", "VALUE2")
.build()))
.outputType(R.I64)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.lowerBound(ImmutableWindowBound.Unbounded.UNBOUNDED)
.upperBound(ImmutableWindowBound.Following.CURRENT_ROW)
.boundsType(Expression.WindowBoundsType.RANGE)
.build()))
// PARTITION BY b
.partitionExpressions(Arrays.asList(b.fieldReference(input, 1)))
.sorts(
Arrays.asList(
Expression.SortField.builder()
// SORT BY c
.expr(b.fieldReference(input, 2))
.direction(Expression.SortDirection.ASC_NULLS_FIRST)
.build()))
.build();

io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1);
io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel);
assertEquals(rel1, rel2);

// Make sure that the record types match I64, I16, I32 and then the I64 and I64 from the window
// functions.
assertEquals(rel2.getRecordType().fields(), Arrays.asList(R.I64, R.I16, R.I32, R.I64, R.I64));
}
}

0 comments on commit 60575b3

Please sign in to comment.