From 60575b3ddc6f9abd7040ccc3d70be69bdf19403f Mon Sep 17 00:00:00 2001 From: Bruno Volpato Date: Fri, 16 Aug 2024 17:10:44 -0400 Subject: [PATCH] fix(core): wrong type derivation for ConsistentPartitionWindow (#286) The output of a ConsistentPartitionWindow consists of: * all input columns * all window expressions The deriveRecordType() has been updated to reflect this --- .../relation/ConsistentPartitionWindow.java | 2 +- ...istentPartitionWindowRelRoundtripTest.java | 80 ++++++++++++++++++- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java b/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java index d1344581c..f1f9cbe71 100644 --- a/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java +++ b/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java @@ -29,7 +29,7 @@ protected Type.Struct deriveRecordType() { .struct( Stream.concat( initial.fields().stream(), - getPartitionExpressions().stream().map(Expression::getType))); + getWindowFunctions().stream().map(WindowRelFunctionInvocation::outputType))); } @Override diff --git a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java index 13036d048..329dc77b4 100644 --- a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java @@ -17,7 +17,7 @@ public class ConsistentPartitionWindowRelRoundtripTest extends TestBase { @Test - void consistentPartitionWindowRoundtrip() { + void consistentPartitionWindowRoundtripSingle() { var windowFunctionDeclaration = defaultExtensionCollection.getWindowFunction( SimpleExtension.FunctionAnchor.of( @@ -63,5 +63,83 @@ void consistentPartitionWindowRoundtrip() { io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1); io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel); assertEquals(rel1, rel2); + + // Make sure that the record types match I64, I16, I32 and then the I64 from the window + // function. + assertEquals(rel2.getRecordType().fields(), Arrays.asList(R.I64, R.I16, R.I32, R.I64)); + } + + @Test + void consistentPartitionWindowRoundtripMulti() { + var windowFunctionLeadDeclaration = + defaultExtensionCollection.getWindowFunction( + SimpleExtension.FunctionAnchor.of( + DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); + var windowFunctionLagDeclaration = + defaultExtensionCollection.getWindowFunction( + SimpleExtension.FunctionAnchor.of( + DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); + Rel input = + b.namedScan( + Arrays.asList("test"), + Arrays.asList("a", "b", "c"), + Arrays.asList(R.I64, R.I16, R.I32)); + Rel rel1 = + ImmutableConsistentPartitionWindow.builder() + .input(input) + .windowFunctions( + Arrays.asList( + ConsistentPartitionWindow.WindowRelFunctionInvocation.builder() + .declaration(windowFunctionLeadDeclaration) + // lead(a) + .arguments(Arrays.asList(b.fieldReference(input, 0))) + .options( + Arrays.asList( + FunctionOption.builder() + .name("option") + .addValues("VALUE1", "VALUE2") + .build())) + .outputType(R.I64) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(Expression.AggregationInvocation.ALL) + .lowerBound(ImmutableWindowBound.Unbounded.UNBOUNDED) + .upperBound(ImmutableWindowBound.Following.CURRENT_ROW) + .boundsType(Expression.WindowBoundsType.RANGE) + .build(), + ConsistentPartitionWindow.WindowRelFunctionInvocation.builder() + .declaration(windowFunctionLagDeclaration) + // lag(a) + .arguments(Arrays.asList(b.fieldReference(input, 0))) + .options( + Arrays.asList( + FunctionOption.builder() + .name("option") + .addValues("VALUE1", "VALUE2") + .build())) + .outputType(R.I64) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(Expression.AggregationInvocation.ALL) + .lowerBound(ImmutableWindowBound.Unbounded.UNBOUNDED) + .upperBound(ImmutableWindowBound.Following.CURRENT_ROW) + .boundsType(Expression.WindowBoundsType.RANGE) + .build())) + // PARTITION BY b + .partitionExpressions(Arrays.asList(b.fieldReference(input, 1))) + .sorts( + Arrays.asList( + Expression.SortField.builder() + // SORT BY c + .expr(b.fieldReference(input, 2)) + .direction(Expression.SortDirection.ASC_NULLS_FIRST) + .build())) + .build(); + + io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1); + io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel); + assertEquals(rel1, rel2); + + // Make sure that the record types match I64, I16, I32 and then the I64 and I64 from the window + // functions. + assertEquals(rel2.getRecordType().fields(), Arrays.asList(R.I64, R.I16, R.I32, R.I64, R.I64)); } }