Skip to content

Commit 7375d70

Browse files
godfreyheKurtYoung
authored andcommitted
[FLINK-12076] [table-planner-blink] Add support for generating optimized logical plan for simple group aggregate on batch (apache#8092)
* [FLINK-12076] [table-planner-blink] Add support for generating optimized logical plan for batch simple group aggregate * fix checkstyle error * remove Ignore from SingleRowJoinTest, and update based on comments * fix checkstyle error
1 parent 0e0a980 commit 7375d70

34 files changed

+5506
-38
lines changed

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/api/PlannerConfigOptions.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ public class PlannerConfigOptions {
3939
"(only count RexCall node, including leaves and interior nodes). Negative number to" +
4040
" use the default threshold: double of number of nodes.");
4141

42+
public static final ConfigOption<String> SQL_OPTIMIZER_AGG_PHASE_ENFORCER =
43+
key("sql.optimizer.agg.phase.enforcer")
44+
.defaultValue("NONE")
45+
.withDescription("Strategy for agg phase. Only NONE, TWO_PHASE or ONE_PHASE can be set.\n" +
46+
"NONE: No special enforcer for aggregate stage. Whether to choose two stage aggregate or one" +
47+
" stage aggregate depends on cost. \n" +
48+
"TWO_PHASE: Enforce to use two stage aggregate which has localAggregate and globalAggregate. " +
49+
"NOTE: If aggregate call does not support split into two phase, still use one stage aggregate.\n" +
50+
"ONE_PHASE: Enforce to use one stage aggregate which only has CompleteGlobalAggregate.");
51+
4252
public static final ConfigOption<Boolean> SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED =
4353
key("sql.optimizer.shuffle.partial-key.enabled")
4454
.defaultValue(false)

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/ExpressionBuilder.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525

2626
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.DIVIDE;
2727
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.EQUALS;
28+
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.GREATER_THAN;
2829
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.IF;
2930
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.IS_NULL;
31+
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.LESS_THAN;
3032
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.MINUS;
3133
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.PLUS;
3234

@@ -78,4 +80,12 @@ public static Expression div(Expression input1, Expression input2) {
7880
public static Expression equalTo(Expression input1, Expression input2) {
7981
return call(EQUALS, input1, input2);
8082
}
83+
84+
public static Expression lessThan(Expression input1, Expression input2) {
85+
return call(LESS_THAN, input1, input2);
86+
}
87+
88+
public static Expression greaterThan(Expression input1, Expression input2) {
89+
return call(GREATER_THAN, input1, input2);
90+
}
8191
}

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/AvgAggFunction.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import org.apache.flink.table.expressions.Expression;
2424
import org.apache.flink.table.expressions.UnresolvedFieldReferenceExpression;
2525
import org.apache.flink.table.type.DecimalType;
26+
import org.apache.flink.table.type.InternalType;
27+
import org.apache.flink.table.type.InternalTypes;
28+
import org.apache.flink.table.type.TypeConverters;
2629
import org.apache.flink.table.typeutils.DecimalTypeInfo;
2730

2831
import java.math.BigDecimal;
@@ -44,9 +47,7 @@ public abstract class AvgAggFunction extends DeclarativeAggregateFunction {
4447
private UnresolvedFieldReferenceExpression sum = new UnresolvedFieldReferenceExpression("sum");
4548
private UnresolvedFieldReferenceExpression count = new UnresolvedFieldReferenceExpression("count");
4649

47-
public TypeInformation getSumType() {
48-
return Types.LONG;
49-
}
50+
public abstract TypeInformation getSumType();
5051

5152
@Override
5253
public int operandCount() {
@@ -60,6 +61,14 @@ public UnresolvedFieldReferenceExpression[] aggBufferAttributes() {
6061
count};
6162
}
6263

64+
@Override
65+
public InternalType[] getAggBufferTypes() {
66+
return new InternalType[] {
67+
TypeConverters.createInternalTypeFromTypeInfo(getSumType()),
68+
InternalTypes.LONG
69+
};
70+
}
71+
6372
@Override
6473
public Expression[] initialValuesExpressions() {
6574
return new Expression[] {
@@ -110,6 +119,11 @@ public static class IntegralAvgAggFunction extends AvgAggFunction {
110119
public TypeInformation getResultType() {
111120
return Types.DOUBLE;
112121
}
122+
123+
@Override
124+
public TypeInformation getSumType() {
125+
return Types.LONG;
126+
}
113127
}
114128

115129
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.functions;
20+
21+
import org.apache.flink.api.common.typeinfo.TypeInformation;
22+
import org.apache.flink.api.common.typeinfo.Types;
23+
import org.apache.flink.table.expressions.Expression;
24+
import org.apache.flink.table.expressions.UnresolvedFieldReferenceExpression;
25+
import org.apache.flink.table.type.InternalType;
26+
import org.apache.flink.table.type.InternalTypes;
27+
28+
import static org.apache.flink.table.expressions.ExpressionBuilder.literal;
29+
import static org.apache.flink.table.expressions.ExpressionBuilder.minus;
30+
import static org.apache.flink.table.expressions.ExpressionBuilder.plus;
31+
32+
/**
33+
* This count1 aggregate function returns the count1 of values
34+
* which go into it like [[CountAggFunction]].
35+
* It differs in that null values are also counted.
36+
*/
37+
public class Count1AggFunction extends DeclarativeAggregateFunction {
38+
private UnresolvedFieldReferenceExpression count1 = new UnresolvedFieldReferenceExpression("count1");
39+
40+
@Override
41+
public int operandCount() {
42+
return 1;
43+
}
44+
45+
@Override
46+
public UnresolvedFieldReferenceExpression[] aggBufferAttributes() {
47+
return new UnresolvedFieldReferenceExpression[] { count1 };
48+
}
49+
50+
@Override
51+
public InternalType[] getAggBufferTypes() {
52+
return new InternalType[] { InternalTypes.LONG };
53+
}
54+
55+
@Override
56+
public TypeInformation getResultType() {
57+
return Types.LONG;
58+
}
59+
60+
@Override
61+
public Expression[] initialValuesExpressions() {
62+
return new Expression[] {
63+
/* count1 = */ literal(0L, getResultType())
64+
};
65+
}
66+
67+
@Override
68+
public Expression[] accumulateExpressions() {
69+
return new Expression[] {
70+
/* count1 = */ plus(count1, literal(1L))
71+
};
72+
}
73+
74+
@Override
75+
public Expression[] retractExpressions() {
76+
return new Expression[] {
77+
/* count1 = */ minus(count1, literal(1L))
78+
};
79+
}
80+
81+
@Override
82+
public Expression[] mergeExpressions() {
83+
return new Expression[] {
84+
/* count1 = */ plus(count1, mergeOperand(count1))
85+
};
86+
}
87+
88+
@Override
89+
public Expression getValueExpression() {
90+
return count1;
91+
}
92+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.functions;
20+
21+
import org.apache.flink.api.common.typeinfo.TypeInformation;
22+
import org.apache.flink.api.common.typeinfo.Types;
23+
import org.apache.flink.table.expressions.Expression;
24+
import org.apache.flink.table.expressions.UnresolvedFieldReferenceExpression;
25+
import org.apache.flink.table.type.InternalType;
26+
import org.apache.flink.table.type.InternalTypes;
27+
28+
import static org.apache.flink.table.expressions.ExpressionBuilder.ifThenElse;
29+
import static org.apache.flink.table.expressions.ExpressionBuilder.isNull;
30+
import static org.apache.flink.table.expressions.ExpressionBuilder.literal;
31+
import static org.apache.flink.table.expressions.ExpressionBuilder.minus;
32+
import static org.apache.flink.table.expressions.ExpressionBuilder.plus;
33+
34+
/**
35+
* built-in count aggregate function.
36+
*/
37+
public class CountAggFunction extends DeclarativeAggregateFunction {
38+
private UnresolvedFieldReferenceExpression count = new UnresolvedFieldReferenceExpression("count");
39+
40+
@Override
41+
public int operandCount() {
42+
return 1;
43+
}
44+
45+
@Override
46+
public UnresolvedFieldReferenceExpression[] aggBufferAttributes() {
47+
return new UnresolvedFieldReferenceExpression[] { count };
48+
}
49+
50+
@Override
51+
public InternalType[] getAggBufferTypes() {
52+
return new InternalType[] { InternalTypes.LONG };
53+
}
54+
55+
@Override
56+
public TypeInformation getResultType() {
57+
return Types.LONG;
58+
}
59+
60+
@Override
61+
public Expression[] initialValuesExpressions() {
62+
return new Expression[] {
63+
/* count = */ literal(0L, getResultType())
64+
};
65+
}
66+
67+
@Override
68+
public Expression[] accumulateExpressions() {
69+
return new Expression[] {
70+
/* count = */ ifThenElse(isNull(operand(0)), count, plus(count, literal(1L)))
71+
};
72+
}
73+
74+
@Override
75+
public Expression[] retractExpressions() {
76+
return new Expression[] {
77+
/* count = */ ifThenElse(isNull(operand(0)), count, minus(count, literal(1L)))
78+
};
79+
}
80+
81+
@Override
82+
public Expression[] mergeExpressions() {
83+
return new Expression[] {
84+
/* count = */ plus(count, mergeOperand(count))
85+
};
86+
}
87+
88+
// If all input are nulls, count will be 0 and we will get result 0.
89+
@Override
90+
public Expression getValueExpression() {
91+
return count;
92+
}
93+
}

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/DeclarativeAggregateFunction.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.flink.api.common.typeinfo.TypeInformation;
2222
import org.apache.flink.table.expressions.Expression;
2323
import org.apache.flink.table.expressions.UnresolvedFieldReferenceExpression;
24+
import org.apache.flink.table.type.InternalType;
2425
import org.apache.flink.util.Preconditions;
2526

2627
import java.util.Arrays;
@@ -57,6 +58,11 @@ public abstract class DeclarativeAggregateFunction extends UserDefinedFunction {
5758
*/
5859
public abstract UnresolvedFieldReferenceExpression[] aggBufferAttributes();
5960

61+
/**
62+
* All types of the aggregate buffer.
63+
*/
64+
public abstract InternalType[] getAggBufferTypes();
65+
6066
/**
6167
* The result type of the function.
6268
*/

0 commit comments

Comments
 (0)