Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] support user variable in analytic function #47728

Merged
merged 8 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion fe/fe-core/src/main/java/com/starrocks/analysis/Expr.java
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,11 @@ public Expr unwrapExpr(boolean implicitOnly) {

public static double getConstFromExpr(Expr e) throws AnalysisException {
Preconditions.checkState(e.isConstant());
double value = 0;
double value;
if (e instanceof UserVariableExpr) {
e = ((UserVariableExpr) e).getValue();
}

if (e instanceof LiteralExpr) {
LiteralExpr lit = (LiteralExpr) e;
value = lit.getDoubleValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,23 +316,6 @@ public boolean isDistinct() {
return fnParams.isDistinct();
}

public boolean isCountStar() {
if (fnName.getFunction().equalsIgnoreCase(FunctionSet.COUNT)) {
if (fnParams.isStar()) {
return true;
} else if (fnParams.exprs() == null || fnParams.exprs().isEmpty()) {
return true;
} else {
for (Expr expr : fnParams.exprs()) {
if (expr.isConstant()) {
return true;
}
}
}
}
return false;
}

@Override
protected void toThrift(TExprNode msg) {
// TODO: we never serialize this to thrift if it's an aggregate function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

package com.starrocks.analysis;

import com.google.common.base.Preconditions;
import com.starrocks.catalog.Type;
import com.starrocks.common.AnalysisException;
import com.starrocks.sql.ast.AstVisitor;
import com.starrocks.sql.parser.NodePosition;

Expand Down Expand Up @@ -47,6 +50,7 @@ public Expr getValue() {

public void setValue(Expr value) {
this.value = value;
this.type = value.getType();
}

@Override
Expand Down Expand Up @@ -78,4 +82,29 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(super.hashCode(), name, value);
}

@Override
public boolean isNullable() {
Preconditions.checkState(value != null, "should analyze UserVariableExpr first then invoke isNullable");
return value.isNullable();
}

@Override
public boolean isConstantImpl() {
Preconditions.checkState(value != null, "should analyze UserVariableExpr first then invoke isConstantImpl");
return value instanceof LiteralExpr;
}

@Override
public String toSqlImpl() {
return "@" + name;
}

@Override
public Expr uncheckedCastTo(Type targetType) throws AnalysisException {
Preconditions.checkState(value != null, "should analyze UserVariableExpr first then cast its value");
UserVariableExpr userVariableExpr = new UserVariableExpr(this);
userVariableExpr.setValue(value.uncheckedCastTo(targetType));
return userVariableExpr;
}
}
packy92 marked this conversation as resolved.
Show resolved Hide resolved
51 changes: 0 additions & 51 deletions fe/fe-core/src/main/java/com/starrocks/common/util/ExprUtil.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
import com.starrocks.analysis.AnalyticWindow;
import com.starrocks.analysis.Expr;
import com.starrocks.analysis.FunctionCallExpr;
import com.starrocks.analysis.LiteralExpr;
import com.starrocks.analysis.NullLiteral;
import com.starrocks.analysis.OrderByElement;
import com.starrocks.analysis.UserVariableExpr;
import com.starrocks.catalog.AggregateFunction;
import com.starrocks.catalog.Function;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.common.AnalysisException;
import com.starrocks.common.util.ExprUtil;

import java.math.BigDecimal;

Expand Down Expand Up @@ -85,7 +86,7 @@ public static void verifyAnalyticExpression(AnalyticExpr analyticExpr) {

if (isOffsetFn(analyticFunction.getFn()) && analyticFunction.getChildren().size() > 1) {
Expr offset = analyticFunction.getChild(1);
if (!ExprUtil.isPositiveConstantInteger(offset)) {
if (!isPositiveConstantInteger(offset)) {
throw new SemanticException(
"The offset parameter of LEAD/LAG must be a constant positive integer: " +
analyticFunction.toSql(), analyticFunction.getPos());
Expand All @@ -112,7 +113,11 @@ public static void verifyAnalyticExpression(AnalyticExpr analyticExpr) {
// but the nullable info in FE is a more relax than BE (such as the nullable info in upper('a') is true,
// but the actually derived column in BE is not nullableColumn)
// which make the input colum in chunk not match the _agg_input_column in BE. so add this check in FE.
if (!analyticFunction.getChild(2).isLiteral() && analyticFunction.getChild(2).isNullable()) {
Expr theThirdChild = analyticFunction.getChild(2);
if (theThirdChild instanceof UserVariableExpr) {
theThirdChild = ((UserVariableExpr) theThirdChild).getValue();
}
if (!theThirdChild.isLiteral() && theThirdChild.isNullable()) {
throw new SemanticException("The type of the third parameter of LEAD/LAG not match the type " + firstType,
analyticFunction.getChild(2).getPos());
}
Expand All @@ -123,7 +128,7 @@ public static void verifyAnalyticExpression(AnalyticExpr analyticExpr) {

if (isNtileFn(analyticFunction.getFn())) {
Expr numBuckets = analyticFunction.getChild(0);
if (!ExprUtil.isPositiveConstantInteger(numBuckets)) {
if (!isPositiveConstantInteger(numBuckets)) {
throw new SemanticException(
"The num_buckets parameter of NTILE must be a constant positive integer: " +
analyticFunction.toSql(), numBuckets.getPos());
Expand Down Expand Up @@ -365,4 +370,15 @@ private static boolean isHllAggFn(Function fn) {

return fn.functionName().equalsIgnoreCase(AnalyticExpr.HLL_UNION_AGG);
}

private static boolean isPositiveConstantInteger(Expr offset) {
if (offset instanceof UserVariableExpr) {
offset = ((UserVariableExpr) offset).getValue();
}

if (offset instanceof LiteralExpr && offset.getType().isFixedPointType()) {
return ((LiteralExpr) offset).getLongValue() > 0;
}
return false;
}
}
packy92 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -1750,9 +1750,7 @@ public Void visitUserVariableExpr(UserVariableExpr node, Scope context) {
UserVariable userVariable = session.getUserVariable(node.getName());
if (userVariable == null) {
node.setValue(NullLiteral.create(Type.STRING));
node.setType(Type.STRING);
} else {
node.setType(userVariable.getEvaluatedExpression().getType());
node.setValue(userVariable.getEvaluatedExpression());
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@
import com.starrocks.analysis.LiteralExpr;
import com.starrocks.analysis.NullLiteral;
import com.starrocks.analysis.StringLiteral;
import com.starrocks.analysis.UserVariableExpr;
import com.starrocks.catalog.AggregateFunction;
import com.starrocks.catalog.ArrayType;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.common.FeConstants;
import com.starrocks.common.util.ExprUtil;
import com.starrocks.qe.ConnectContext;

import java.util.Optional;

public class FunctionAnalyzer {

public static void analyze(FunctionCallExpr functionCallExpr) {
Expand Down Expand Up @@ -383,40 +385,39 @@ private static void analyzeBuiltinAggFunction(FunctionCallExpr functionCallExpr)
}

if (fnName.getFunction().equals(FunctionSet.APPROX_TOP_K)) {
Long k = null;
Long counterNum = null;
Optional<Long> k = Optional.empty();
Optional<Long> counterNum = Optional.empty();
Expr kExpr = null;
Expr counterNumExpr = null;
if (functionCallExpr.hasChild(1)) {
kExpr = functionCallExpr.getChild(1);
if (!ExprUtil.isPositiveConstantInteger(kExpr)) {
k = extractIntegerValue(kExpr);
if (!k.isPresent() || k.get() <= 0) {
throw new SemanticException(
"The second parameter of APPROX_TOP_K must be a constant positive integer: " +
functionCallExpr.toSql(), kExpr.getPos());
}
k = ExprUtil.getIntegerConstant(kExpr);
}
if (functionCallExpr.hasChild(2)) {
counterNumExpr = functionCallExpr.getChild(2);
if (!ExprUtil.isPositiveConstantInteger(counterNumExpr)) {
counterNum = extractIntegerValue(counterNumExpr);
if (!counterNum.isPresent() || counterNum.get() <= 0) {
throw new SemanticException(
"The third parameter of APPROX_TOP_K must be a constant positive integer: " +
functionCallExpr.toSql(), counterNumExpr.getPos());
}
counterNum = ExprUtil.getIntegerConstant(counterNumExpr);
}
if (k != null && k > FeConstants.MAX_COUNTER_NUM_OF_TOP_K) {
if (k.isPresent() && k.get() > FeConstants.MAX_COUNTER_NUM_OF_TOP_K) {
throw new SemanticException("The maximum number of the second parameter is "
+ FeConstants.MAX_COUNTER_NUM_OF_TOP_K + ", " + functionCallExpr.toSql(), kExpr.getPos());
}
if (counterNum != null) {
Preconditions.checkNotNull(k);
if (counterNum > FeConstants.MAX_COUNTER_NUM_OF_TOP_K) {
if (counterNum.isPresent()) {
if (counterNum.get() > FeConstants.MAX_COUNTER_NUM_OF_TOP_K) {
throw new SemanticException("The maximum number of the third parameter is "
+ FeConstants.MAX_COUNTER_NUM_OF_TOP_K + ", " + functionCallExpr.toSql(),
counterNumExpr.getPos());
}
if (k > counterNum) {
if (k.get() > counterNum.get()) {
throw new SemanticException(
"The second parameter must be smaller than or equal to the third parameter" +
functionCallExpr.toSql(), kExpr.getPos());
Expand Down Expand Up @@ -459,4 +460,16 @@ private static void analyzeBuiltinAggFunction(FunctionCallExpr functionCallExpr)
}
}
}

private static Optional<Long> extractIntegerValue(Expr expr) {
if (expr instanceof UserVariableExpr) {
expr = ((UserVariableExpr) expr).getValue();
}

if (expr instanceof LiteralExpr && expr.getType().isFixedPointType()) {
return Optional.of(((LiteralExpr) expr).getLongValue());
}

return Optional.empty();
}
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most risky bug in this code is:
Optional.get() without checking for presence

You can modify the code like this:

if (k.isPresent() && counterNum.isPresent() && k.get() > counterNum.get()) {
    throw new SemanticException(
        "The second parameter must be smaller than or equal to the third parameter" +
        functionCallExpr.toSql(), kExpr.getPos());
}

Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ public Visitor(ExpressionMapping expressionMapping, ColumnRefFactory columnRefFa
@Override
public ScalarOperator visit(ParseNode node, Context context) {
Expr expr = (Expr) node;
if (expressionMapping.get(expr) != null && !(expr.isConstant())) {
if (expressionMapping.get(expr) != null && !expr.isConstant()) {
return expressionMapping.get(expr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.starrocks.sql.plan;

import com.starrocks.sql.analyzer.SemanticException;
import com.starrocks.sql.ast.UserVariable;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -70,4 +71,44 @@ public void testRemoveEscapeCharacter() {
actual = UserVariable.removeEscapeCharacter(str);
Assert.assertEquals("abc\\abc", actual);
}

@Test
public void test() throws Exception {
String sql = "SELECT /*+ set_user_variable(@a = 1) */ col_1, col_2, LAG(col_2, @a, 0) OVER (ORDER BY col_1) " +
"FROM (SELECT 1 AS col_1, NULL AS col_2 UNION ALL SELECT 2 AS col_1, 4 AS col_2) AS T ORDER BY col_1;";
String plan = getFragmentPlan(sql);
assertContains(plan, "functions: [, lag(9: cast, 1, 0), ]");

sql = "SELECT /*+ set_user_variable(@a = 1, @b = 1) */ lag(v1, @a, @b) over (ORDER BY v2) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "functions: [, lag(1: v1, 1, 1), ]");

sql = "SELECT /*+ set_user_variable(@a = 1, @b = null) */ lag(@a, @a, @b) over (ORDER BY v2) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "functions: [, lag(1, 1, NULL), ]");

sql = "select /*+ set_user_variable(@a = 1, @b = 100000) */ APPROX_TOP_K(v1, @a), APPROX_TOP_K(v1, @a, @b) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "approx_top_k(1: v1, 1), approx_top_k(1: v1, 1, 100000)");

sql = "select /*+ set_user_variable(@a = 1, @b = 10) */ ntile(@a) over (partition by v2 order by v3) " +
"as bucket_id from t0;";
plan = getFragmentPlan(sql);
assertContains(plan, "functions: [, ntile(1), ]");


Exception exception = Assert.assertThrows(SemanticException.class, () -> {
String invalidSql = "select /*+ set_user_variable(@a = 1, @b = 1000000) */ APPROX_TOP_K(v1, @a), " +
"APPROX_TOP_K(v1, @a, @b)from t0";
getFragmentPlan(invalidSql);
});
assertContains(exception.getMessage(), "The maximum number of the third parameter is 100000");


exception = Assert.assertThrows(SemanticException.class, () -> {
String invalidSql = "select /*+ set_user_variable(@a = [1, 2, 3]) */ LAG(v1, @a) over (ORDER BY v2) from t0";
getFragmentPlan(invalidSql);
});
assertContains(exception.getMessage(), "The offset parameter of LEAD/LAG must be a constant positive integer");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ public void testNtileWindowFunction() throws Exception {
.analysisError("The num_buckets parameter of NTILE must be a constant positive integer");

sql = "select v1, v2, NTILE(9223372036854775808) over (partition by v1 order by v2) as j1 from t0";
starRocksAssert.query(sql).analysisError("Number out of range");
starRocksAssert.query(sql).analysisError("The num_buckets parameter of NTILE must be a constant positive integer");

sql = "select v1, v2, NTILE((select v1 from t0)) over (partition by v1 order by v2) as j1 from t0";
starRocksAssert.query(sql)
Expand Down
Loading
Loading