Skip to content

Commit 6aea881

Browse files
committed
Translate LIKE predicate to Domain
1 parent f9dd802 commit 6aea881

File tree

4 files changed

+299
-4
lines changed

4 files changed

+299
-4
lines changed

presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import com.google.common.collect.ImmutableList;
1717
import com.google.common.collect.ImmutableMap;
1818
import com.google.common.collect.PeekingIterator;
19+
import io.airlift.slice.Slice;
20+
import io.airlift.slice.Slices;
1921
import io.prestosql.Session;
2022
import io.prestosql.metadata.Metadata;
2123
import io.prestosql.metadata.OperatorNotFoundException;
@@ -34,6 +36,7 @@
3436
import io.prestosql.spi.type.DoubleType;
3537
import io.prestosql.spi.type.RealType;
3638
import io.prestosql.spi.type.Type;
39+
import io.prestosql.spi.type.VarcharType;
3740
import io.prestosql.sql.ExpressionUtils;
3841
import io.prestosql.sql.InterpretedFunctionInvoker;
3942
import io.prestosql.sql.parser.SqlParser;
@@ -47,11 +50,14 @@
4750
import io.prestosql.sql.tree.InPredicate;
4851
import io.prestosql.sql.tree.IsNotNullPredicate;
4952
import io.prestosql.sql.tree.IsNullPredicate;
53+
import io.prestosql.sql.tree.LikePredicate;
5054
import io.prestosql.sql.tree.LogicalBinaryExpression;
5155
import io.prestosql.sql.tree.NodeRef;
5256
import io.prestosql.sql.tree.NotExpression;
5357
import io.prestosql.sql.tree.NullLiteral;
58+
import io.prestosql.sql.tree.StringLiteral;
5459
import io.prestosql.sql.tree.SymbolReference;
60+
import io.prestosql.type.LikeFunctions;
5561
import io.prestosql.type.TypeCoercion;
5662

5763
import javax.annotation.Nullable;
@@ -66,6 +72,10 @@
6672
import static com.google.common.collect.ImmutableList.toImmutableList;
6773
import static com.google.common.collect.Iterables.getOnlyElement;
6874
import static com.google.common.collect.Iterators.peekingIterator;
75+
import static io.airlift.slice.SliceUtf8.countCodePoints;
76+
import static io.airlift.slice.SliceUtf8.getCodePointAt;
77+
import static io.airlift.slice.SliceUtf8.lengthOfCodePoint;
78+
import static io.airlift.slice.SliceUtf8.setCodePointAt;
6979
import static io.prestosql.spi.function.OperatorType.SATURATED_FLOOR_CAST;
7080
import static io.prestosql.sql.ExpressionUtils.and;
7181
import static io.prestosql.sql.ExpressionUtils.combineConjuncts;
@@ -857,6 +867,92 @@ protected ExtractionResult visitBetweenPredicate(BetweenPredicate node, Boolean
857867
new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax())), complement);
858868
}
859869

870+
@Override
871+
protected ExtractionResult visitLikePredicate(LikePredicate node, Boolean complement)
872+
{
873+
Optional<ExtractionResult> result = tryVisitLikePredicate(node, complement);
874+
if (result.isPresent()) {
875+
return result.get();
876+
}
877+
return super.visitLikePredicate(node, complement);
878+
}
879+
880+
private Optional<ExtractionResult> tryVisitLikePredicate(LikePredicate node, Boolean complement)
881+
{
882+
if (!(node.getValue() instanceof SymbolReference)) {
883+
// LIKE not on a symbol
884+
return Optional.empty();
885+
}
886+
887+
if (!(node.getPattern() instanceof StringLiteral)) {
888+
// dynamic pattern
889+
return Optional.empty();
890+
}
891+
892+
if (node.getEscape().isPresent() && !(node.getEscape().get() instanceof StringLiteral)) {
893+
// dynamic escape
894+
return Optional.empty();
895+
}
896+
897+
Type type = typeAnalyzer.getType(session, types, node.getValue());
898+
if (!(type instanceof VarcharType)) {
899+
// TODO support CharType
900+
return Optional.empty();
901+
}
902+
VarcharType varcharType = (VarcharType) type;
903+
904+
Symbol symbol = Symbol.from(node.getValue());
905+
Slice pattern = ((StringLiteral) node.getPattern()).getSlice();
906+
Optional<Slice> escape = node.getEscape()
907+
.map(StringLiteral.class::cast)
908+
.map(StringLiteral::getSlice);
909+
910+
int patternConstantPrefixBytes = LikeFunctions.patternConstantPrefixBytes(pattern, escape);
911+
if (patternConstantPrefixBytes == pattern.length()) {
912+
// This should not actually happen, constant LIKE pattern should be converted to equality predicate before DomainTranslator is invoked.
913+
914+
Slice literal = LikeFunctions.unescapeLiteralLikePattern(pattern, escape);
915+
ValueSet valueSet;
916+
if (varcharType.isUnbounded() || countCodePoints(literal) <= varcharType.getBoundedLength()) {
917+
valueSet = ValueSet.of(type, literal);
918+
}
919+
else {
920+
// impossible to satisfy
921+
valueSet = ValueSet.none(type);
922+
}
923+
Domain domain = Domain.create(complementIfNecessary(valueSet, complement), false);
924+
return Optional.of(new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), TRUE_LITERAL));
925+
}
926+
927+
if (complement || patternConstantPrefixBytes == 0) {
928+
// TODO
929+
return Optional.empty();
930+
}
931+
932+
Slice constantPrefix = LikeFunctions.unescapeLiteralLikePattern(pattern.slice(0, patternConstantPrefixBytes), escape);
933+
934+
int lastIncrementable = -1;
935+
for (int position = 0; position < constantPrefix.length(); position += lengthOfCodePoint(constantPrefix, position)) {
936+
// Get last ASCII character to increment, so that character length in bytes does not change.
937+
// Also prefer not to produce non-ASCII if input is all-ASCII, to be on the safe side with connectors.
938+
// TODO remove those limitations
939+
if (getCodePointAt(constantPrefix, position) < 127) {
940+
lastIncrementable = position;
941+
}
942+
}
943+
944+
if (lastIncrementable == -1) {
945+
return Optional.empty();
946+
}
947+
948+
Slice lowerBound = constantPrefix;
949+
Slice upperBound = Slices.copyOf(constantPrefix.slice(0, lastIncrementable + lengthOfCodePoint(constantPrefix, lastIncrementable)));
950+
setCodePointAt(getCodePointAt(constantPrefix, lastIncrementable) + 1, upperBound, lastIncrementable);
951+
952+
Domain domain = Domain.create(ValueSet.ofRanges(Range.range(type, lowerBound, true, upperBound, false)), false);
953+
return Optional.of(new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), node));
954+
}
955+
860956
@Override
861957
protected ExtractionResult visitIsNullPredicate(IsNullPredicate node, Boolean complement)
862958
{

presto-main/src/test/java/io/prestosql/sql/planner/TestDomainTranslator.java

Lines changed: 144 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import io.prestosql.sql.tree.InListExpression;
4040
import io.prestosql.sql.tree.InPredicate;
4141
import io.prestosql.sql.tree.IsNullPredicate;
42+
import io.prestosql.sql.tree.LikePredicate;
4243
import io.prestosql.sql.tree.Literal;
4344
import io.prestosql.sql.tree.LongLiteral;
4445
import io.prestosql.sql.tree.NotExpression;
@@ -56,6 +57,7 @@
5657
import java.math.BigDecimal;
5758
import java.util.Arrays;
5859
import java.util.List;
60+
import java.util.Optional;
5961
import java.util.concurrent.TimeUnit;
6062

6163
import static io.airlift.slice.Slices.utf8Slice;
@@ -77,6 +79,7 @@
7779
import static io.prestosql.spi.type.TinyintType.TINYINT;
7880
import static io.prestosql.spi.type.VarbinaryType.VARBINARY;
7981
import static io.prestosql.spi.type.VarcharType.VARCHAR;
82+
import static io.prestosql.spi.type.VarcharType.createUnboundedVarcharType;
8083
import static io.prestosql.sql.ExpressionUtils.and;
8184
import static io.prestosql.sql.ExpressionUtils.or;
8285
import static io.prestosql.sql.analyzer.TypeSignatureTranslator.toSqlType;
@@ -1457,6 +1460,128 @@ private void testNumericTypeTranslation(NumericValues<?> columnValues, NumericVa
14571460
}
14581461
}
14591462

1463+
@Test
1464+
public void testLikePredicate()
1465+
{
1466+
Type varcharType = createUnboundedVarcharType();
1467+
1468+
// constant
1469+
testSimpleComparison(
1470+
like(C_VARCHAR, stringLiteral("abc")),
1471+
C_VARCHAR,
1472+
Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc"))));
1473+
1474+
// starts with pattern
1475+
assertUnsupportedPredicate(like(C_VARCHAR, stringLiteral("_def")));
1476+
assertUnsupportedPredicate(like(C_VARCHAR, stringLiteral("%def")));
1477+
1478+
// _ pattern (unless escaped)
1479+
testSimpleComparison(
1480+
like(C_VARCHAR, stringLiteral("abc_def")),
1481+
C_VARCHAR,
1482+
like(C_VARCHAR, stringLiteral("abc_def")),
1483+
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc"), true, utf8Slice("abd"), false)), false));
1484+
1485+
testSimpleComparison(
1486+
like(C_VARCHAR, stringLiteral("abc\\_def")),
1487+
C_VARCHAR,
1488+
like(C_VARCHAR, stringLiteral("abc\\_def")),
1489+
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc\\"), true, utf8Slice("abc]"), false)), false));
1490+
1491+
testSimpleComparison(
1492+
like(C_VARCHAR, stringLiteral("abc\\_def"), stringLiteral("\\")),
1493+
C_VARCHAR,
1494+
Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc_def"))));
1495+
1496+
testSimpleComparison(
1497+
like(C_VARCHAR, stringLiteral("abc\\_def_"), stringLiteral("\\")),
1498+
C_VARCHAR,
1499+
like(C_VARCHAR, stringLiteral("abc\\_def_"), stringLiteral("\\")),
1500+
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc_def"), true, utf8Slice("abc_deg"), false)), false));
1501+
1502+
testSimpleComparison(
1503+
like(C_VARCHAR, stringLiteral("abc^_def_"), stringLiteral("^")),
1504+
C_VARCHAR,
1505+
like(C_VARCHAR, stringLiteral("abc^_def_"), stringLiteral("^")),
1506+
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc_def"), true, utf8Slice("abc_deg"), false)), false));
1507+
1508+
// % pattern (unless escaped)
1509+
testSimpleComparison(
1510+
like(C_VARCHAR, stringLiteral("abc%")),
1511+
C_VARCHAR,
1512+
like(C_VARCHAR, stringLiteral("abc%")),
1513+
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc"), true, utf8Slice("abd"), false)), false));
1514+
1515+
testSimpleComparison(
1516+
like(C_VARCHAR, stringLiteral("abc%def")),
1517+
C_VARCHAR,
1518+
like(C_VARCHAR, stringLiteral("abc%def")),
1519+
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc"), true, utf8Slice("abd"), false)), false));
1520+
1521+
testSimpleComparison(
1522+
like(C_VARCHAR, stringLiteral("abc\\%def")),
1523+
C_VARCHAR,
1524+
like(C_VARCHAR, stringLiteral("abc\\%def")),
1525+
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc\\"), true, utf8Slice("abc]"), false)), false));
1526+
1527+
testSimpleComparison(
1528+
like(C_VARCHAR, stringLiteral("abc\\%def"), stringLiteral("\\")),
1529+
C_VARCHAR,
1530+
Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc%def"))));
1531+
1532+
testSimpleComparison(
1533+
like(C_VARCHAR, stringLiteral("abc\\%def_"), stringLiteral("\\")),
1534+
C_VARCHAR,
1535+
like(C_VARCHAR, stringLiteral("abc\\%def_"), stringLiteral("\\")),
1536+
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc%def"), true, utf8Slice("abc%deg"), false)), false));
1537+
1538+
testSimpleComparison(
1539+
like(C_VARCHAR, stringLiteral("abc^%def_"), stringLiteral("^")),
1540+
C_VARCHAR,
1541+
like(C_VARCHAR, stringLiteral("abc^%def_"), stringLiteral("^")),
1542+
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc%def"), true, utf8Slice("abc%deg"), false)), false));
1543+
1544+
// non-ASCII literal
1545+
testSimpleComparison(
1546+
like(C_VARCHAR, stringLiteral("abc\u007f\u0123\udbfe")),
1547+
C_VARCHAR,
1548+
Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc\u007f\u0123\udbfe"))));
1549+
1550+
// non-ASCII prefix
1551+
testSimpleComparison(
1552+
like(C_VARCHAR, stringLiteral("abc\u0123\ud83d\ude80def\u007e\u007f\u00ff\u0123\uccf0%")),
1553+
C_VARCHAR,
1554+
like(C_VARCHAR, stringLiteral("abc\u0123\ud83d\ude80def\u007e\u007f\u00ff\u0123\uccf0%")),
1555+
Domain.create(
1556+
ValueSet.ofRanges(Range.range(varcharType,
1557+
utf8Slice("abc\u0123\ud83d\ude80def\u007e\u007f\u00ff\u0123\uccf0"), true,
1558+
utf8Slice("abc\u0123\ud83d\ude80def\u007f"), false)),
1559+
false));
1560+
1561+
// dynamic escape
1562+
assertUnsupportedPredicate(like(C_VARCHAR, stringLiteral("abc\\_def"), C_VARCHAR_1.toSymbolReference()));
1563+
1564+
// negation with literal
1565+
testSimpleComparison(
1566+
not(like(C_VARCHAR, stringLiteral("abcdef"))),
1567+
C_VARCHAR,
1568+
Domain.create(ValueSet.ofRanges(
1569+
Range.lessThan(varcharType, utf8Slice("abcdef")),
1570+
Range.greaterThan(varcharType, utf8Slice("abcdef"))),
1571+
false));
1572+
1573+
testSimpleComparison(
1574+
not(like(C_VARCHAR, stringLiteral("abc\\_def"), stringLiteral("\\"))),
1575+
C_VARCHAR,
1576+
Domain.create(ValueSet.ofRanges(
1577+
Range.lessThan(varcharType, utf8Slice("abc_def")),
1578+
Range.greaterThan(varcharType, utf8Slice("abc_def"))),
1579+
false));
1580+
1581+
// negation with pattern
1582+
assertUnsupportedPredicate(not(like(C_VARCHAR, stringLiteral("abc\\_def"))));
1583+
}
1584+
14601585
@Test
14611586
public void testCharComparedToVarcharExpression()
14621587
{
@@ -1568,6 +1693,16 @@ private static ComparisonExpression isDistinctFrom(Symbol symbol, Expression exp
15681693
return isDistinctFrom(symbol.toSymbolReference(), expression);
15691694
}
15701695

1696+
private static LikePredicate like(Symbol symbol, Expression expression)
1697+
{
1698+
return new LikePredicate(symbol.toSymbolReference(), expression, Optional.empty());
1699+
}
1700+
1701+
private static LikePredicate like(Symbol symbol, Expression expression, Expression escape)
1702+
{
1703+
return new LikePredicate(symbol.toSymbolReference(), expression, Optional.of(escape));
1704+
}
1705+
15711706
private static Expression isNotNull(Symbol symbol)
15721707
{
15731708
return isNotNull(symbol.toSymbolReference());
@@ -1733,14 +1868,19 @@ private void testSimpleComparison(Expression expression, Symbol symbol, Range ex
17331868
testSimpleComparison(expression, symbol, Domain.create(ValueSet.ofRanges(expectedDomainRange), false));
17341869
}
17351870

1736-
private void testSimpleComparison(Expression expression, Symbol symbol, Domain domain)
1871+
private void testSimpleComparison(Expression expression, Symbol symbol, Domain expectedDomain)
1872+
{
1873+
testSimpleComparison(expression, symbol, TRUE_LITERAL, expectedDomain);
1874+
}
1875+
1876+
private void testSimpleComparison(Expression expression, Symbol symbol, Expression expectedRemainingExpression, Domain expectedDomain)
17371877
{
17381878
ExtractionResult result = fromPredicate(expression);
1739-
assertEquals(result.getRemainingExpression(), TRUE_LITERAL);
1879+
assertEquals(result.getRemainingExpression(), expectedRemainingExpression);
17401880
TupleDomain<Symbol> actual = result.getTupleDomain();
1741-
TupleDomain<Symbol> expected = withColumnDomains(ImmutableMap.of(symbol, domain));
1881+
TupleDomain<Symbol> expected = withColumnDomains(ImmutableMap.of(symbol, expectedDomain));
17421882
if (!actual.equals(expected)) {
1743-
fail(format("for comparison [%s] expected %s but found %s", expression.toString(), expected.toString(SESSION), actual.toString(SESSION)));
1883+
fail(format("for comparison [%s] expected [%s] but found [%s]", expression.toString(), expected.toString(SESSION), actual.toString(SESSION)));
17441884
}
17451885
}
17461886

presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,16 @@
1515

1616
import com.google.common.collect.ImmutableList;
1717
import com.google.common.collect.ImmutableMap;
18+
import io.airlift.slice.Slices;
1819
import io.prestosql.Session;
20+
import io.prestosql.plugin.tpch.TpchColumnHandle;
21+
import io.prestosql.plugin.tpch.TpchTableHandle;
1922
import io.prestosql.spi.block.SortOrder;
23+
import io.prestosql.spi.connector.ColumnHandle;
24+
import io.prestosql.spi.predicate.Domain;
25+
import io.prestosql.spi.predicate.Range;
26+
import io.prestosql.spi.predicate.TupleDomain;
27+
import io.prestosql.spi.predicate.ValueSet;
2028
import io.prestosql.sql.analyzer.FeaturesConfig.JoinDistributionType;
2129
import io.prestosql.sql.analyzer.FeaturesConfig.JoinReorderingStrategy;
2230
import io.prestosql.sql.planner.assertions.BasePlanTest;
@@ -51,10 +59,14 @@
5159
import org.testng.annotations.Test;
5260

5361
import java.util.List;
62+
import java.util.Map;
63+
import java.util.Map.Entry;
5464
import java.util.Optional;
5565
import java.util.function.Consumer;
5666
import java.util.function.Predicate;
5767

68+
import static com.google.common.collect.ImmutableList.toImmutableList;
69+
import static com.google.common.collect.MoreCollectors.toOptional;
5870
import static io.airlift.slice.Slices.utf8Slice;
5971
import static io.prestosql.SystemSessionProperties.DISTRIBUTED_SORT;
6072
import static io.prestosql.SystemSessionProperties.FORCE_SINGLE_NODE_OUTPUT;
@@ -138,6 +150,37 @@ public void testAnalyze()
138150
tableScan("orders", ImmutableMap.of()))))))))));
139151
}
140152

153+
@Test
154+
public void testLikePredicate()
155+
{
156+
assertPlan("SELECT type FROM part WHERE type LIKE 'LARGE PLATED %'",
157+
anyTree(
158+
tableScan(
159+
tableHandle -> {
160+
Map<ColumnHandle, Domain> domains = ((TpchTableHandle) tableHandle).getConstraint().getDomains()
161+
.orElseThrow(() -> new AssertionError("Unexpected none TupleDomain"));
162+
163+
Domain domain = domains.entrySet().stream()
164+
.filter(entry -> ((TpchColumnHandle) entry.getKey()).getColumnName().equals("type"))
165+
.map(Entry::getValue)
166+
.collect(toOptional())
167+
.orElseThrow(() -> new AssertionError("No domain for 'type'"));
168+
169+
assertEquals(domain, Domain.multipleValues(
170+
createVarcharType(25),
171+
ImmutableList.of("LARGE PLATED BRASS", "LARGE PLATED COPPER", "LARGE PLATED NICKEL", "LARGE PLATED STEEL", "LARGE PLATED TIN").stream()
172+
.map(Slices::utf8Slice)
173+
.collect(toImmutableList())));
174+
return true;
175+
},
176+
TupleDomain.withColumnDomains(ImmutableMap.of(
177+
tableHandle -> ((TpchColumnHandle) tableHandle).getColumnName().equals("type"),
178+
Domain.create(
179+
ValueSet.ofRanges(Range.range(createVarcharType(25), utf8Slice("LARGE PLATED "), true, utf8Slice("LARGE PLATED!"), false)),
180+
false))),
181+
ImmutableMap.of())));
182+
}
183+
141184
@Test
142185
public void testAggregation()
143186
{

0 commit comments

Comments
 (0)