From c2c90c079c95e54e69eda99eaa052cf661f7c42d Mon Sep 17 00:00:00 2001 From: Leiqing Cai Date: Thu, 6 Feb 2020 13:04:00 -0800 Subject: [PATCH] Use Return statement as the body for SQL-invoked functions --- .../testing/SqlInvokedFunctionTestUtils.java | 6 +- .../sql/analyzer/StatementAnalyzer.java | 4 +- .../sql/relational/SqlFunctionUtils.java | 3 +- .../com/facebook/presto/sql/parser/SqlBase.g4 | 8 ++ .../com/facebook/presto/sql/SqlFormatter.java | 15 +++- .../presto/sql/parser/AstBuilder.java | 17 +++- .../facebook/presto/sql/parser/SqlParser.java | 6 ++ .../facebook/presto/sql/tree/AstVisitor.java | 5 ++ .../presto/sql/tree/CreateFunction.java | 10 +-- .../com/facebook/presto/sql/tree/Return.java | 89 +++++++++++++++++++ .../presto/sql/parser/TestSqlParser.java | 7 +- 11 files changed, 152 insertions(+), 18 deletions(-) create mode 100644 presto-parser/src/main/java/com/facebook/presto/sql/tree/Return.java diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/testing/SqlInvokedFunctionTestUtils.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/testing/SqlInvokedFunctionTestUtils.java index a1fc2184ad4ef..d5d81dfde5af0 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/testing/SqlInvokedFunctionTestUtils.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/testing/SqlInvokedFunctionTestUtils.java @@ -44,7 +44,7 @@ private SqlInvokedFunctionTestUtils() parseTypeSignature(DOUBLE), "power tower", RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).build(), - "pow(x, x)", + "RETURN pow(x, x)", Optional.empty()); public static final SqlInvokedFunction FUNCTION_POWER_TOWER_DOUBLE_UPDATED = new SqlInvokedFunction( @@ -53,7 +53,7 @@ private SqlInvokedFunctionTestUtils() parseTypeSignature(DOUBLE), "power tower", RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), - "pow(x, x)", + "RETURN pow(x, x)", Optional.empty()); public static final SqlInvokedFunction FUNCTION_POWER_TOWER_INT = new SqlInvokedFunction( @@ -62,6 +62,6 @@ private SqlInvokedFunctionTestUtils() parseTypeSignature(INTEGER), "power tower", RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), - "pow(x, x)", + "RETURN pow(x, x)", Optional.empty()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index ae16b61558ff8..6eaf1965fcc65 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -580,12 +580,12 @@ protected Scope visitCreateFunction(CreateFunction node, Optional scope) Scope functionScope = Scope.builder() .withRelationType(RelationId.anonymous(), new RelationType(fields)) .build(); - Type bodyType = analyzeExpression(node.getBody(), functionScope).getExpressionTypes().get(NodeRef.of(node.getBody())); + Type bodyType = analyzeExpression(node.getBody().getExpression(), functionScope).getExpressionTypes().get(NodeRef.of(node.getBody().getExpression())); if (!bodyType.equals(returnType)) { throw new SemanticException(TYPE_MISMATCH, node, "Function implementation type '%s' does not match declared return type '%s'", bodyType, returnType); } - Analyzer.verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), node.getBody(), "CREATE FUNCTION body"); + Analyzer.verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), node.getBody().getExpression(), "CREATE FUNCTION body"); // TODO: Check body contains no SQL invoked functions diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java index cc65596f1bbbf..73e79f4f75004 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java @@ -104,7 +104,8 @@ private static Expression parseSqlFunctionExpression(SqlInvokedScalarFunctionImp ParsingOptions parsingOptions = ParsingOptions.builder() .setDecimalLiteralTreatment(sqlFunctionProperties.isParseDecimalLiteralAsDouble() ? AS_DOUBLE : AS_DECIMAL) .build(); - return new SqlParser().createExpression(functionImplementation.getImplementation(), parsingOptions); + // TODO: Use injector-created SqlParser, which could potentially be different from the adhoc SqlParser. + return new SqlParser().createRoutineBody(functionImplementation.getImplementation(), parsingOptions).getExpression(); } private static Map getFunctionArgumentTypes(FunctionMetadata functionMetadata, Metadata metadata) diff --git a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 index b8c4eb20e2437..4025fd8401752 100644 --- a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 +++ b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 @@ -26,6 +26,10 @@ standaloneExpression : expression EOF ; +standaloneRoutineBody + : routineBody EOF + ; + statement : query #statementDefault | USE schema=identifier #use @@ -170,6 +174,10 @@ alterRoutineCharacteristic ; routineBody + : returnStatement + ; + +returnStatement : RETURN expression ; diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java index 6d3c22b3b3e1c..7e9b64b1e6687 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java @@ -74,6 +74,7 @@ import com.facebook.presto.sql.tree.RenameSchema; import com.facebook.presto.sql.tree.RenameTable; import com.facebook.presto.sql.tree.ResetSession; +import com.facebook.presto.sql.tree.Return; import com.facebook.presto.sql.tree.Revoke; import com.facebook.presto.sql.tree.RevokeRoles; import com.facebook.presto.sql.tree.Rollback; @@ -570,8 +571,9 @@ protected Void visitCreateFunction(CreateFunction node, Integer indent) } builder.append("\n") .append(formatRoutineCharacteristics(node.getCharacteristics())) - .append("\nRETURN ") - .append(formatExpression(node.getBody(), parameters)); + .append("\n"); + + process(node.getBody(), 0); return null; } @@ -601,6 +603,15 @@ protected Void visitDropFunction(DropFunction node, Integer indent) return null; } + @Override + protected Void visitReturn(Return node, Integer indent) + { + append(indent, "RETURN "); + builder.append(formatExpression(node.getExpression(), parameters)); + + return null; + } + @Override protected Void visitDropView(DropView node, Integer context) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java index 70925d1faf874..0a27edd17e563 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java @@ -118,6 +118,7 @@ import com.facebook.presto.sql.tree.RenameSchema; import com.facebook.presto.sql.tree.RenameTable; import com.facebook.presto.sql.tree.ResetSession; +import com.facebook.presto.sql.tree.Return; import com.facebook.presto.sql.tree.Revoke; import com.facebook.presto.sql.tree.RevokeRoles; import com.facebook.presto.sql.tree.Rollback; @@ -218,6 +219,12 @@ public Node visitStandaloneExpression(SqlBaseParser.StandaloneExpressionContext return visit(context.expression()); } + @Override + public Node visitStandaloneRoutineBody(SqlBaseParser.StandaloneRoutineBodyContext context) + { + return visit(context.routineBody()); + } + // ******************* statements ********************** @Override @@ -421,7 +428,7 @@ public Node visitCreateFunction(SqlBaseParser.CreateFunctionContext context) getType(context.returnType), comment, getRoutineCharacteristics(context.routineCharacteristics()), - (Expression) visit(context.routineBody())); + (Return) visit(context.routineBody())); } @Override @@ -445,7 +452,13 @@ public Node visitDropFunction(SqlBaseParser.DropFunctionContext context) @Override public Node visitRoutineBody(SqlBaseParser.RoutineBodyContext context) { - return visit(context.expression()); + return visit(context.returnStatement()); + } + + @Override + public Node visitReturnStatement(SqlBaseParser.ReturnStatementContext context) + { + return new Return((Expression) visit(context.expression())); } @Override diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/SqlParser.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/SqlParser.java index b84838068d15c..cd7e209ea1cdb 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/SqlParser.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/SqlParser.java @@ -15,6 +15,7 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.Return; import com.facebook.presto.sql.tree.Statement; import org.antlr.v4.runtime.BaseErrorListener; import org.antlr.v4.runtime.CharStreams; @@ -112,6 +113,11 @@ public Expression createExpression(String expression, ParsingOptions parsingOpti return (Expression) invokeParser("expression", expression, SqlBaseParser::standaloneExpression, parsingOptions); } + public Return createRoutineBody(String routineBody, ParsingOptions parsingOptions) + { + return (Return) invokeParser("routineBody", routineBody, SqlBaseParser::standaloneRoutineBody, parsingOptions); + } + private Node invokeParser(String name, String sql, Function parseFunction, ParsingOptions parsingOptions) { try { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java index f7ddacbf642db..c3e5123f7c7a2 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java @@ -761,4 +761,9 @@ protected R visitCurrentUser(CurrentUser node, C context) { return visitExpression(node, context); } + + protected R visitReturn(Return node, C context) + { + return visitStatement(node, context); + } } diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/CreateFunction.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/CreateFunction.java index 2031da7b45db7..02214a5265d4b 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/CreateFunction.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/CreateFunction.java @@ -31,19 +31,19 @@ public class CreateFunction private final String returnType; private final Optional comment; private final RoutineCharacteristics characteristics; - private final Expression body; + private final Return body; - public CreateFunction(QualifiedName functionName, boolean replace, List parameters, String returnType, Optional comment, RoutineCharacteristics characteristics, Expression body) + public CreateFunction(QualifiedName functionName, boolean replace, List parameters, String returnType, Optional comment, RoutineCharacteristics characteristics, Return body) { this(Optional.empty(), replace, functionName, parameters, returnType, comment, characteristics, body); } - public CreateFunction(NodeLocation location, boolean replace, QualifiedName functionName, List parameters, String returnType, Optional comment, RoutineCharacteristics characteristics, Expression body) + public CreateFunction(NodeLocation location, boolean replace, QualifiedName functionName, List parameters, String returnType, Optional comment, RoutineCharacteristics characteristics, Return body) { this(Optional.of(location), replace, functionName, parameters, returnType, comment, characteristics, body); } - private CreateFunction(Optional location, boolean replace, QualifiedName functionName, List parameters, String returnType, Optional comment, RoutineCharacteristics characteristics, Expression body) + private CreateFunction(Optional location, boolean replace, QualifiedName functionName, List parameters, String returnType, Optional comment, RoutineCharacteristics characteristics, Return body) { super(location); this.functionName = requireNonNull(functionName, "functionName is null"); @@ -85,7 +85,7 @@ public RoutineCharacteristics getCharacteristics() return characteristics; } - public Expression getBody() + public Return getBody() { return body; } diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/Return.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Return.java new file mode 100644 index 0000000000000..a5d6357360bf4 --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Return.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class Return + extends Statement +{ + private final Expression expression; + + public Return(Expression expression) + { + this(Optional.empty(), expression); + } + + public Return(NodeLocation location, Expression expression) + { + this(Optional.of(location), expression); + } + + private Return(Optional location, Expression expression) + { + super(location); + this.expression = requireNonNull(expression, "Expression is null"); + } + + public Expression getExpression() + { + return expression; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitReturn(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(expression); + } + + @Override + public int hashCode() + { + return Objects.hash(expression); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + Return o = (Return) obj; + return Objects.equals(expression, o.expression); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("expression", expression) + .toString(); + } +} diff --git a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java index af5efdf6c8277..07095490153d1 100644 --- a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java +++ b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java @@ -102,6 +102,7 @@ import com.facebook.presto.sql.tree.RenameSchema; import com.facebook.presto.sql.tree.RenameTable; import com.facebook.presto.sql.tree.ResetSession; +import com.facebook.presto.sql.tree.Return; import com.facebook.presto.sql.tree.Revoke; import com.facebook.presto.sql.tree.RevokeRoles; import com.facebook.presto.sql.tree.Rollback; @@ -1460,10 +1461,10 @@ public void testCreateFunction() "double", Optional.of("tangent trigonometric function"), new RoutineCharacteristics(SQL, DETERMINISTIC, RETURNS_NULL_ON_NULL_INPUT), - new ArithmeticBinaryExpression( + new Return(new ArithmeticBinaryExpression( DIVIDE, new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(identifier("x"))), - new FunctionCall(QualifiedName.of("cos"), ImmutableList.of(identifier("x")))))); + new FunctionCall(QualifiedName.of("cos"), ImmutableList.of(identifier("x"))))))); CreateFunction createFunctionRand = new CreateFunction( QualifiedName.of("dev", "testing", "rand"), @@ -1472,7 +1473,7 @@ public void testCreateFunction() "double", Optional.empty(), new RoutineCharacteristics(SQL, NOT_DETERMINISTIC, CALLED_ON_NULL_INPUT), - new FunctionCall(QualifiedName.of("rand"), ImmutableList.of())); + new Return(new FunctionCall(QualifiedName.of("rand"), ImmutableList.of()))); assertStatement( "CREATE OR REPLACE FUNCTION dev.testing.rand ()\n" + "RETURNS double\n" +