From f600f7061988e6ad6182243aef6beb13dc39846e Mon Sep 17 00:00:00 2001 From: AKIRA <33112463+Kikyou1997@users.noreply.github.com> Date: Wed, 22 Mar 2023 12:07:56 +0900 Subject: [PATCH] [ehancement](fe) Tune for stats framework (#17860) --- .../doris/nereids/stats/FilterEstimation.java | 49 +++++++++++++++---- .../doris/nereids/stats/JoinEstimation.java | 23 ++++++--- .../doris/nereids/stats/StatsCalculator.java | 41 +++++++++++++--- .../nereids/stats/StatsErrorEstimator.java | 2 +- .../doris/statistics/StatisticConstants.java | 2 + .../apache/doris/statistics/Statistics.java | 11 +++++ .../nereids/stats/FilterEstimationTest.java | 8 +-- tools/qerror.py | 28 ++++++++--- 8 files changed, 125 insertions(+), 39 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java index e2159f1040559e..1aaaeaa955f304 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.stats; import org.apache.doris.nereids.stats.FilterEstimation.EstimationContext; +import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; @@ -33,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.Function; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.statistics.Bucket; @@ -49,6 +51,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Predicate; /** * Calculate selectivity of expression that produces boolean value. @@ -56,9 +59,21 @@ */ public class FilterEstimation extends ExpressionVisitor { public static final double DEFAULT_INEQUALITY_COEFFICIENT = 0.5; + public static final double DEFAULT_IN_COEFFICIENT = 1.0 / 3.0; + + public static final double DEFAULT_HAVING_COEFFICIENT = 0.01; public static final double DEFAULT_EQUALITY_COMPARISON_SELECTIVITY = 0.1; + private Set aggSlots; + + public FilterEstimation() { + } + + public FilterEstimation(Set aggSlots) { + this.aggSlots = aggSlots; + } + /** * This method will update the stats according to the selectivity. */ @@ -104,7 +119,6 @@ public Statistics visitCompoundPredicate(CompoundPredicate predicate, Estimation estimatedColStatsBuilder.setMaxValue(rightColStats.maxValue); estimatedColStatsBuilder.setMaxExpr(rightColStats.maxExpr); } - orStats.addColumnStats(entry.getKey(), estimatedColStatsBuilder.build()); } return orStats; } @@ -127,6 +141,24 @@ public Statistics visitComparisonPredicate(ComparisonPredicate cp, EstimationCon } ColumnStatistic statsForLeft = ExpressionEstimation.estimate(left, context.statistics); ColumnStatistic statsForRight = ExpressionEstimation.estimate(right, context.statistics); + if (aggSlots != null) { + Predicate> containsAggSlot = e -> { + if (e instanceof SlotReference) { + SlotReference slot = (SlotReference) e; + return aggSlots.contains(slot); + } + return false; + }; + boolean leftAgg = left.anyMatch(containsAggSlot); + boolean rightAgg = right.anyMatch(containsAggSlot); + // It means this predicate appears in HAVING clause. + if (leftAgg || rightAgg) { + double rowCount = context.statistics.getRowCount(); + double newRowCount = Math.max(rowCount * DEFAULT_HAVING_COEFFICIENT, + Math.max(statsForLeft.ndv, statsForRight.ndv)); + return context.statistics.withRowCount(newRowCount); + } + } if (!(left instanceof Literal) && !(right instanceof Literal)) { return calculateWhenBothColumn(cp, context, statsForLeft, statsForRight); } else { @@ -167,14 +199,11 @@ private Statistics calculateWhenLiteralRight(ComparisonPredicate cp, double ndv = statsForLeft.ndv; double val = statsForRight.maxValue; if (cp instanceof EqualTo || cp instanceof NullSafeEqual) { - if (statsForLeft == ColumnStatistic.UNKNOWN) { - selectivity = DEFAULT_EQUALITY_COMPARISON_SELECTIVITY; + + if (val > statsForLeft.maxValue || val < statsForLeft.minValue) { + selectivity = 0.0; } else { - if (val > statsForLeft.maxValue || val < statsForLeft.minValue) { - selectivity = 0.0; - } else { - selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv); - } + selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv); } if (context.isNot) { selectivity = 1 - selectivity; @@ -249,8 +278,8 @@ public Statistics visitInPredicate(InPredicate inPredicate, EstimationContext co boolean isNotIn = context != null && context.isNot; Expression compareExpr = inPredicate.getCompareExpr(); ColumnStatistic compareExprStats = ExpressionEstimation.estimate(compareExpr, context.statistics); - if (compareExprStats.isUnKnown) { - return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT); + if (compareExprStats.isUnKnown || compareExpr instanceof Function) { + return context.statistics.withSel(DEFAULT_IN_COEFFICIENT); } List options = inPredicate.getOptions(); double maxOption = 0; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java index 0e8516e00ca263..d1427ef4699c6b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java @@ -21,10 +21,10 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.algebra.Join; -import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.statistics.Statistics; import org.apache.doris.statistics.StatisticsBuilder; +import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -33,6 +33,7 @@ * TODO: Update other props in the ColumnStats properly. */ public class JoinEstimation { + private static Statistics estimateInnerJoin(Statistics crossJoinStats, List joinConditions) { List> sortedJoinConditions = joinConditions.stream() .map(expression -> Pair.of(expression, estimateJoinConditionSel(crossJoinStats, expression))) @@ -51,7 +52,7 @@ private static Statistics estimateInnerJoin(Statistics crossJoinStats, List joinConditions = join.getHashJoinConjuncts(); - Statistics innerJoinStats = estimateInnerJoin(crossJoinStats, joinConditions); - if (!join.getOtherJoinConjuncts().isEmpty()) { - FilterEstimation filterEstimation = new FilterEstimation(); - innerJoinStats = filterEstimation.estimate( - ExpressionUtils.and(join.getOtherJoinConjuncts()), innerJoinStats); + Statistics innerJoinStats = null; + if (crossJoinStats.getRowCount() != 0) { + List joinConditions = new ArrayList<>(join.getHashJoinConjuncts()); + joinConditions.addAll(join.getOtherJoinConjuncts()); + innerJoinStats = estimateInnerJoin(crossJoinStats, joinConditions); + } else { + innerJoinStats = crossJoinStats; } + // if (!join.getOtherJoinConjuncts().isEmpty()) { + // FilterEstimation filterEstimation = new FilterEstimation(); + // innerJoinStats = filterEstimation.estimate( + // ExpressionUtils.and(join.getOtherJoinConjuncts()), innerJoinStats); + // } innerJoinStats.setWidth(leftStats.getWidth() + rightStats.getWidth()); innerJoinStats.setPenalty(0); double rowCount; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index 8f10d6f4ea0a3b..3143e196f554dc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -22,10 +22,12 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Aggregate; import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation; @@ -94,6 +96,7 @@ import org.apache.doris.statistics.StatisticsBuilder; import com.google.common.collect.Maps; +import org.apache.commons.collections.CollectionUtils; import java.util.AbstractMap.SimpleEntry; import java.util.HashMap; @@ -394,9 +397,24 @@ private Statistics computeAssertNumRows(long desiredNumOfRows) { } private Statistics computeFilter(Filter filter) { - FilterEstimation filterEstimation = new FilterEstimation(); Statistics stats = groupExpression.childStatistics(0); - return filterEstimation.estimate(filter.getPredicate(), stats); + Plan plan = tryToFindChild(groupExpression); + if (plan != null) { + if (plan instanceof Aggregate) { + Aggregate agg = ((Aggregate) plan); + List expressions = agg.getOutputExpressions(); + Set slots = expressions + .stream() + .filter(Alias.class::isInstance) + .filter(s -> ((Alias) s).child().anyMatch(AggregateFunction.class::isInstance)) + .map(NamedExpression::toSlot).collect(Collectors.toSet()); + Expression predicate = filter.getPredicate(); + if (predicate.anyMatch(s -> slots.contains(s))) { + return new FilterEstimation(slots).estimate(filter.getPredicate(), stats); + } + } + } + return new FilterEstimation().estimate(filter.getPredicate(), stats); } // TODO: 1. Subtract the pruned partition @@ -441,13 +459,8 @@ private Statistics computeAggregate(Aggregate aggregate) { if (!groupByExpressions.isEmpty()) { Map childSlotToColumnStats = childStats.columnStatistics(); double inputRowCount = childStats.getRowCount(); - if (inputRowCount == 0) { - //on empty relation, Agg output 1 tuple - resultSetCount = 1; - } else { + if (inputRowCount != 0) { List groupByKeyStats = groupByExpressions.stream() - .flatMap(expr -> expr.getInputSlots().stream()) - .map(Slot::getExprId) .filter(childSlotToColumnStats::containsKey) .map(childSlotToColumnStats::get) .filter(s -> !s.isUnKnown) @@ -692,4 +705,16 @@ private ColumnStatistic unionColumn(ColumnStatistic leftStats, double leftRowCou .setAvgSizeByte(newAverageRowSize); return columnStatisticBuilder.build(); } + + private Plan tryToFindChild(GroupExpression groupExpression) { + List groupExprs = groupExpression.child(0).getLogicalExpressions(); + if (CollectionUtils.isEmpty(groupExprs)) { + groupExprs = groupExpression.child(0).getPhysicalExpressions(); + if (CollectionUtils.isEmpty(groupExprs)) { + return null; + } + } + return groupExprs.get(0).getPlan(); + } + } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java index 6966fc97ea9507..54d309ac1439d7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java @@ -52,7 +52,7 @@ public StatsErrorEstimator() { } /** - * Map plan id to stats. + * Invoked by PhysicalPlanTranslator, put the translated plan node and corresponding physical plan to estimator. */ public void updateLegacyPlanIdToPhysicalPlan(PlanNode planNode, AbstractPlan physicalPlan) { Statistics statistics = physicalPlan.getStats(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java b/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java index df34c2f9d69d66..0345c7930edd0c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java +++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java @@ -62,4 +62,6 @@ public class StatisticConstants { public static final int LOAD_TASK_LIMITS = 10; + public static final double DEFAULT_INNER_JOIN_FACTOR = 0.1; + } diff --git a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java index b9cf6040e8c338..2aa0d1ad33667a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java +++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java @@ -83,14 +83,25 @@ public double getRowCount() { return rowCount; } + /* + * Return a stats with new rowCount and fix each column stats. + */ public Statistics withRowCount(double rowCount) { + if (Double.isNaN(rowCount)) { + return this; + } Statistics statistics = new Statistics(rowCount, new HashMap<>(expressionToColumnStats), width, penalty); statistics.fix(rowCount, StatsMathUtil.nonZeroDivisor(this.rowCount)); return statistics; } + public Statistics updateRowCountOnly(double rowCount) { + return new Statistics(rowCount, expressionToColumnStats); + } + public void fix(double newRowCount, double originRowCount) { double sel = newRowCount / originRowCount; + for (Entry entry : expressionToColumnStats.entrySet()) { ColumnStatistic columnStatistic = entry.getValue(); ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder(columnStatistic); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java index 691cf5372065d3..0ab0c7c2ff7e0a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java @@ -115,9 +115,7 @@ public void testInNaN() { Statistics stat = new Statistics(1000, slotToColumnStat); FilterEstimation filterEstimation = new FilterEstimation(); Statistics expected = filterEstimation.estimate(in, stat); - Assertions.assertEquals( - FilterEstimation.DEFAULT_INEQUALITY_COEFFICIENT * stat.getRowCount(), - expected.getRowCount()); + Assertions.assertTrue(Precision.equals(333.33, expected.getRowCount(), 0.01)); } @Test @@ -134,9 +132,7 @@ public void testNotInNaN() { Statistics stat = new Statistics(1000, slotToColumnStat); FilterEstimation filterEstimation = new FilterEstimation(); Statistics expected = filterEstimation.estimate(notIn, stat); - Assertions.assertEquals( - FilterEstimation.DEFAULT_INEQUALITY_COEFFICIENT * stat.getRowCount(), - expected.getRowCount()); + Assertions.assertTrue(Precision.equals(333.33, expected.getRowCount(), 0.01)); } /** diff --git a/tools/qerror.py b/tools/qerror.py index 428d9fdf9073dd..70920b60a44790 100644 --- a/tools/qerror.py +++ b/tools/qerror.py @@ -21,6 +21,7 @@ import requests import json +import time mycli_cmd = "mysql -h127.0.0.1 -P9030 -uroot -Dtpch1G" @@ -37,6 +38,8 @@ SET session_context='trace_id:{}'; """ +q_err_list = [] + def extract_number(string): return int(''.join([c for c in string if c.isdigit()])) @@ -73,6 +76,7 @@ def execute_sql(sql_file: str): def get_q_error(trace_id): + time.sleep(1) # 'YWRtaW46' is the base64 encoded result for 'admin:' headers = {'Authorization': 'BASIC YWRtaW46'} resp_wrapper = requests.get(trace_url.format(trace_id), headers=headers) @@ -81,9 +85,13 @@ def get_q_error(trace_id): resp_wrapper = requests.get(qerror_url.format(query_id), headers=headers) resp_text = resp_wrapper.text write_result(str(trace_id), resp_text) + print(trace_id) + print(resp_text) + qerr = json.loads(resp_text)["qError"] + q_err_list.append(float(qerr)) -def iterates_sqls(path: str) -> list: +def iterates_sqls(path: str, if_write_results: bool) -> list: cost_times = [] files = os.listdir(path) files.sort(key=extract_number) @@ -93,14 +101,22 @@ def iterates_sqls(path: str) -> list: traced_sql_file = filepath + ".traced" content = read_lines(filepath) sql_num = extract_number(filename) - write_results(traced_sql_file, str(sql_file_prefix_for_trace.format(sql_num)), content) - execute_sql(traced_sql_file) - get_q_error(sql_num) - os.remove(traced_sql_file) + print("sql num" + str(sql_num)) + if if_write_results: + write_results(traced_sql_file, str(sql_file_prefix_for_trace.format(sql_num)), content) + execute_sql(traced_sql_file) + get_q_error(sql_num) + os.remove(traced_sql_file) + else: + execute_sql(filepath) return cost_times if __name__ == '__main__': execute_command("echo 'set global enable_nereids_planner=true' | mysql -h127.0.0.1 -P9030") execute_command("echo 'set global enable_fallback_to_original_planner=false' | mysql -h127.0.0.1 -P9030") - iterates_sqls(original_sql_dir) + print("Preparing") + iterates_sqls(original_sql_dir, False) + print("Started...") + iterates_sqls(original_sql_dir, True) + write_results(qerr_saved_file_path, "AVG\n", [sum(q_err_list) / len(qerror_url)])