Skip to content

Commit

Permalink
[Refactor] refactor error msg for high-order functions (StarRocks#30766)
Browse files Browse the repository at this point in the history
Signed-off-by: Zhuhe Fang <fzhedu@gmail.com>
  • Loading branch information
fzhedu authored Sep 13, 2023
1 parent 15f0cae commit 07f5efd
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 78 deletions.
6 changes: 3 additions & 3 deletions fe/fe-core/src/main/java/com/starrocks/analysis/Expr.java
Original file line number Diff line number Diff line change
Expand Up @@ -1392,17 +1392,17 @@ public boolean hasLambdaFunction(Expr expression) {
}
if (num == 1 && (idx == 0 || idx == children.size() - 1)) {
if (children.size() <= 1) {
throw new SemanticException("Lambda functions need array inputs in high-order functions.");
throw new SemanticException("Lambda functions need array/map inputs in high-order functions");
}
return true;
} else if (num > 1) {
throw new SemanticException("A high-order function should have only 1 lambda function, " +
"but there are " + num + " lambda functions.");
"but there are " + num + " lambda functions");
} else if (idx > 0 && idx < children.size() - 1) {
throw new SemanticException(
"Lambda functions should only be the first or last argument of any high-order function, " +
"or lambda arguments should be in () if there are more than one lambda arguments, " +
"like (x,y)->x+y.");
"like (x,y)->x+y");
} else if (num == 0) {
if (expression instanceof FunctionCallExpr) {
String funcName = ((FunctionCallExpr) expression).getFnName().getFunction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ public Expr clone() {
return new VariableExpr(this);
}

@Override
public String toSqlImpl() {
String msg = setType.toString() + " " + name;
if (value != null) {
msg += " = " + value.toString();
}
return msg;
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), name, setType, value, isNull);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ void analyzeHighOrderFunction(Visitor visitor, Expr expression, Scope scope) {
funcName = expression.toString();
}
throw new SemanticException(funcName + " can't use lambda functions, " +
"as it is not a supported high-order function", expression.getPos());
"as it is not a supported high-order function");
}
int childSize = expression.getChildren().size();
// move the lambda function to the first if it is at the last.
Expand All @@ -279,7 +279,7 @@ void analyzeHighOrderFunction(Visitor visitor, Expr expression, Scope scope) {
expr.setType(Type.ARRAY_INT); // Let it have item type.
}
if (!expr.getType().isArrayType()) {
throw new SemanticException(i + "th lambda input should be arrays", expr.getPos());
throw new SemanticException(i + "-th lambda input ( " + expr + " ) should be arrays");
}
Type itemType = ((ArrayType) expr.getType()).getItemType();
scope.putLambdaInput(new PlaceHolderExpr(-1, expr.isNullable(), itemType));
Expand All @@ -290,21 +290,21 @@ void analyzeHighOrderFunction(Visitor visitor, Expr expression, Scope scope) {
// map_apply(func, map)
if (functionCallExpr.getFnName().getFunction().equals(FunctionSet.MAP_APPLY)) {
if (!(expression.getChild(0).getChild(0) instanceof MapExpr)) {
throw new SemanticException("The right part of map lambda function (" +
expression.getChild(0).toSql() + ") should have key and value arguments",
throw new SemanticException("The right part of map lambda function ( " +
expression.getChild(0).toSql() + " ) should have key and value arguments",
expression.getChild(0).getPos());
}
} else {
if (expression.getChild(0).getChild(0) instanceof MapExpr) {
throw new SemanticException("The right part of map lambda function (" +
expression.getChild(0).toSql() + ") should have only one arguments",
throw new SemanticException("The right part of map lambda function ( " +
expression.getChild(0).toSql() + " ) should have only one arguments",
expression.getChild(0).getPos());
}
}
if (expression.getChild(0).getChildren().size() != 3) {
Expr child = expression.getChild(0);
throw new SemanticException("The left part of map lambda function (" +
child.toSql() + ") should have 2 arguments, but there are "
throw new SemanticException("The left part of map lambda function ( " +
child.toSql() + " ) should have 2 arguments, but there are "
+ (child.getChildren().size() - 1) + " arguments", child.getPos());
}
Expr expr = expression.getChild(1);
Expand All @@ -313,7 +313,7 @@ void analyzeHighOrderFunction(Visitor visitor, Expr expression, Scope scope) {
expr.setType(Type.ANY_MAP); // Let it have item type.
}
if (!expr.getType().isMapType()) {
throw new SemanticException("Lambda inputs should be maps", expr.getPos());
throw new SemanticException("Lambda input ( " + expr.toSql() + " ) should be maps");
}
Type keyType = ((MapType) expr.getType()).getKeyType();
Type valueType = ((MapType) expr.getType()).getValueType();
Expand Down Expand Up @@ -348,14 +348,34 @@ void analyzeHighOrderFunction(Visitor visitor, Expr expression, Scope scope) {
}

private void bottomUpAnalyze(Visitor visitor, Expr expression, Scope scope) {
if (expression.hasLambdaFunction(expression)) {
analyzeHighOrderFunction(visitor, expression, scope);
boolean hasLambdaFunc = false;
String originalSQL = expression.toSql();
try {
hasLambdaFunc = expression.hasLambdaFunction(expression);
} catch (SemanticException e) {
if (e.canNested()) {
throw new SemanticException(e.getDetailMsg() + " in " + originalSQL, expression.getPos(), false);
} else {
throw e;
}
}
if (hasLambdaFunc) {
try {
analyzeHighOrderFunction(visitor, expression, scope);
visitor.visit(expression, scope);
} catch (SemanticException e) {
if (e.canNested()) {
throw new SemanticException(e.getDetailMsg() + " in " + originalSQL, expression.getPos(), false);
} else {
throw e;
}
}
} else {
for (Expr expr : expression.getChildren()) {
bottomUpAnalyze(visitor, expr, scope);
}
visitor.visit(expression, scope);
}
visitor.visit(expression, scope);
}

public static class Visitor extends AstVisitor<Void, Scope> {
Expand Down Expand Up @@ -1185,17 +1205,17 @@ private void checkFunction(String fnName, FunctionCallExpr node, Type[] argument
node.getPos());
}
if (!node.getChild(0).getType().isArrayType() && !node.getChild(0).getType().isNull()) {
throw new SemanticException("The first input of " + fnName +
throw new SemanticException(fnName + "'s first input " + node.getChild(0).toSql() +
" should be an array or a lambda function", node.getPos());
}
if (!node.getChild(1).getType().isArrayType() && !node.getChild(1).getType().isNull()) {
throw new SemanticException("The second input of " + fnName +
throw new SemanticException(fnName + "'s second input " + node.getChild(1).toSql() +
" should be an array or a lambda function", node.getPos());
}
// force the second array be of Type.ARRAY_BOOLEAN
if (!Type.canCastTo(node.getChild(1).getType(), Type.ARRAY_BOOLEAN)) {
throw new SemanticException("The second input of array_filter " +
node.getChild(1).getType().toString() + " can't cast to ARRAY<BOOL>", node.getPos());
throw new SemanticException(fnName + "'s second input " + node.getChild(1).toSql() +
" can't cast to ARRAY<BOOL>", node.getPos());
}
break;
case FunctionSet.ALL_MATCH:
Expand All @@ -1204,13 +1224,13 @@ private void checkFunction(String fnName, FunctionCallExpr node, Type[] argument
throw new SemanticException(fnName + " should have a input array", node.getPos());
}
if (!node.getChild(0).getType().isArrayType() && !node.getChild(0).getType().isNull()) {
throw new SemanticException("The first input of " + fnName + " should be an array",
node.getPos());
throw new SemanticException(fnName + "'s input " + node.getChild(0).toSql() + " should be " +
"an array", node.getPos());
}
// force the second array be of Type.ARRAY_BOOLEAN
// force the input array be of Type.ARRAY_BOOLEAN
if (!Type.canCastTo(node.getChild(0).getType(), Type.ARRAY_BOOLEAN)) {
throw new SemanticException("The second input of " + fnName +
node.getChild(0).getType().toString() + " can't cast to ARRAY<BOOL>", node.getPos());
throw new SemanticException(fnName + "'s input " +
node.getChild(0).toSql() + " can't cast to ARRAY<BOOL>", node.getPos());
}
break;
case FunctionSet.ARRAY_SORTBY:
Expand All @@ -1219,11 +1239,11 @@ private void checkFunction(String fnName, FunctionCallExpr node, Type[] argument
node.getPos());
}
if (!node.getChild(0).getType().isArrayType() && !node.getChild(0).getType().isNull()) {
throw new SemanticException("The first input of " + fnName +
throw new SemanticException(fnName + "'s first input " + node.getChild(0).toSql() +
" should be an array or a lambda function", node.getPos());
}
if (!node.getChild(1).getType().isArrayType() && !node.getChild(1).getType().isNull()) {
throw new SemanticException("The second input of " + fnName +
throw new SemanticException(fnName + "'s second input " + node.getChild(1).toSql() +
" should be an array or a lambda function", node.getPos());
}
break;
Expand All @@ -1244,20 +1264,20 @@ private void checkFunction(String fnName, FunctionCallExpr node, Type[] argument
case FunctionSet.MAP_FILTER:
if (node.getChildren().size() != 2) {
throw new SemanticException(fnName + " should have 2 inputs, " +
"but there are just " + node.getChildren().size() + " inputs.");
"but there are just " + node.getChildren().size() + " inputs");
}
if (!node.getChild(0).getType().isMapType() && !node.getChild(0).getType().isNull()) {
throw new SemanticException("The first input of " + fnName +
throw new SemanticException(fnName + "'s first input " + node.getChild(0).toSql() +
" should be a map or a lambda function.");
}
if (!node.getChild(1).getType().isArrayType() && !node.getChild(1).getType().isNull()) {
throw new SemanticException("The second input of " + fnName +
" should be a array or a lambda function.");
throw new SemanticException(fnName + "'s second input " + node.getChild(1).toSql() +
" should be an array or a lambda function.");
}
// force the second array be of Type.ARRAY_BOOLEAN
if (!Type.canCastTo(node.getChild(1).getType(), Type.ARRAY_BOOLEAN)) {
throw new SemanticException("The second input of map_filter " +
node.getChild(1).getType().toString() + " can't cast to ARRAY<BOOL>");
throw new SemanticException(fnName + "'s second input " + node.getChild(1).toSql() +
" can't cast to ARRAY<BOOL>");
}
break;
case FunctionSet.GROUP_CONCAT:
Expand Down Expand Up @@ -1302,7 +1322,6 @@ private void checkFunction(String fnName, FunctionCallExpr node, Type[] argument
throw new SemanticException("named_struct contains duplicate subfield name: " +
name + " at " + (i + 1) + "-th input", node.getPos());
}

check.add(name.toLowerCase());
}
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public class SemanticException extends StarRocksPlannerException {

protected final NodePosition pos;

protected boolean canNested = true;


public SemanticException(String formatString) {
this(formatString, NodePosition.ZERO);
Expand All @@ -40,6 +42,17 @@ public SemanticException(String detailMsg, NodePosition pos) {
this.pos = pos;
}

public SemanticException(String detailMsg, NodePosition pos, boolean canNested) {
super(detailMsg, ErrorType.USER_ERROR);
this.detailMsg = detailMsg;
this.pos = pos;
this.canNested = canNested;
}

boolean canNested() {
return canNested;
}

public SemanticException(String formatString, Object... args) {
this(format(formatString, args), NodePosition.ZERO);
}
Expand Down
14 changes: 14 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/sql/ast/SetType.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,18 @@ public static SetType fromThrift(TVarType tType) {
}
return SetType.SESSION;
}

public String toString() {
switch (this) {
case USER:
return "USER";
case GLOBAL:
return "GLOBAL";
case SESSION:
return "SESSION";
case VERBOSE:
return "VERBOSE";
}
return "UNKNOWN";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ public void testLambdaFunction() {
analyzeFail("select transform(1)");
analyzeFail("select array_map(x->x+ array_length(array_agg(x)),[2,6]) from tarray");
analyzeFail("select array_map(x->x > count(v1), v3) from tarray");
analyzeFail("select array_map(array_map(x2->x2+1,[1,2,3]),array_map(x1->x1+x2,[1,2,3]),(x,y)->(x+y))");
}

@Test
Expand Down
Loading

0 comments on commit 07f5efd

Please sign in to comment.