Skip to content

Commit

Permalink
feat(pojo): add POJO representation and converters for ConsistentPart…
Browse files Browse the repository at this point in the history
…itionWindowRel (#231)
  • Loading branch information
bvolpato authored Feb 16, 2024
1 parent 63dd305 commit f148bbb
Show file tree
Hide file tree
Showing 10 changed files with 411 additions and 42 deletions.
22 changes: 22 additions & 0 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.protobuf.ByteString;
import io.substrait.extension.SimpleExtension;
import io.substrait.relation.ConsistentPartitionWindow;
import io.substrait.type.Type;
import io.substrait.util.DecimalUtil;
import java.math.BigDecimal;
Expand Down Expand Up @@ -342,6 +343,27 @@ public static Expression.WindowFunctionInvocation windowFunction(
.build();
}

public static ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expression.AggregationPhase phase,
Expression.AggregationInvocation invocation,
Expression.WindowBoundsType boundsType,
WindowBound lowerBound,
WindowBound upperBound,
Iterable<? extends FunctionArg> arguments) {
return ConsistentPartitionWindow.WindowRelFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.boundsType(boundsType)
.lowerBound(lowerBound)
.upperBound(upperBound)
.invocation(invocation)
.addAllArguments(arguments)
.build();
}

public static Expression.WindowFunctionInvocation windowFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,10 @@ public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocat
.build();
}

static class BoundConverter
public static class BoundConverter
implements WindowBound.WindowBoundVisitor<Expression.WindowFunction.Bound, RuntimeException> {

static Expression.WindowFunction.Bound convert(WindowBound bound) {
public static Expression.WindowFunction.Bound convert(WindowBound bound) {
return bound.accept(TO_BOUND_VISITOR);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,24 @@
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.ImmutableExpression;
import io.substrait.expression.ImmutableFunctionOption;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.ConsistentPartitionWindowRel;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.SortField;
import io.substrait.relation.ConsistentPartitionWindow;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.type.Type;
import io.substrait.type.proto.ProtoTypeConverter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -115,46 +123,7 @@ public Expression from(io.substrait.proto.Expression expr) {
.outputType(protoTypeConverter.from(scalarFunction.getOutputType()))
.build();
}
case WINDOW_FUNCTION -> {
var windowFunction = expr.getWindowFunction();
var functionReference = windowFunction.getFunctionReference();
var declaration = lookup.getWindowFunction(functionReference, extensions);

var argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter);
var args =
IntStream.range(0, windowFunction.getArgumentsCount())
.mapToObj(i -> argVisitor.convert(declaration, i, windowFunction.getArguments(i)))
.collect(java.util.stream.Collectors.toList());
var partitionExprs =
windowFunction.getPartitionsList().stream()
.map(this::from)
.collect(java.util.stream.Collectors.toList());
var sortFields =
windowFunction.getSortsList().stream()
.map(
s ->
Expression.SortField.builder()
.direction(Expression.SortDirection.fromProto(s.getDirection()))
.expr(from(s.getExpr()))
.build())
.collect(java.util.stream.Collectors.toList());

WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound());
WindowBound upperBound = toWindowBound(windowFunction.getUpperBound());

yield Expression.WindowFunctionInvocation.builder()
.arguments(args)
.declaration(declaration)
.outputType(protoTypeConverter.from(windowFunction.getOutputType()))
.aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase()))
.partitionBy(partitionExprs)
.sort(sortFields)
.boundsType(Expression.WindowBoundsType.fromProto(windowFunction.getBoundsType()))
.lowerBound(lowerBound)
.upperBound(upperBound)
.invocation(Expression.AggregationInvocation.fromProto(windowFunction.getInvocation()))
.build();
}
case WINDOW_FUNCTION -> fromWindowFunction(expr.getWindowFunction());
case IF_THEN -> {
var ifThen = expr.getIfThen();
var clauses =
Expand Down Expand Up @@ -250,6 +219,80 @@ public Expression from(io.substrait.proto.Expression expr) {
};
}

public Expression.WindowFunctionInvocation fromWindowFunction(
io.substrait.proto.Expression.WindowFunction windowFunction) {
var functionReference = windowFunction.getFunctionReference();
var declaration = lookup.getWindowFunction(functionReference, extensions);
var argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter);

var args =
fromFunctionArgumentList(
windowFunction.getArgumentsCount(),
argVisitor,
declaration,
windowFunction::getArguments);
var partitionExprs =
windowFunction.getPartitionsList().stream().map(this::from).collect(Collectors.toList());
var sortFields =
windowFunction.getSortsList().stream()
.map(this::fromSortField)
.collect(Collectors.toList());
var options =
windowFunction.getOptionsList().stream()
.map(this::fromFunctionOption)
.collect(Collectors.toMap(FunctionOption::getName, Function.identity()));

WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound());
WindowBound upperBound = toWindowBound(windowFunction.getUpperBound());

return Expression.WindowFunctionInvocation.builder()
.arguments(args)
.declaration(declaration)
.outputType(protoTypeConverter.from(windowFunction.getOutputType()))
.aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase()))
.partitionBy(partitionExprs)
.sort(sortFields)
.boundsType(Expression.WindowBoundsType.fromProto(windowFunction.getBoundsType()))
.lowerBound(lowerBound)
.upperBound(upperBound)
.invocation(Expression.AggregationInvocation.fromProto(windowFunction.getInvocation()))
.options(options)
.build();
}

public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFunction(
ConsistentPartitionWindowRel.WindowRelFunction windowRelFunction) {
var functionReference = windowRelFunction.getFunctionReference();
var declaration = lookup.getWindowFunction(functionReference, extensions);
var argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter);

var args =
fromFunctionArgumentList(
windowRelFunction.getArgumentsCount(),
argVisitor,
declaration,
windowRelFunction::getArguments);
var options =
windowRelFunction.getOptionsList().stream()
.map(this::fromFunctionOption)
.collect(Collectors.toMap(FunctionOption::getName, Function.identity()));

WindowBound lowerBound = toWindowBound(windowRelFunction.getLowerBound());
WindowBound upperBound = toWindowBound(windowRelFunction.getUpperBound());

return ConsistentPartitionWindow.WindowRelFunctionInvocation.builder()
.arguments(args)
.declaration(declaration)
.outputType(protoTypeConverter.from(windowRelFunction.getOutputType()))
.aggregationPhase(Expression.AggregationPhase.fromProto(windowRelFunction.getPhase()))
.boundsType(Expression.WindowBoundsType.fromProto(windowRelFunction.getBoundsType()))
.lowerBound(lowerBound)
.upperBound(upperBound)
.invocation(Expression.AggregationInvocation.fromProto(windowRelFunction.getInvocation()))
.options(options)
.build();
}

private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.Bound bound) {
return switch (bound.getKindCase()) {
case PRECEDING -> WindowBound.Preceding.of(bound.getPreceding().getOffset());
Expand Down Expand Up @@ -318,4 +361,28 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
"Unexpected value: " + literal.getLiteralTypeCase());
};
}

private static List<FunctionArg> fromFunctionArgumentList(
int argumentsCount,
FunctionArg.ProtoFrom argVisitor,
SimpleExtension.Function declaration,
Function<Integer, FunctionArgument> argFunction) {
return IntStream.range(0, argumentsCount)
.mapToObj(i -> argVisitor.convert(declaration, i, argFunction.apply(i)))
.collect(Collectors.toList());
}

public Expression.SortField fromSortField(SortField s) {
return Expression.SortField.builder()
.direction(Expression.SortDirection.fromProto(s.getDirection()))
.expr(from(s.getExpr()))
.build();
}

public FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
return ImmutableFunctionOption.builder()
.name(o.getName())
.addAllValues(o.getPreferenceList())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,9 @@ public OUTPUT visit(MergeJoin mergeJoin) throws EXCEPTION {
public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
return visitFallback(nestedLoopJoin);
}

@Override
public OUTPUT visit(ConsistentPartitionWindow consistentPartitionWindow) throws EXCEPTION {
return visitFallback(consistentPartitionWindow);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package io.substrait.relation;

import io.substrait.expression.Expression;
import io.substrait.expression.Expression.SortField;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.WindowBound;
import io.substrait.extension.SimpleExtension;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Immutable
@Value.Enclosing
public abstract class ConsistentPartitionWindow extends SingleInputRel implements HasExtension {

public abstract List<WindowRelFunctionInvocation> getWindowFunctions();

public abstract List<Expression> getPartitionExpressions();

public abstract List<SortField> getSorts();

@Override
protected Type.Struct deriveRecordType() {
Type.Struct initial = getInput().getRecordType();
return TypeCreator.of(initial.nullable())
.struct(
Stream.concat(
initial.fields().stream(),
getPartitionExpressions().stream().map(Expression::getType)));
}

@Override
public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
return visitor.visit(this);
}

public static ImmutableConsistentPartitionWindow.Builder builder() {
return ImmutableConsistentPartitionWindow.builder();
}

@Value.Immutable
public abstract static class WindowRelFunctionInvocation {

public abstract SimpleExtension.WindowFunctionVariant declaration();

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();

public abstract Type outputType();

public abstract Expression.AggregationPhase aggregationPhase();

public abstract Expression.AggregationInvocation invocation();

public abstract WindowBound lowerBound();

public abstract WindowBound upperBound();

public abstract Expression.WindowBoundsType boundsType();

public static ImmutableConsistentPartitionWindow.WindowRelFunctionInvocation.Builder builder() {
return ImmutableConsistentPartitionWindow.WindowRelFunctionInvocation.builder();
}
}
}
39 changes: 39 additions & 0 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.AggregateRel;
import io.substrait.proto.ConsistentPartitionWindowRel;
import io.substrait.proto.CrossRel;
import io.substrait.proto.ExtensionLeafRel;
import io.substrait.proto.ExtensionMultiRel;
Expand Down Expand Up @@ -108,6 +109,9 @@ public Rel from(io.substrait.proto.Rel rel) {
case NESTED_LOOP_JOIN -> {
return newNestedLoopJoin(rel.getNestedLoopJoin());
}
case WINDOW -> {
return newConsistentPartitionWindow(rel.getWindow());
}
default -> {
throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType);
}
Expand Down Expand Up @@ -601,6 +605,41 @@ private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) {
return builder.build();
}

private ConsistentPartitionWindow newConsistentPartitionWindow(ConsistentPartitionWindowRel rel) {

var input = from(rel.getInput());
var protoExpressionConverter =
new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);

var partitionExprs =
rel.getPartitionExpressionsList().stream()
.map(protoExpressionConverter::from)
.collect(Collectors.toList());
var sortFields =
rel.getSortsList().stream()
.map(protoExpressionConverter::fromSortField)
.collect(Collectors.toList());
var windowRelFunctions =
rel.getWindowFunctionsList().stream()
.map(protoExpressionConverter::fromWindowRelFunction)
.collect(Collectors.toList());

var builder =
ConsistentPartitionWindow.builder()
.input(input)
.partitionExpressions(partitionExprs)
.sorts(sortFields)
.windowFunctions(windowRelFunctions);

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()));
if (rel.hasAdvancedExtension()) {
builder.extension(advancedExtension(rel.getAdvancedExtension()));
}
return builder.build();
}

private static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon relCommon) {
return Optional.ofNullable(
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);
Expand Down
Loading

0 comments on commit f148bbb

Please sign in to comment.