Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(isthmus): support in-predicate (#204) #205

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,13 @@ public Expression singleOrList(Expression condition, Expression... options) {
return SingleOrList.builder().condition(condition).addOptions(options).build();
}

public Expression.InPredicate inPredicate(Rel haystack, Expression... needles) {
return Expression.InPredicate.builder()
.addAllNeedles(Arrays.asList(needles))
.haystack(haystack)
.build();
}

public List<Expression.SortField> sortFields(Rel input, int... indexes) {
return Arrays.stream(indexes)
.mapToObj(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ public SubstraitRelNodeConverter(
this.scalarFunctionConverter = scalarFunctionConverter;
this.aggregateFunctionConverter = aggregateFunctionConverter;
this.expressionRexConverter = expressionRexConverter;
this.expressionRexConverter.setRelNodeConverter(this);
}

public static RelNode convert(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.substrait.expression.FunctionArg;
import io.substrait.expression.WindowBound;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.SubstraitRelNodeConverter;
import io.substrait.isthmus.TypeConverter;
import io.substrait.type.StringTypeVisitor;
import io.substrait.type.Type;
Expand All @@ -22,12 +23,14 @@
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.calcite.avatica.util.ByteString;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexFieldCollation;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.sql.SqlAggFunction;
Expand All @@ -50,6 +53,7 @@ public class ExpressionRexConverter extends AbstractExpressionVisitor<RexNode, R
protected final RexBuilder rexBuilder;
protected final ScalarFunctionConverter scalarFunctionConverter;
protected final WindowFunctionConverter windowFunctionConverter;
protected SubstraitRelNodeConverter relNodeConverter;

private static final SqlIntervalQualifier YEAR_MONTH_INTERVAL =
new SqlIntervalQualifier(
Expand Down Expand Up @@ -79,6 +83,10 @@ public ExpressionRexConverter(
this.windowFunctionConverter = windowFunctionConverter;
}

public void setRelNodeConverter(final SubstraitRelNodeConverter substraitRelNodeConverter) {
this.relNodeConverter = substraitRelNodeConverter;
}

@Override
public RexNode visit(Expression.NullLiteral expr) throws RuntimeException {
return rexBuilder.makeLiteral(null, typeConverter.toCalcite(typeFactory, expr.getType()));
Expand Down Expand Up @@ -385,6 +393,14 @@ public RexNode visit(Expression.WindowFunctionInvocation expr) throws RuntimeExc
ignoreNulls);
}

@Override
public RexNode visit(Expression.InPredicate expr) throws RuntimeException {
List<RexNode> needles =
expr.needles().stream().map(e -> e.accept(this)).collect(Collectors.toList());
RelNode rel = expr.haystack().accept(relNodeConverter);
return RexSubQuery.in(rel, ImmutableList.copyOf(needles));
}

static class ToRexWindowBound
implements WindowBound.WindowBoundVisitor<RexWindowBound, RuntimeException> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ public void mapLiteral() throws IOException, SqlParseException {
assertFullRoundTrip("select MAP[1, 'hello'] from ORDERS");
}

@Test
public void inPredicate() throws IOException, SqlParseException {
assertFullRoundTrip(
"select L_PARTKEY from LINEITEM where L_PARTKEY in "
+ "(SELECT L_SUPPKEY from LINEITEM where L_SUPPKEY < L_ORDERKEY)");
}

@Test
public void singleOrList() {
Expression singleOrList = b.singleOrList(b.fieldReference(commonTable, 0), b.i32(5), b.i32(10));
Expand Down