Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core support for enums, take #2 #15219

Merged
merged 4 commits into from
Sep 29, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support enum literals in queries
  • Loading branch information
daniel-ohayon committed Sep 24, 2020
commit f0fd12d02de62bc5f604487ce87f53e3204710f5
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ private static Object fixValue(TypeSignature signature, Object value)
}
return fixedValue;
}
if (signature.isVarcharEnum()) {
return String.class.cast(value);
}
switch (signature.getBase()) {
case BIGINT:
if (value instanceof String) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.common.type.CharType;
import com.facebook.presto.common.type.DecimalParseResult;
import com.facebook.presto.common.type.Decimals;
import com.facebook.presto.common.type.EnumType;
import com.facebook.presto.common.type.FunctionType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.StandardTypes;
Expand Down Expand Up @@ -58,6 +59,7 @@
import com.facebook.presto.sql.tree.DecimalLiteral;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.DoubleLiteral;
import com.facebook.presto.sql.tree.EnumLiteral;
import com.facebook.presto.sql.tree.ExistsPredicate;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Extract;
Expand Down Expand Up @@ -140,6 +142,7 @@
import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy;
import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoAggregateWindowOrGroupingFunctions;
import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoExternalFunctions;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteralType;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.EXPRESSION_NOT_CONSTANT;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE;
Expand Down Expand Up @@ -432,14 +435,21 @@ protected Type visitDereferenceExpression(DereferenceExpression node, StackableA
{
QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node);

// If this Dereference looks like column reference, try match it to column first.
// Handle qualified name
if (qualifiedName != null) {
// first, try to match it to a column name
Scope scope = context.getContext().getScope();
Optional<ResolvedField> resolvedField = scope.tryResolveField(node, qualifiedName);
if (resolvedField.isPresent()) {
return handleResolvedField(node, resolvedField.get(), context);
}
// otherwise, try to match it to an enum literal (eg Mood.HAPPY)
if (!scope.isColumnReference(qualifiedName)) {
Optional<EnumType> enumType = tryResolveEnumLiteralType(qualifiedName, typeManager);
if (enumType.isPresent()) {
setExpressionType(node.getBase(), enumType.get());
return setExpressionType(node, enumType.get());
}
throw missingAttributeException(node, qualifiedName);
}
}
Expand Down Expand Up @@ -773,6 +783,20 @@ protected Type visitGenericLiteral(GenericLiteral node, StackableAstVisitorConte
return setExpressionType(node, type);
}

@Override
protected Type visitEnumLiteral(EnumLiteral node, StackableAstVisitorContext<Context> context)
{
Type type;
try {
type = typeManager.getType(parseTypeSignature(node.getType()));
}
catch (IllegalArgumentException e) {
throw new SemanticException(TYPE_MISMATCH, node, "Unknown type: " + node.getType());
}

return setExpressionType(node, type);
}

@Override
protected Type visitTimeLiteral(TimeLiteral node, StackableAstVisitorContext<Context> context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,34 @@
*/
package com.facebook.presto.sql.analyzer;

import com.facebook.presto.common.type.EnumType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;

import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.alwaysTrue;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.Slices.utf8Slice;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;

public final class ExpressionTreeUtils
Expand Down Expand Up @@ -112,4 +123,50 @@ public static boolean isEqualComparisonExpression(Expression expression)
{
return expression instanceof ComparisonExpression && ((ComparisonExpression) expression).getOperator() == ComparisonExpression.Operator.EQUAL;
}

static Optional<EnumType> tryResolveEnumLiteralType(QualifiedName qualifiedName, TypeManager typeManager)
{
Optional<QualifiedName> prefix = qualifiedName.getPrefix();
if (!prefix.isPresent()) {
// an enum literal should be of the form `MyEnum.my_key`
return Optional.empty();
}
try {
Type baseType = typeManager.getType(parseTypeSignature(prefix.get().toString()));
if (baseType instanceof EnumType) {
return Optional.of((EnumType) baseType);
}
}
catch (IllegalArgumentException e) {
return Optional.empty();
}
return Optional.empty();
}

private static boolean isEnumLiteral(DereferenceExpression node, Type nodeType)
{
if (!(nodeType instanceof EnumType)) {
return false;
}
QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node);
if (qualifiedName == null) {
return false;
}
Optional<QualifiedName> prefix = qualifiedName.getPrefix();
return prefix.isPresent()
&& prefix.get().toString().equalsIgnoreCase(nodeType.getTypeSignature().getBase());
}

public static Optional<Object> tryResolveEnumLiteral(DereferenceExpression node, Type nodeType)
{
QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node);
if (!isEnumLiteral(node, nodeType)) {
return Optional.empty();
}
EnumType enumType = (EnumType) nodeType;
String enumKey = qualifiedName.getSuffix().toUpperCase(ENGLISH);
checkArgument(enumType.getEnumMap().containsKey(enumKey), format("No key '%s' in enum '%s'", enumKey, nodeType.getDisplayName()));
Object enumValue = enumType.getEnumMap().get(enumKey);
return enumValue instanceof String ? Optional.of(utf8Slice((String) enumValue)) : Optional.of(enumValue);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ private SemanticExceptions() {}

public static SemanticException missingAttributeException(Expression node, QualifiedName name)
{
throw new SemanticException(MISSING_ATTRIBUTE, node, "Column '%s' cannot be resolved", name);
throw new SemanticException(
MISSING_ATTRIBUTE,
node,
name.getPrefix().isPresent() ? "'%s' cannot be resolved" : "Column '%s' cannot be resolved",
name);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
import static com.facebook.presto.sql.analyzer.ConstantExpressionVerifier.verifyExpressionIsConstant;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.createConstantAnalyzer;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteral;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.gen.VarArgsToMapAdapterGenerator.generateVarArgsToMapAdapter;
import static com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator.isDeterministic;
Expand Down Expand Up @@ -298,6 +299,12 @@ public Object visitFieldReference(FieldReference node, Object context)
@Override
protected Object visitDereferenceExpression(DereferenceExpression node, Object context)
{
Type returnType = type(node);
Optional<Object> maybeEnumValue = tryResolveEnumLiteral(node, returnType);
if (maybeEnumValue.isPresent()) {
return maybeEnumValue.get();
}

Type type = type(node.getBase());
// if there is no type for the base of Dereference, it must be QualifiedName
if (type == null) {
Expand All @@ -315,7 +322,6 @@ protected Object visitDereferenceExpression(DereferenceExpression node, Object c
}

RowType rowType = (RowType) type;
Type returnType = type(node);
String fieldName = node.getField().getValue();
List<Field> fields = rowType.getFields();
int index = -1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import com.facebook.presto.sql.tree.CharLiteral;
import com.facebook.presto.sql.tree.DecimalLiteral;
import com.facebook.presto.sql.tree.DoubleLiteral;
import com.facebook.presto.sql.tree.EnumLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.IntervalLiteral;
Expand Down Expand Up @@ -235,6 +236,12 @@ protected Slice visitBinaryLiteral(BinaryLiteral node, ConnectorSession session)
return node.getValue();
}

@Override
protected Object visitEnumLiteral(EnumLiteral node, ConnectorSession context)
{
return node.getValue();
}

@Override
protected Object visitGenericLiteral(GenericLiteral node, ConnectorSession session)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.EnumLiteral;
import com.facebook.presto.sql.tree.Except;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.InPredicate;
Expand Down Expand Up @@ -97,6 +100,7 @@
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isEqualComparisonExpression;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteral;
import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference;
Expand Down Expand Up @@ -670,22 +674,39 @@ protected RelationPlan visitValues(Values node, Void context)
ImmutableList.Builder<RowExpression> values = ImmutableList.builder();
if (row instanceof Row) {
for (Expression item : ((Row) row).getItems()) {
Expression expression = Coercer.addCoercions(item, analysis);
values.add(castToRowExpression(ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters(), analysis), expression)));
values.add(rewriteRow(item));
}
}
else {
Expression expression = Coercer.addCoercions(row, analysis);
values.add(castToRowExpression(ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters(), analysis), expression)));
values.add(rewriteRow(row));
}

rowsBuilder.add(values.build());
}

ValuesNode valuesNode = new ValuesNode(idAllocator.getNextId(), outputVariablesBuilder.build(), rowsBuilder.build());
return new RelationPlan(valuesNode, scope, outputVariablesBuilder.build());
}

private RowExpression rewriteRow(Expression row)
{
// resolve enum literals
Expression expression = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() {
@Override
public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
Type nodeType = analysis.getType(node);
Optional<Object> maybeEnumValue = tryResolveEnumLiteral(node, nodeType);
if (maybeEnumValue.isPresent()) {
return new EnumLiteral(nodeType.getTypeSignature().toString(), maybeEnumValue.get());
}
return node;
}
}, row);
expression = Coercer.addCoercions(expression, analysis);
expression = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters(), analysis), expression);
return castToRowExpression(expression);
}

@Override
protected RelationPlan visitUnnest(Unnest node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.sql.analyzer.ResolvedField;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.EnumLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
Expand All @@ -36,6 +37,7 @@
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteral;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -234,6 +236,13 @@ public Expression rewriteDereferenceExpression(DereferenceExpression node, Void
// do not rewrite outer references, it will be handled in outer scope planner
return node;
}

Type nodeType = analysis.getType(node);
Optional<Object> maybeEnumValue = tryResolveEnumLiteral(node, nodeType);
if (maybeEnumValue.isPresent()) {
return new EnumLiteral(nodeType.getTypeSignature().toString(), maybeEnumValue.get());
}

return rewriteExpression(node, context, treeRewriter);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.ROW_CONSTRUCTOR;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteral;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
Expand Down Expand Up @@ -562,6 +563,12 @@ private RowExpression buildSwitch(RowExpression operand, List<WhenClause> whenCl
@Override
protected RowExpression visitDereferenceExpression(DereferenceExpression node, Void context)
{
Type returnType = getType(node);
Optional<Object> maybeEnumLiteral = tryResolveEnumLiteral(node, returnType);
if (maybeEnumLiteral.isPresent()) {
return constant(maybeEnumLiteral.get(), returnType);
}

RowType rowType = (RowType) getType(node.getBase());
String fieldName = node.getField().getValue();
List<Field> fields = rowType.getFields();
Expand All @@ -582,7 +589,6 @@ protected RowExpression visitDereferenceExpression(DereferenceExpression node, V
}

checkState(index >= 0, "could not find field name: %s", node.getField());
Type returnType = getType(node);
return specialForm(DEREFERENCE, returnType, process(node.getBase(), context), constant((long) index, INTEGER));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ public void testInvalidAttribute()
assertFails(MISSING_ATTRIBUTE, "SELECT * FROM t1 WHERE f > 1");
}

@Test(expectedExceptions = SemanticException.class, expectedExceptionsMessageRegExp = "line 1:8: Column 't.y' cannot be resolved")
@Test(expectedExceptions = SemanticException.class, expectedExceptionsMessageRegExp = "line 1:8: 't.y' cannot be resolved")
public void testInvalidAttributeCorrectErrorMessage()
{
analyze("SELECT t.y FROM (VALUES 1) t(x)");
Expand Down Expand Up @@ -1165,7 +1165,7 @@ public void testCreateTableAsColumns()
assertFails(MISMATCHED_COLUMN_ALIASES, 1, 19, "CREATE TABLE test(x, y) AS (VALUES 1)");
assertFails(DUPLICATE_COLUMN_NAME, 1, 24, "CREATE TABLE test(abc, AbC) AS SELECT 1, 2");
assertFails(COLUMN_TYPE_UNKNOWN, 1, 1, "CREATE TABLE test(x) AS SELECT null");
assertFails(MISSING_ATTRIBUTE, ".*Column 'y' cannot be resolved", "CREATE TABLE test(x) WITH (p1 = y) AS SELECT null");
assertFails(MISSING_ATTRIBUTE, ".*'y' cannot be resolved", "CREATE TABLE test(x) WITH (p1 = y) AS SELECT null");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE TABLE test(x) WITH (p1 = 'p1', p2 = 'p2', p1 = 'p3') AS SELECT null");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE TABLE test(x) WITH (p1 = 'p1', \"p1\" = 'p2') AS SELECT null");
}
Expand All @@ -1176,7 +1176,7 @@ public void testCreateTable()
analyze("CREATE TABLE test (id bigint)");
analyze("CREATE TABLE test (id bigint) WITH (p1 = 'p1')");

assertFails(MISSING_ATTRIBUTE, ".*Column 'y' cannot be resolved", "CREATE TABLE test (x bigint) WITH (p1 = y)");
assertFails(MISSING_ATTRIBUTE, ".*'y' cannot be resolved", "CREATE TABLE test (x bigint) WITH (p1 = y)");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE TABLE test (id bigint) WITH (p1 = 'p1', p2 = 'p2', p1 = 'p3')");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE TABLE test (id bigint) WITH (p1 = 'p1', \"p1\" = 'p2')");
}
Expand All @@ -1197,7 +1197,7 @@ public void testCreateSchema()
analyze("CREATE SCHEMA test");
analyze("CREATE SCHEMA test WITH (p1 = 'p1')");

assertFails(MISSING_ATTRIBUTE, ".*Column 'y' cannot be resolved", "CREATE SCHEMA test WITH (p1 = y)");
assertFails(MISSING_ATTRIBUTE, ".*'y' cannot be resolved", "CREATE SCHEMA test WITH (p1 = y)");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE SCHEMA test WITH (p1 = 'p1', p2 = 'p2', p1 = 'p3')");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE SCHEMA test WITH (p1 = 'p1', \"p1\" = 'p2')");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void testColumnReferences()
"SELECT t.k FROM " +
"(VALUES (1, 'a')) AS t(k, v1) JOIN" +
"(VALUES (1, 'b')) AS u(k, v2) USING (k)",
".*Column 't.k' cannot be resolved.*");
".*'t.k' cannot be resolved.*");
}

@Test
Expand Down
Loading