Skip to content

Commit

Permalink
Use Return statement as the body for SQL-invoked functions
Browse files Browse the repository at this point in the history
  • Loading branch information
caithagoras0 committed Feb 8, 2020
1 parent dc129d8 commit c2c90c0
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private SqlInvokedFunctionTestUtils()
parseTypeSignature(DOUBLE),
"power tower",
RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).build(),
"pow(x, x)",
"RETURN pow(x, x)",
Optional.empty());

public static final SqlInvokedFunction FUNCTION_POWER_TOWER_DOUBLE_UPDATED = new SqlInvokedFunction(
Expand All @@ -53,7 +53,7 @@ private SqlInvokedFunctionTestUtils()
parseTypeSignature(DOUBLE),
"power tower",
RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(),
"pow(x, x)",
"RETURN pow(x, x)",
Optional.empty());

public static final SqlInvokedFunction FUNCTION_POWER_TOWER_INT = new SqlInvokedFunction(
Expand All @@ -62,6 +62,6 @@ private SqlInvokedFunctionTestUtils()
parseTypeSignature(INTEGER),
"power tower",
RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(),
"pow(x, x)",
"RETURN pow(x, x)",
Optional.empty());
}
Original file line number Diff line number Diff line change
Expand Up @@ -580,12 +580,12 @@ protected Scope visitCreateFunction(CreateFunction node, Optional<Scope> scope)
Scope functionScope = Scope.builder()
.withRelationType(RelationId.anonymous(), new RelationType(fields))
.build();
Type bodyType = analyzeExpression(node.getBody(), functionScope).getExpressionTypes().get(NodeRef.of(node.getBody()));
Type bodyType = analyzeExpression(node.getBody().getExpression(), functionScope).getExpressionTypes().get(NodeRef.of(node.getBody().getExpression()));
if (!bodyType.equals(returnType)) {
throw new SemanticException(TYPE_MISMATCH, node, "Function implementation type '%s' does not match declared return type '%s'", bodyType, returnType);
}

Analyzer.verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), node.getBody(), "CREATE FUNCTION body");
Analyzer.verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), metadata.getFunctionManager(), node.getBody().getExpression(), "CREATE FUNCTION body");

// TODO: Check body contains no SQL invoked functions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ private static Expression parseSqlFunctionExpression(SqlInvokedScalarFunctionImp
ParsingOptions parsingOptions = ParsingOptions.builder()
.setDecimalLiteralTreatment(sqlFunctionProperties.isParseDecimalLiteralAsDouble() ? AS_DOUBLE : AS_DECIMAL)
.build();
return new SqlParser().createExpression(functionImplementation.getImplementation(), parsingOptions);
// TODO: Use injector-created SqlParser, which could potentially be different from the adhoc SqlParser.
return new SqlParser().createRoutineBody(functionImplementation.getImplementation(), parsingOptions).getExpression();
}

private static Map<String, Type> getFunctionArgumentTypes(FunctionMetadata functionMetadata, Metadata metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ standaloneExpression
: expression EOF
;

standaloneRoutineBody
: routineBody EOF
;

statement
: query #statementDefault
| USE schema=identifier #use
Expand Down Expand Up @@ -170,6 +174,10 @@ alterRoutineCharacteristic
;

routineBody
: returnStatement
;

returnStatement
: RETURN expression
;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
import com.facebook.presto.sql.tree.RenameSchema;
import com.facebook.presto.sql.tree.RenameTable;
import com.facebook.presto.sql.tree.ResetSession;
import com.facebook.presto.sql.tree.Return;
import com.facebook.presto.sql.tree.Revoke;
import com.facebook.presto.sql.tree.RevokeRoles;
import com.facebook.presto.sql.tree.Rollback;
Expand Down Expand Up @@ -570,8 +571,9 @@ protected Void visitCreateFunction(CreateFunction node, Integer indent)
}
builder.append("\n")
.append(formatRoutineCharacteristics(node.getCharacteristics()))
.append("\nRETURN ")
.append(formatExpression(node.getBody(), parameters));
.append("\n");

process(node.getBody(), 0);

return null;
}
Expand Down Expand Up @@ -601,6 +603,15 @@ protected Void visitDropFunction(DropFunction node, Integer indent)
return null;
}

@Override
protected Void visitReturn(Return node, Integer indent)
{
append(indent, "RETURN ");
builder.append(formatExpression(node.getExpression(), parameters));

return null;
}

@Override
protected Void visitDropView(DropView node, Integer context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
import com.facebook.presto.sql.tree.RenameSchema;
import com.facebook.presto.sql.tree.RenameTable;
import com.facebook.presto.sql.tree.ResetSession;
import com.facebook.presto.sql.tree.Return;
import com.facebook.presto.sql.tree.Revoke;
import com.facebook.presto.sql.tree.RevokeRoles;
import com.facebook.presto.sql.tree.Rollback;
Expand Down Expand Up @@ -218,6 +219,12 @@ public Node visitStandaloneExpression(SqlBaseParser.StandaloneExpressionContext
return visit(context.expression());
}

@Override
public Node visitStandaloneRoutineBody(SqlBaseParser.StandaloneRoutineBodyContext context)
{
return visit(context.routineBody());
}

// ******************* statements **********************

@Override
Expand Down Expand Up @@ -421,7 +428,7 @@ public Node visitCreateFunction(SqlBaseParser.CreateFunctionContext context)
getType(context.returnType),
comment,
getRoutineCharacteristics(context.routineCharacteristics()),
(Expression) visit(context.routineBody()));
(Return) visit(context.routineBody()));
}

@Override
Expand All @@ -445,7 +452,13 @@ public Node visitDropFunction(SqlBaseParser.DropFunctionContext context)
@Override
public Node visitRoutineBody(SqlBaseParser.RoutineBodyContext context)
{
return visit(context.expression());
return visit(context.returnStatement());
}

@Override
public Node visitReturnStatement(SqlBaseParser.ReturnStatementContext context)
{
return new Return((Expression) visit(context.expression()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.Return;
import com.facebook.presto.sql.tree.Statement;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.CharStreams;
Expand Down Expand Up @@ -112,6 +113,11 @@ public Expression createExpression(String expression, ParsingOptions parsingOpti
return (Expression) invokeParser("expression", expression, SqlBaseParser::standaloneExpression, parsingOptions);
}

public Return createRoutineBody(String routineBody, ParsingOptions parsingOptions)
{
return (Return) invokeParser("routineBody", routineBody, SqlBaseParser::standaloneRoutineBody, parsingOptions);
}

private Node invokeParser(String name, String sql, Function<SqlBaseParser, ParserRuleContext> parseFunction, ParsingOptions parsingOptions)
{
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -761,4 +761,9 @@ protected R visitCurrentUser(CurrentUser node, C context)
{
return visitExpression(node, context);
}

protected R visitReturn(Return node, C context)
{
return visitStatement(node, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@ public class CreateFunction
private final String returnType;
private final Optional<String> comment;
private final RoutineCharacteristics characteristics;
private final Expression body;
private final Return body;

public CreateFunction(QualifiedName functionName, boolean replace, List<SqlParameterDeclaration> parameters, String returnType, Optional<String> comment, RoutineCharacteristics characteristics, Expression body)
public CreateFunction(QualifiedName functionName, boolean replace, List<SqlParameterDeclaration> parameters, String returnType, Optional<String> comment, RoutineCharacteristics characteristics, Return body)
{
this(Optional.empty(), replace, functionName, parameters, returnType, comment, characteristics, body);
}

public CreateFunction(NodeLocation location, boolean replace, QualifiedName functionName, List<SqlParameterDeclaration> parameters, String returnType, Optional<String> comment, RoutineCharacteristics characteristics, Expression body)
public CreateFunction(NodeLocation location, boolean replace, QualifiedName functionName, List<SqlParameterDeclaration> parameters, String returnType, Optional<String> comment, RoutineCharacteristics characteristics, Return body)
{
this(Optional.of(location), replace, functionName, parameters, returnType, comment, characteristics, body);
}

private CreateFunction(Optional<NodeLocation> location, boolean replace, QualifiedName functionName, List<SqlParameterDeclaration> parameters, String returnType, Optional<String> comment, RoutineCharacteristics characteristics, Expression body)
private CreateFunction(Optional<NodeLocation> location, boolean replace, QualifiedName functionName, List<SqlParameterDeclaration> parameters, String returnType, Optional<String> comment, RoutineCharacteristics characteristics, Return body)
{
super(location);
this.functionName = requireNonNull(functionName, "functionName is null");
Expand Down Expand Up @@ -85,7 +85,7 @@ public RoutineCharacteristics getCharacteristics()
return characteristics;
}

public Expression getBody()
public Return getBody()
{
return body;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.tree;

import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

public class Return
extends Statement
{
private final Expression expression;

public Return(Expression expression)
{
this(Optional.empty(), expression);
}

public Return(NodeLocation location, Expression expression)
{
this(Optional.of(location), expression);
}

private Return(Optional<NodeLocation> location, Expression expression)
{
super(location);
this.expression = requireNonNull(expression, "Expression is null");
}

public Expression getExpression()
{
return expression;
}

@Override
public <R, C> R accept(AstVisitor<R, C> visitor, C context)
{
return visitor.visitReturn(this, context);
}

@Override
public List<Node> getChildren()
{
return ImmutableList.of(expression);
}

@Override
public int hashCode()
{
return Objects.hash(expression);
}

@Override
public boolean equals(Object obj)
{
if (this == obj) {
return true;
}
if ((obj == null) || (getClass() != obj.getClass())) {
return false;
}
Return o = (Return) obj;
return Objects.equals(expression, o.expression);
}

@Override
public String toString()
{
return toStringHelper(this)
.add("expression", expression)
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
import com.facebook.presto.sql.tree.RenameSchema;
import com.facebook.presto.sql.tree.RenameTable;
import com.facebook.presto.sql.tree.ResetSession;
import com.facebook.presto.sql.tree.Return;
import com.facebook.presto.sql.tree.Revoke;
import com.facebook.presto.sql.tree.RevokeRoles;
import com.facebook.presto.sql.tree.Rollback;
Expand Down Expand Up @@ -1460,10 +1461,10 @@ public void testCreateFunction()
"double",
Optional.of("tangent trigonometric function"),
new RoutineCharacteristics(SQL, DETERMINISTIC, RETURNS_NULL_ON_NULL_INPUT),
new ArithmeticBinaryExpression(
new Return(new ArithmeticBinaryExpression(
DIVIDE,
new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(identifier("x"))),
new FunctionCall(QualifiedName.of("cos"), ImmutableList.of(identifier("x"))))));
new FunctionCall(QualifiedName.of("cos"), ImmutableList.of(identifier("x")))))));

CreateFunction createFunctionRand = new CreateFunction(
QualifiedName.of("dev", "testing", "rand"),
Expand All @@ -1472,7 +1473,7 @@ public void testCreateFunction()
"double",
Optional.empty(),
new RoutineCharacteristics(SQL, NOT_DETERMINISTIC, CALLED_ON_NULL_INPUT),
new FunctionCall(QualifiedName.of("rand"), ImmutableList.of()));
new Return(new FunctionCall(QualifiedName.of("rand"), ImmutableList.of())));
assertStatement(
"CREATE OR REPLACE FUNCTION dev.testing.rand ()\n" +
"RETURNS double\n" +
Expand Down

0 comments on commit c2c90c0

Please sign in to comment.