Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,20 @@
import org.apache.flink.util.CollectionUtil;
import org.apache.flink.util.Preconditions;

import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList;

import org.apache.calcite.plan.Context;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptSchema;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
Expand Down Expand Up @@ -197,12 +193,6 @@ public RelBuilder aggregate(
final LogicalAggregate logicalAggregate = (LogicalAggregate) relNode;
if (isTableAggregate(logicalAggregate.getAggCallList())) {
relNode = LogicalTableAggregate.create(logicalAggregate);
} else if (isCountStarAgg(logicalAggregate)) {
final RelNode newAggInput =
push(logicalAggregate.getInput(0)).project(literal(0)).build();
relNode =
logicalAggregate.copy(
logicalAggregate.getTraitSet(), ImmutableList.of(newAggInput));
}
}

Expand Down Expand Up @@ -270,14 +260,4 @@ public RelBuilder transform(UnaryOperator<Config> transform) {
cluster.getPlanner().getContext());
return FlinkRelBuilder.of(mergedContext, cluster, relOptSchema);
}

private static boolean isCountStarAgg(LogicalAggregate agg) {
if (agg.getGroupCount() != 0 || agg.getAggCallList().size() != 1) {
return false;
}
final AggregateCall call = agg.getAggCallList().get(0);
return call.getAggregation().getKind() == SqlKind.COUNT
&& call.filterArg == -1
&& call.getArgList().isEmpty();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.rules.logical;

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.immutables.value.Value;

import java.util.Collections;

/**
* Planner rule that prunes the input columns of a {@link LogicalAggregate} representing {@code
* COUNT(*)} (i.e. no group keys, a single COUNT aggregate call with no arguments and no filter) by
* inserting a project of a constant literal {@code 0} between the aggregate and its input.
*
* <p>This avoids reading all columns from the source when only counting rows.
*
* <p>Before:
*
* <pre>
* LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
* +- SomeInput
* </pre>
*
* <p>After:
*
* <pre>
* LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
* +- LogicalProject($f0=[0])
* +- SomeInput
* </pre>
*/
@Value.Enclosing
public class PruneCountStarInputRule
extends RelRule<PruneCountStarInputRule.PruneCountStarInputRuleConfig> {

public static final PruneCountStarInputRule INSTANCE =
PruneCountStarInputRule.PruneCountStarInputRuleConfig.DEFAULT.toRule();

protected PruneCountStarInputRule(PruneCountStarInputRuleConfig config) {
super(config);
}

@Override
public boolean matches(RelOptRuleCall call) {
final LogicalAggregate agg = call.rel(0);
final RelNode input = agg.getInput();
if (agg.getGroupCount() != 0 || agg.getAggCallList().size() != 1) {
return false;
}
final AggregateCall aggCall = agg.getAggCallList().get(0);
if (aggCall.getAggregation().getKind() != SqlKind.COUNT
|| aggCall.filterArg != -1
|| !aggCall.getArgList().isEmpty()) {
return false;
}
// Only rewrite when the input has more than one field. After the rewrite, the input
// becomes a single-field Project(0), so this condition naturally prevents repeated
// application even if other rules in the same phase transform or remove the inserted
// project.
return input.getRowType().getFieldCount() > 1;
}

@Override
public void onMatch(RelOptRuleCall call) {
final LogicalAggregate agg = call.rel(0);
final RelNode input = agg.getInput();

final RelBuilder relBuilder = call.builder();
final RelNode newInput = relBuilder.push(input).project(relBuilder.literal(0)).build();
final RelNode newAgg = agg.copy(agg.getTraitSet(), Collections.singletonList(newInput));
call.transformTo(newAgg);
}

/** Rule configuration. */
@Value.Immutable(singleton = false)
public interface PruneCountStarInputRuleConfig extends RelRule.Config {
PruneCountStarInputRule.PruneCountStarInputRuleConfig DEFAULT =
ImmutablePruneCountStarInputRule.PruneCountStarInputRuleConfig.builder()
.operandSupplier(b0 -> b0.operand(LogicalAggregate.class).anyInputs())
.description("PruneCountStarInputRule")
.build();

@Override
default PruneCountStarInputRule toRule() {
return new PruneCountStarInputRule(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ object FlinkBatchRuleSets {
// vector search rule.
ConstantVectorSearchCallToCorrelateRule.INSTANCE,
// Wrap arguments for JSON aggregate functions
WrapJsonAggFunctionArgumentsRule.INSTANCE
WrapJsonAggFunctionArgumentsRule.INSTANCE,
// prune COUNT(*) input to project a constant before aggregation
PruneCountStarInputRule.INSTANCE
)).asJava)

/** RuleSet about filter */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ object FlinkStreamRuleSets {
// rewrite constant table function scan to correlate
JoinTableFunctionScanToCorrelateRule.INSTANCE,
// Wrap arguments for JSON aggregate functions
WrapJsonAggFunctionArgumentsRule.INSTANCE
WrapJsonAggFunctionArgumentsRule.INSTANCE,
// prune COUNT(*) input to project a constant before aggregation
PruneCountStarInputRule.INSTANCE
)
).asJava)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ NestedLoopJoin(joinType=[FullOuterJoin], where=[(cnt <> cnt0)], select=[cnt, cnt
: +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS cnt])(reuse_id=[1])
: +- Exchange(distribution=[single])
: +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0])
: +- TableSourceScan(table=[[default_catalog, default_database, x]], fields=[a, b, c])
: +- Calc(select=[0 AS $f0])
: +- TableSourceScan(table=[[default_catalog, default_database, x]], fields=[a, b, c])
+- Exchange(distribution=[single], shuffle_mode=[BATCH])
+- Calc(select=[cnt], where=[(cnt < 5)])
+- Reused(reference_id=[1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ Calc(select=[a])
: +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS cnt])
: +- Exchange(distribution=[single])
: +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0])
: +- Reused(reference_id=[1])
: +- Calc(select=[0 AS $f0])
: +- Reused(reference_id=[1])
:- Exchange(distribution=[broadcast])
: +- Calc(select=[a])
: +- TableSourceScan(table=[[default_catalog, default_database, x]], fields=[a, b, c, nx])
Expand Down Expand Up @@ -213,7 +214,8 @@ Calc(select=[a])
: +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS cnt])
: +- Exchange(distribution=[single])
: +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0])
: +- Reused(reference_id=[1])
: +- Calc(select=[0 AS $f0])
: +- Reused(reference_id=[1])
:- Exchange(distribution=[broadcast])
: +- Calc(select=[a])
: +- TableSourceScan(table=[[default_catalog, default_database, x]], fields=[a, b, c, nx])
Expand Down Expand Up @@ -268,7 +270,8 @@ Calc(select=[a])
: +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS cnt])
: +- Exchange(distribution=[single])
: +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0])
: +- Reused(reference_id=[1])
: +- Calc(select=[0 AS $f0])
: +- Reused(reference_id=[1])
:- Exchange(distribution=[broadcast])
: +- Calc(select=[a])
: +- TableSourceScan(table=[[default_catalog, default_database, x]], fields=[a, b, c, nx])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ Sink(table=[default_catalog.default_database.t], fields=[a, b])
+- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS EXPR$0])
+- Exchange(distribution=[single])
+- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0])
+- TableSourceScan(table=[[default_catalog, default_database, t]], fields=[a, b])
+- Calc(select=[0 AS $f0])
+- TableSourceScan(table=[[default_catalog, default_database, t]], fields=[a, b])
== Optimized Execution Plan ==
Sink(table=[default_catalog.default_database.t], fields=[a, b])
Expand All @@ -395,7 +396,8 @@ Sink(table=[default_catalog.default_database.t], fields=[a, b])
+- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS EXPR$0])
+- Exchange(distribution=[single])
+- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0])
+- Reused(reference_id=[1])
+- Calc(select=[0 AS $f0])
+- Reused(reference_id=[1])
== Physical Execution Plan ==
{
Expand All @@ -416,6 +418,17 @@ Sink(table=[default_catalog.default_database.t], fields=[a, b])
"ship_strategy" : "FORWARD",
"side" : "second"
} ]
}, {
"id" : ,
"type" : "Calc[]",
"pact" : "Operator",
"contents" : "[]:Calc(select=[0 AS $f0])",
"parallelism" : 1,
"predecessors" : [ {
"id" : ,
"ship_strategy" : "FORWARD",
"side" : "second"
} ]
}, {
"id" : ,
"type" : "HashAggregate[]",
Expand Down Expand Up @@ -510,7 +523,8 @@ Sink(table=[default_catalog.default_database.t], fields=[a, b])
+- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS EXPR$0])
+- Exchange(distribution=[single])
+- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0])
+- TableSourceScan(table=[[default_catalog, default_database, t]], fields=[a, b])
+- Calc(select=[0 AS $f0])
+- TableSourceScan(table=[[default_catalog, default_database, t]], fields=[a, b])
== Optimized Execution Plan ==
Sink(table=[default_catalog.default_database.t], fields=[a, b])
Expand All @@ -522,7 +536,8 @@ Sink(table=[default_catalog.default_database.t], fields=[a, b])
+- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS EXPR$0])
+- Exchange(distribution=[single])
+- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0])
+- Reused(reference_id=[1])
+- Calc(select=[0 AS $f0])
+- Reused(reference_id=[1])
== Physical Execution Plan ==
{
Expand All @@ -532,6 +547,17 @@ Sink(table=[default_catalog.default_database.t], fields=[a, b])
"pact" : "Data Source",
"contents" : "[]:TableSourceScan(table=[[default_catalog, default_database, t]], fields=[a, b])",
"parallelism" : 1
}, {
"id" : ,
"type" : "Calc[]",
"pact" : "Operator",
"contents" : "[]:Calc(select=[0 AS $f0])",
"parallelism" : 1,
"predecessors" : [ {
"id" : ,
"ship_strategy" : "FORWARD",
"side" : "second"
} ]
}, {
"id" : ,
"type" : "HashAggregate[]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,8 @@ Sink(table=[default_catalog.default_database.t], targetColumns=[[1]], fields=[a,
+- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS EXPR$0])
+- Exchange(distribution=[single])
+- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0])
+- TableSourceScan(table=[[default_catalog, default_database, t1]], fields=[a, b])
+- Calc(select=[0 AS $f0])
+- TableSourceScan(table=[[default_catalog, default_database, t1]], fields=[a, b])
== Optimized Execution Plan ==
Sink(table=[default_catalog.default_database.t], targetColumns=[[1]], fields=[a, b])
Expand All @@ -616,7 +617,8 @@ Sink(table=[default_catalog.default_database.t], targetColumns=[[1]], fields=[a,
+- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS EXPR$0])
+- Exchange(distribution=[single])
+- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0])
+- TableSourceScan(table=[[default_catalog, default_database, t1]], fields=[a, b])
+- Calc(select=[0 AS $f0])
+- TableSourceScan(table=[[default_catalog, default_database, t1]], fields=[a, b])
== Physical Execution Plan ==
{
Expand All @@ -632,6 +634,17 @@ Sink(table=[default_catalog.default_database.t], targetColumns=[[1]], fields=[a,
"pact" : "Data Source",
"contents" : "[]:TableSourceScan(table=[[default_catalog, default_database, t1]], fields=[a, b])",
"parallelism" : 1
}, {
"id" : ,
"type" : "Calc[]",
"pact" : "Operator",
"contents" : "[]:Calc(select=[0 AS $f0])",
"parallelism" : 1,
"predecessors" : [ {
"id" : ,
"ship_strategy" : "FORWARD",
"side" : "second"
} ]
}, {
"id" : ,
"type" : "HashAggregate[]",
Expand Down Expand Up @@ -716,7 +729,8 @@ Sink(table=[default_catalog.default_database.t], targetColumns=[[1]], fields=[a,
+- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS EXPR$0])
+- Exchange(distribution=[single])
+- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0])
+- TableSourceScan(table=[[default_catalog, default_database, t1]], fields=[a, b])
+- Calc(select=[0 AS $f0])
+- TableSourceScan(table=[[default_catalog, default_database, t1]], fields=[a, b])
== Optimized Execution Plan ==
Sink(table=[default_catalog.default_database.t], targetColumns=[[1]], fields=[a, b])
Expand All @@ -728,7 +742,8 @@ Sink(table=[default_catalog.default_database.t], targetColumns=[[1]], fields=[a,
+- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS EXPR$0])
+- Exchange(distribution=[single])
+- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0])
+- TableSourceScan(table=[[default_catalog, default_database, t1]], fields=[a, b])
+- Calc(select=[0 AS $f0])
+- TableSourceScan(table=[[default_catalog, default_database, t1]], fields=[a, b])
== Physical Execution Plan ==
{
Expand All @@ -755,6 +770,17 @@ Sink(table=[default_catalog.default_database.t], targetColumns=[[1]], fields=[a,
"pact" : "Data Source",
"contents" : "[]:TableSourceScan(table=[[default_catalog, default_database, t1]], fields=[a, b])",
"parallelism" : 1
}, {
"id" : ,
"type" : "Calc[]",
"pact" : "Operator",
"contents" : "[]:Calc(select=[0 AS $f0])",
"parallelism" : 1,
"predecessors" : [ {
"id" : ,
"ship_strategy" : "FORWARD",
"side" : "second"
} ]
}, {
"id" : ,
"type" : "HashAggregate[]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
<![CDATA[
HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS EXPR$0])
+- Exchange(distribution=[single])
+- TableSourceScan(table=[[default_catalog, default_database, ProjectableTable, aggregates=[grouping=[], aggFunctions=[Count1AggFunction()]]]], fields=[count1$0])
+- TableSourceScan(table=[[default_catalog, default_database, ProjectableTable, project=[a], metadata=[], aggregates=[grouping=[], aggFunctions=[Count1AggFunction()]]]], fields=[count1$0])
]]>
</Resource>
</TestCase>
Expand Down
Loading