Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import org.apache.commons.io.FileUtils;
import org.apache.pinot.spi.config.table.TableConfig;
import org.apache.pinot.spi.data.Schema;
Expand Down Expand Up @@ -700,6 +701,22 @@ public void systemColumnsCanBeUsedInWhere(String systemColumn)
assertNoError(jsonNode);
}

@Test
public void testSearch()
throws Exception {
String sqlQuery = "SELECT CASE WHEN ArrDelay > 50 OR ArrDelay < 10 THEN 10 ELSE 0 END "
+ "FROM mytable LIMIT 1000";
JsonNode jsonNode = postQuery("Explain plan for " + sqlQuery);
JsonNode plan = jsonNode.get("resultTable").get("rows").get(0).get(1);

Pattern pattern = Pattern.compile("SEARCH\\(\\$7, Sarg\\[\\(-∞\\.\\.10\\), \\(50\\.\\.\\+∞\\)]\\)");
boolean matches = pattern.matcher(plan.asText()).find();
Assert.assertTrue(matches, "Plan doesn't contain the expected SEARCH");

jsonNode = postQuery(sqlQuery);
assertNoError(jsonNode);
}

@AfterClass
public void tearDown()
throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import static org.apache.pinot.common.function.scalar.StringFunctions.*;
Expand Down Expand Up @@ -222,6 +223,11 @@ public void setUp()
waitForAllDocsLoaded(600_000L);
}

@BeforeMethod
public void resetMultiStage() {
setUseMultiStageQueryEngine(false);
}

protected void startBrokers()
throws Exception {
startBrokers(getNumBrokers());
Expand Down Expand Up @@ -1956,9 +1962,10 @@ public void testCaseStatementInSelectionWithTransformFunctionInThen()
}
}

@Test
public void testCaseStatementWithLogicalTransformFunction()
@Test(dataProvider = "useBothQueryEngines")
public void testCaseStatementWithLogicalTransformFunction(boolean useMultiStageQueryEngine)
throws Exception {
setUseMultiStageQueryEngine(useMultiStageQueryEngine);
String sqlQuery = "SELECT ArrDelay" + ", CASE WHEN ArrDelay > 50 OR ArrDelay < 10 THEN 10 ELSE 0 END"
+ ", CASE WHEN ArrDelay < 50 AND ArrDelay >= 10 THEN 10 ELSE 0 END" + " FROM mytable LIMIT 1000";
JsonNode response = postQuery(sqlQuery);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
package org.apache.pinot.query.planner.logical;

import com.google.common.base.Preconditions;
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Range;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.GregorianCalendar;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
Expand All @@ -36,8 +39,8 @@
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.NlsString;
import org.apache.calcite.util.Sarg;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.spi.utils.BooleanUtils;
import org.checkerframework.checker.nullness.qual.Nullable;


Expand Down Expand Up @@ -141,8 +144,7 @@ private static RexExpression handleReinterpret(RexCall rexCall) {
return fromRexNode(operands.get(0));
}

// TODO: Add support for range filter expressions (e.g. a > 0 and a < 30)
private static RexExpression.FunctionCall handleSearch(RexCall rexCall) {
private static RexExpression handleSearch(RexCall rexCall) {
List<RexNode> operands = rexCall.getOperands();
RexInputRef rexInputRef = (RexInputRef) operands.get(0);
RexLiteral rexLiteral = (RexLiteral) operands.get(1);
Expand All @@ -155,10 +157,83 @@ private static RexExpression.FunctionCall handleSearch(RexCall rexCall) {
return new RexExpression.FunctionCall(SqlKind.NOT_IN, dataType, SqlKind.NOT_IN.name(),
toFunctionOperands(rexInputRef, sarg.rangeSet.complement().asRanges(), dataType));
} else {
throw new NotImplementedException("Range is not implemented yet");
Set<Range<?>> ranges = sarg.rangeSet.asRanges();
return convertRangesToOr(dataType, rexInputRef, ranges);
}
}

private static RexExpression convertRangesToOr(ColumnDataType dataType, RexInputRef rexInputRef,
Set<Range<?>> ranges) {
RexExpression result;
Iterator<Range<?>> it = ranges.iterator();
if (!it.hasNext()) { // no disjunctions means false
return new RexExpression.Literal(ColumnDataType.BOOLEAN, 0);
}
RexExpression.InputRef rexInput = fromRexInputRef(rexInputRef);
result = convertRange(rexInput, dataType, it.next());
if (result instanceof RexExpression.Literal) {
Object value = ((RexExpression.Literal) result).getValue();
if (BooleanUtils.isTrueInternalValue(value)) { // one of the disjunctions is true => return true
return result;
}
}
while (it.hasNext()) {
Range<?> range = it.next();
RexExpression newExp = convertRange(rexInput, dataType, range);
if (newExp instanceof RexExpression.Literal) {
Object value = ((RexExpression.Literal) newExp).getValue();
if (BooleanUtils.isTrueInternalValue(value)) { // one of the disjunctions is true => return true
return newExp;
} else {
continue; // one of the disjunctions is false => ignore it
}
}
ImmutableList<RexExpression> operands = ImmutableList.of(result, newExp);
result = new RexExpression.FunctionCall(SqlKind.OR, ColumnDataType.BOOLEAN, SqlKind.OR.name(), operands);
}
return result;
}

private static RexExpression convertRange(RexExpression.InputRef rexInput, ColumnDataType dataType, Range<?> range) {
if (range.isEmpty()) {
return new RexExpression.Literal(ColumnDataType.BOOLEAN, 0);
}
if (!range.hasLowerBound()) {
if (!range.hasUpperBound()) {
return new RexExpression.Literal(ColumnDataType.BOOLEAN, 1);
}
return convertUpperBound(rexInput, dataType, range.upperBoundType(), range.upperEndpoint());
} else if (!range.hasUpperBound()) {
return convertLowerBound(rexInput, dataType, range.lowerBoundType(), range.lowerEndpoint());
} else {
RexExpression lowerConstraint =
convertLowerBound(rexInput, dataType, range.lowerBoundType(), range.lowerEndpoint());
RexExpression upperConstraint =
convertUpperBound(rexInput, dataType, range.upperBoundType(), range.upperEndpoint());
ImmutableList<RexExpression> operands = ImmutableList.of(lowerConstraint, upperConstraint);
return new RexExpression.FunctionCall(SqlKind.AND, ColumnDataType.BOOLEAN, SqlKind.AND.name(), operands);
}
}

private static RexExpression convertLowerBound(RexExpression.InputRef inputRef, ColumnDataType dataType,
BoundType boundType, Comparable<?> endpoint) {
SqlKind sqlKind = boundType == BoundType.OPEN ? SqlKind.GREATER_THAN : SqlKind.GREATER_THAN_OR_EQUAL;
RexExpression.Literal literal = new RexExpression.Literal(dataType, convertValue(dataType, endpoint));
ImmutableList<RexExpression> operands = ImmutableList.of(inputRef, literal);
return new RexExpression.FunctionCall(sqlKind, ColumnDataType.BOOLEAN, sqlKind.name(), operands);
}

private static RexExpression convertUpperBound(RexExpression.InputRef inputRef, ColumnDataType dataType,
BoundType boundType, Comparable<?> endpoint) {
SqlKind sqlKind = boundType == BoundType.OPEN ? SqlKind.LESS_THAN : SqlKind.LESS_THAN_OR_EQUAL;
RexExpression.Literal literal = new RexExpression.Literal(dataType, convertValue(dataType, endpoint));
ImmutableList<RexExpression> operands = ImmutableList.of(inputRef, literal);
return new RexExpression.FunctionCall(sqlKind, ColumnDataType.BOOLEAN, sqlKind.name(), operands);
}

/**
* Transforms a set of <b>point based</b> ranges into a list of expressions.
*/
private static List<RexExpression> toFunctionOperands(RexInputRef rexInputRef, Set<Range> ranges,
ColumnDataType dataType) {
List<RexExpression> result = new ArrayList<>(ranges.size() + 1);
Expand Down