Skip to content

Commit

Permalink
[Enhancement] support user variable in analytic function (backport #4…
Browse files Browse the repository at this point in the history
…7728) (#47789)

Co-authored-by: packy92 <110370499+packy92@users.noreply.github.com>
  • Loading branch information
mergify[bot] and packy92 authored Jul 3, 2024
1 parent f6972f0 commit 8f5113f
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 89 deletions.
6 changes: 5 additions & 1 deletion fe/fe-core/src/main/java/com/starrocks/analysis/Expr.java
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,11 @@ public Expr unwrapExpr(boolean implicitOnly) {

public static double getConstFromExpr(Expr e) throws AnalysisException {
Preconditions.checkState(e.isConstant());
double value = 0;
double value;
if (e instanceof UserVariableExpr) {
e = ((UserVariableExpr) e).getValue();
}

if (e instanceof LiteralExpr) {
LiteralExpr lit = (LiteralExpr) e;
value = lit.getDoubleValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,23 +316,6 @@ public boolean isDistinct() {
return fnParams.isDistinct();
}

public boolean isCountStar() {
if (fnName.getFunction().equalsIgnoreCase(FunctionSet.COUNT)) {
if (fnParams.isStar()) {
return true;
} else if (fnParams.exprs() == null || fnParams.exprs().isEmpty()) {
return true;
} else {
for (Expr expr : fnParams.exprs()) {
if (expr.isConstant()) {
return true;
}
}
}
}
return false;
}

@Override
protected void toThrift(TExprNode msg) {
// TODO: we never serialize this to thrift if it's an aggregate function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

package com.starrocks.analysis;

import com.google.common.base.Preconditions;
import com.starrocks.catalog.Type;
import com.starrocks.common.AnalysisException;
import com.starrocks.sql.ast.AstVisitor;
import com.starrocks.sql.parser.NodePosition;

Expand Down Expand Up @@ -47,6 +50,7 @@ public Expr getValue() {

public void setValue(Expr value) {
this.value = value;
this.type = value.getType();
}

@Override
Expand Down Expand Up @@ -78,4 +82,23 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(super.hashCode(), name, value);
}

@Override
public boolean isNullable() {
Preconditions.checkState(value != null, "should analyze UserVariableExpr first then invoke isNullable");
return value.isNullable();
}

@Override
public String toSqlImpl() {
return "@" + name;
}

@Override
public Expr uncheckedCastTo(Type targetType) throws AnalysisException {
Preconditions.checkState(value != null, "should analyze UserVariableExpr first then cast its value");
UserVariableExpr userVariableExpr = new UserVariableExpr(this);
userVariableExpr.setValue(value.uncheckedCastTo(targetType));
return userVariableExpr;
}
}
51 changes: 0 additions & 51 deletions fe/fe-core/src/main/java/com/starrocks/common/util/ExprUtil.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
import com.starrocks.analysis.AnalyticWindow;
import com.starrocks.analysis.Expr;
import com.starrocks.analysis.FunctionCallExpr;
import com.starrocks.analysis.LiteralExpr;
import com.starrocks.analysis.NullLiteral;
import com.starrocks.analysis.OrderByElement;
import com.starrocks.analysis.UserVariableExpr;
import com.starrocks.catalog.AggregateFunction;
import com.starrocks.catalog.Function;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.common.AnalysisException;
import com.starrocks.common.util.ExprUtil;

import java.math.BigDecimal;

Expand Down Expand Up @@ -85,7 +86,7 @@ public static void verifyAnalyticExpression(AnalyticExpr analyticExpr) {

if (isOffsetFn(analyticFunction.getFn()) && analyticFunction.getChildren().size() > 1) {
Expr offset = analyticFunction.getChild(1);
if (!ExprUtil.isPositiveConstantInteger(offset)) {
if (!isPositiveConstantInteger(offset)) {
throw new SemanticException(
"The offset parameter of LEAD/LAG must be a constant positive integer: " +
analyticFunction.toSql(), analyticFunction.getPos());
Expand All @@ -112,7 +113,11 @@ public static void verifyAnalyticExpression(AnalyticExpr analyticExpr) {
// but the nullable info in FE is a more relax than BE (such as the nullable info in upper('a') is true,
// but the actually derived column in BE is not nullableColumn)
// which make the input colum in chunk not match the _agg_input_column in BE. so add this check in FE.
if (!analyticFunction.getChild(2).isLiteral() && analyticFunction.getChild(2).isNullable()) {
Expr theThirdChild = analyticFunction.getChild(2);
if (theThirdChild instanceof UserVariableExpr) {
theThirdChild = ((UserVariableExpr) theThirdChild).getValue();
}
if (!theThirdChild.isLiteral() && theThirdChild.isNullable()) {
throw new SemanticException("The type of the third parameter of LEAD/LAG not match the type " + firstType,
analyticFunction.getChild(2).getPos());
}
Expand All @@ -123,7 +128,7 @@ public static void verifyAnalyticExpression(AnalyticExpr analyticExpr) {

if (isNtileFn(analyticFunction.getFn())) {
Expr numBuckets = analyticFunction.getChild(0);
if (!ExprUtil.isPositiveConstantInteger(numBuckets)) {
if (!isPositiveConstantInteger(numBuckets)) {
throw new SemanticException(
"The num_buckets parameter of NTILE must be a constant positive integer: " +
analyticFunction.toSql(), numBuckets.getPos());
Expand Down Expand Up @@ -365,4 +370,15 @@ private static boolean isHllAggFn(Function fn) {

return fn.functionName().equalsIgnoreCase(AnalyticExpr.HLL_UNION_AGG);
}

private static boolean isPositiveConstantInteger(Expr offset) {
if (offset instanceof UserVariableExpr) {
offset = ((UserVariableExpr) offset).getValue();
}

if (offset instanceof LiteralExpr && offset.getType().isFixedPointType()) {
return ((LiteralExpr) offset).getLongValue() > 0;
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1750,9 +1750,7 @@ public Void visitUserVariableExpr(UserVariableExpr node, Scope context) {
UserVariable userVariable = session.getUserVariable(node.getName());
if (userVariable == null) {
node.setValue(NullLiteral.create(Type.STRING));
node.setType(Type.STRING);
} else {
node.setType(userVariable.getEvaluatedExpression().getType());
node.setValue(userVariable.getEvaluatedExpression());
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@
import com.starrocks.analysis.LiteralExpr;
import com.starrocks.analysis.NullLiteral;
import com.starrocks.analysis.StringLiteral;
import com.starrocks.analysis.UserVariableExpr;
import com.starrocks.catalog.AggregateFunction;
import com.starrocks.catalog.ArrayType;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.common.FeConstants;
import com.starrocks.common.util.ExprUtil;
import com.starrocks.qe.ConnectContext;

import java.util.Optional;

public class FunctionAnalyzer {

public static void analyze(FunctionCallExpr functionCallExpr) {
Expand Down Expand Up @@ -379,39 +381,38 @@ private static void analyzeBuiltinAggFunction(FunctionCallExpr functionCallExpr)
}

if (fnName.getFunction().equals(FunctionSet.APPROX_TOP_K)) {
Long k = null;
Long counterNum = null;
Optional<Long> k = Optional.empty();
Optional<Long> counterNum = Optional.empty();
Expr kExpr = null;
Expr counterNumExpr = null;
if (functionCallExpr.hasChild(1)) {
kExpr = functionCallExpr.getChild(1);
if (!ExprUtil.isPositiveConstantInteger(kExpr)) {
k = extractIntegerValue(kExpr);
if (!k.isPresent() || k.get() <= 0) {
throw new SemanticException(
"The second parameter of APPROX_TOP_K must be a constant positive integer: " +
functionCallExpr.toSql(), kExpr.getPos());
}
k = ExprUtil.getIntegerConstant(kExpr);
}
if (functionCallExpr.hasChild(2)) {
counterNumExpr = functionCallExpr.getChild(2);
if (!ExprUtil.isPositiveConstantInteger(counterNumExpr)) {
counterNum = extractIntegerValue(counterNumExpr);
if (!counterNum.isPresent() || counterNum.get() <= 0) {
throw new SemanticException(
"The third parameter of APPROX_TOP_K must be a constant positive integer: " +
functionCallExpr.toSql(), counterNumExpr.getPos());
}
counterNum = ExprUtil.getIntegerConstant(counterNumExpr);
}
if (k != null && k > FeConstants.MAX_COUNTER_NUM_OF_TOP_K) {
if (k.isPresent() && k.get() > FeConstants.MAX_COUNTER_NUM_OF_TOP_K) {
throw new SemanticException("The maximum number of the second parameter is "
+ FeConstants.MAX_COUNTER_NUM_OF_TOP_K + ", " + functionCallExpr.toSql(), kExpr.getPos());
}
if (counterNum != null) {
Preconditions.checkNotNull(k);
if (counterNum > FeConstants.MAX_COUNTER_NUM_OF_TOP_K) {
if (counterNum.isPresent()) {
if (counterNum.get() > FeConstants.MAX_COUNTER_NUM_OF_TOP_K) {
throw new SemanticException("The maximum number of the third parameter is "
+ FeConstants.MAX_COUNTER_NUM_OF_TOP_K + ", " + functionCallExpr.toSql(), counterNumExpr.getPos());
}
if (k > counterNum) {
if (k.get() > counterNum.get()) {
throw new SemanticException(
"The second parameter must be smaller than or equal to the third parameter" +
functionCallExpr.toSql(), kExpr.getPos());
Expand Down Expand Up @@ -453,4 +454,16 @@ private static void analyzeBuiltinAggFunction(FunctionCallExpr functionCallExpr)
}
}
}

private static Optional<Long> extractIntegerValue(Expr expr) {
if (expr instanceof UserVariableExpr) {
expr = ((UserVariableExpr) expr).getValue();
}

if (expr instanceof LiteralExpr && expr.getType().isFixedPointType()) {
return Optional.of(((LiteralExpr) expr).getLongValue());
}

return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ public Visitor(ExpressionMapping expressionMapping, ColumnRefFactory columnRefFa
@Override
public ScalarOperator visit(ParseNode node, Context context) {
Expr expr = (Expr) node;
if (expressionMapping.get(expr) != null && !(expr.isConstant())) {
if (expressionMapping.get(expr) != null && !expr.isConstant()) {
return expressionMapping.get(expr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.starrocks.sql.plan;

import com.starrocks.sql.analyzer.SemanticException;
import com.starrocks.sql.ast.UserVariable;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -70,4 +71,53 @@ public void testRemoveEscapeCharacter() {
actual = UserVariable.removeEscapeCharacter(str);
Assert.assertEquals("abc\\abc", actual);
}

@Test
public void test() throws Exception {
String sql = "SELECT /*+ set_user_variable(@a = 1) */ col_1, col_2, LAG(col_2, @a, 0) OVER (ORDER BY col_1) " +
"FROM (SELECT 1 AS col_1, NULL AS col_2 UNION ALL SELECT 2 AS col_1, 4 AS col_2) AS T ORDER BY col_1;";
String plan = getFragmentPlan(sql);
assertContains(plan, "functions: [, lag(9: cast, 1, 0), ]");

sql = "SELECT /*+ set_user_variable(@a = 1, @b = 1) */ lag(v1, @a, @b) over (ORDER BY v2) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "functions: [, lag(1: v1, 1, 1), ]");

sql = "SELECT /*+ set_user_variable(@a = 1, @b = null) */ lag(@a, @a, @b) over (ORDER BY v2) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "functions: [, lag(1, 1, NULL), ]");

sql = "select /*+ set_user_variable(@a = 1, @b = 100000) */ APPROX_TOP_K(v1, @a), APPROX_TOP_K(v1, @a, @b) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "approx_top_k(1: v1, 1), approx_top_k(1: v1, 1, 100000)");

sql = "select /*+ set_user_variable(@a = 1, @b = 10) */ ntile(@a) over (partition by v2 order by v3) " +
"as bucket_id from t0;";
plan = getFragmentPlan(sql);
assertContains(plan, "functions: [, ntile(1), ]");

sql = "select /*+ set_user_variable(@a = 1, @b = 10) */ percentile_approx(v1, @a) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "percentile_approx(CAST(1: v1 AS DOUBLE), 1.0)");

Exception exception = Assert.assertThrows(SemanticException.class, () -> {
String invalidSql = "select /*+ set_user_variable(@a = 1, @b = 1000000) */ APPROX_TOP_K(v1, @a), " +
"APPROX_TOP_K(v1, @a, @b)from t0";
getFragmentPlan(invalidSql);
});
assertContains(exception.getMessage(), "The maximum number of the third parameter is 100000");


exception = Assert.assertThrows(SemanticException.class, () -> {
String invalidSql = "select /*+ set_user_variable(@a = [1, 2, 3]) */ LAG(v1, @a) over (ORDER BY v2) from t0";
getFragmentPlan(invalidSql);
});
assertContains(exception.getMessage(), "The offset parameter of LEAD/LAG must be a constant positive integer");

exception = Assert.assertThrows(SemanticException.class, () -> {
String invalidSql = "select /*+ set_user_variable(@a = 1, @b = 10) */ percentile_approx(@a, @b) from t0";
getFragmentPlan(invalidSql);
});
assertContains(exception.getMessage(), " percentile_approx second parameter'value must be between 0 and 1");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ public void testNtileWindowFunction() throws Exception {
.analysisError("The num_buckets parameter of NTILE must be a constant positive integer");

sql = "select v1, v2, NTILE(9223372036854775808) over (partition by v1 order by v2) as j1 from t0";
starRocksAssert.query(sql).analysisError("Number out of range");
starRocksAssert.query(sql).analysisError("The num_buckets parameter of NTILE must be a constant positive integer");

sql = "select v1, v2, NTILE((select v1 from t0)) over (partition by v1 order by v2) as j1 from t0";
starRocksAssert.query(sql)
Expand Down
Loading

0 comments on commit 8f5113f

Please sign in to comment.