Skip to content

Commit

Permalink
[ehancement](fe) Tune for stats framework (apache#17860)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kikyou1997 authored Mar 22, 2023
1 parent 173d684 commit f600f70
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.stats;

import org.apache.doris.nereids.stats.FilterEstimation.EstimationContext;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
Expand All @@ -33,6 +34,7 @@
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.statistics.Bucket;
Expand All @@ -49,16 +51,29 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;

/**
* Calculate selectivity of expression that produces boolean value.
* TODO: Should consider the distribution of data.
*/
public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationContext> {
public static final double DEFAULT_INEQUALITY_COEFFICIENT = 0.5;
public static final double DEFAULT_IN_COEFFICIENT = 1.0 / 3.0;

public static final double DEFAULT_HAVING_COEFFICIENT = 0.01;

public static final double DEFAULT_EQUALITY_COMPARISON_SELECTIVITY = 0.1;

private Set<Slot> aggSlots;

public FilterEstimation() {
}

public FilterEstimation(Set<Slot> aggSlots) {
this.aggSlots = aggSlots;
}

/**
* This method will update the stats according to the selectivity.
*/
Expand Down Expand Up @@ -104,7 +119,6 @@ public Statistics visitCompoundPredicate(CompoundPredicate predicate, Estimation
estimatedColStatsBuilder.setMaxValue(rightColStats.maxValue);
estimatedColStatsBuilder.setMaxExpr(rightColStats.maxExpr);
}
orStats.addColumnStats(entry.getKey(), estimatedColStatsBuilder.build());
}
return orStats;
}
Expand All @@ -127,6 +141,24 @@ public Statistics visitComparisonPredicate(ComparisonPredicate cp, EstimationCon
}
ColumnStatistic statsForLeft = ExpressionEstimation.estimate(left, context.statistics);
ColumnStatistic statsForRight = ExpressionEstimation.estimate(right, context.statistics);
if (aggSlots != null) {
Predicate<TreeNode<Expression>> containsAggSlot = e -> {
if (e instanceof SlotReference) {
SlotReference slot = (SlotReference) e;
return aggSlots.contains(slot);
}
return false;
};
boolean leftAgg = left.anyMatch(containsAggSlot);
boolean rightAgg = right.anyMatch(containsAggSlot);
// It means this predicate appears in HAVING clause.
if (leftAgg || rightAgg) {
double rowCount = context.statistics.getRowCount();
double newRowCount = Math.max(rowCount * DEFAULT_HAVING_COEFFICIENT,
Math.max(statsForLeft.ndv, statsForRight.ndv));
return context.statistics.withRowCount(newRowCount);
}
}
if (!(left instanceof Literal) && !(right instanceof Literal)) {
return calculateWhenBothColumn(cp, context, statsForLeft, statsForRight);
} else {
Expand Down Expand Up @@ -167,14 +199,11 @@ private Statistics calculateWhenLiteralRight(ComparisonPredicate cp,
double ndv = statsForLeft.ndv;
double val = statsForRight.maxValue;
if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
if (statsForLeft == ColumnStatistic.UNKNOWN) {
selectivity = DEFAULT_EQUALITY_COMPARISON_SELECTIVITY;

if (val > statsForLeft.maxValue || val < statsForLeft.minValue) {
selectivity = 0.0;
} else {
if (val > statsForLeft.maxValue || val < statsForLeft.minValue) {
selectivity = 0.0;
} else {
selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv);
}
selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv);
}
if (context.isNot) {
selectivity = 1 - selectivity;
Expand Down Expand Up @@ -249,8 +278,8 @@ public Statistics visitInPredicate(InPredicate inPredicate, EstimationContext co
boolean isNotIn = context != null && context.isNot;
Expression compareExpr = inPredicate.getCompareExpr();
ColumnStatistic compareExprStats = ExpressionEstimation.estimate(compareExpr, context.statistics);
if (compareExprStats.isUnKnown) {
return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT);
if (compareExprStats.isUnKnown || compareExpr instanceof Function) {
return context.statistics.withSel(DEFAULT_IN_COEFFICIENT);
}
List<Expression> options = inPredicate.getOptions();
double maxOption = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsBuilder;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -33,6 +33,7 @@
* TODO: Update other props in the ColumnStats properly.
*/
public class JoinEstimation {

private static Statistics estimateInnerJoin(Statistics crossJoinStats, List<Expression> joinConditions) {
List<Pair<Expression, Double>> sortedJoinConditions = joinConditions.stream()
.map(expression -> Pair.of(expression, estimateJoinConditionSel(crossJoinStats, expression)))
Expand All @@ -51,7 +52,7 @@ private static Statistics estimateInnerJoin(Statistics crossJoinStats, List<Expr
for (int i = 0; i < sortedJoinConditions.size(); i++) {
sel *= Math.pow(sortedJoinConditions.get(i).second, 1 / Math.pow(2, i));
}
return crossJoinStats.withSel(sel);
return crossJoinStats.updateRowCountOnly(crossJoinStats.getRowCount() * sel);
}

private static double estimateJoinConditionSel(Statistics crossJoinStats, Expression joinCond) {
Expand All @@ -69,13 +70,19 @@ public static Statistics estimate(Statistics leftStats, Statistics rightStats, J
.putColumnStatistics(leftStats.columnStatistics())
.putColumnStatistics(rightStats.columnStatistics())
.build();
List<Expression> joinConditions = join.getHashJoinConjuncts();
Statistics innerJoinStats = estimateInnerJoin(crossJoinStats, joinConditions);
if (!join.getOtherJoinConjuncts().isEmpty()) {
FilterEstimation filterEstimation = new FilterEstimation();
innerJoinStats = filterEstimation.estimate(
ExpressionUtils.and(join.getOtherJoinConjuncts()), innerJoinStats);
Statistics innerJoinStats = null;
if (crossJoinStats.getRowCount() != 0) {
List<Expression> joinConditions = new ArrayList<>(join.getHashJoinConjuncts());
joinConditions.addAll(join.getOtherJoinConjuncts());
innerJoinStats = estimateInnerJoin(crossJoinStats, joinConditions);
} else {
innerJoinStats = crossJoinStats;
}
// if (!join.getOtherJoinConjuncts().isEmpty()) {
// FilterEstimation filterEstimation = new FilterEstimation();
// innerJoinStats = filterEstimation.estimate(
// ExpressionUtils.and(join.getOtherJoinConjuncts()), innerJoinStats);
// }
innerJoinStats.setWidth(leftStats.getWidth() + rightStats.getWidth());
innerJoinStats.setPenalty(0);
double rowCount;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation;
Expand Down Expand Up @@ -94,6 +96,7 @@
import org.apache.doris.statistics.StatisticsBuilder;

import com.google.common.collect.Maps;
import org.apache.commons.collections.CollectionUtils;

import java.util.AbstractMap.SimpleEntry;
import java.util.HashMap;
Expand Down Expand Up @@ -394,9 +397,24 @@ private Statistics computeAssertNumRows(long desiredNumOfRows) {
}

private Statistics computeFilter(Filter filter) {
FilterEstimation filterEstimation = new FilterEstimation();
Statistics stats = groupExpression.childStatistics(0);
return filterEstimation.estimate(filter.getPredicate(), stats);
Plan plan = tryToFindChild(groupExpression);
if (plan != null) {
if (plan instanceof Aggregate) {
Aggregate agg = ((Aggregate<?>) plan);
List<NamedExpression> expressions = agg.getOutputExpressions();
Set<Slot> slots = expressions
.stream()
.filter(Alias.class::isInstance)
.filter(s -> ((Alias) s).child().anyMatch(AggregateFunction.class::isInstance))
.map(NamedExpression::toSlot).collect(Collectors.toSet());
Expression predicate = filter.getPredicate();
if (predicate.anyMatch(s -> slots.contains(s))) {
return new FilterEstimation(slots).estimate(filter.getPredicate(), stats);
}
}
}
return new FilterEstimation().estimate(filter.getPredicate(), stats);
}

// TODO: 1. Subtract the pruned partition
Expand Down Expand Up @@ -441,13 +459,8 @@ private Statistics computeAggregate(Aggregate<? extends Plan> aggregate) {
if (!groupByExpressions.isEmpty()) {
Map<Expression, ColumnStatistic> childSlotToColumnStats = childStats.columnStatistics();
double inputRowCount = childStats.getRowCount();
if (inputRowCount == 0) {
//on empty relation, Agg output 1 tuple
resultSetCount = 1;
} else {
if (inputRowCount != 0) {
List<ColumnStatistic> groupByKeyStats = groupByExpressions.stream()
.flatMap(expr -> expr.getInputSlots().stream())
.map(Slot::getExprId)
.filter(childSlotToColumnStats::containsKey)
.map(childSlotToColumnStats::get)
.filter(s -> !s.isUnKnown)
Expand Down Expand Up @@ -692,4 +705,16 @@ private ColumnStatistic unionColumn(ColumnStatistic leftStats, double leftRowCou
.setAvgSizeByte(newAverageRowSize);
return columnStatisticBuilder.build();
}

private Plan tryToFindChild(GroupExpression groupExpression) {
List<GroupExpression> groupExprs = groupExpression.child(0).getLogicalExpressions();
if (CollectionUtils.isEmpty(groupExprs)) {
groupExprs = groupExpression.child(0).getPhysicalExpressions();
if (CollectionUtils.isEmpty(groupExprs)) {
return null;
}
}
return groupExprs.get(0).getPlan();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public StatsErrorEstimator() {
}

/**
* Map plan id to stats.
* Invoked by PhysicalPlanTranslator, put the translated plan node and corresponding physical plan to estimator.
*/
public void updateLegacyPlanIdToPhysicalPlan(PlanNode planNode, AbstractPlan physicalPlan) {
Statistics statistics = physicalPlan.getStats();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@ public class StatisticConstants {

public static final int LOAD_TASK_LIMITS = 10;

public static final double DEFAULT_INNER_JOIN_FACTOR = 0.1;

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,25 @@ public double getRowCount() {
return rowCount;
}

/*
* Return a stats with new rowCount and fix each column stats.
*/
public Statistics withRowCount(double rowCount) {
if (Double.isNaN(rowCount)) {
return this;
}
Statistics statistics = new Statistics(rowCount, new HashMap<>(expressionToColumnStats), width, penalty);
statistics.fix(rowCount, StatsMathUtil.nonZeroDivisor(this.rowCount));
return statistics;
}

public Statistics updateRowCountOnly(double rowCount) {
return new Statistics(rowCount, expressionToColumnStats);
}

public void fix(double newRowCount, double originRowCount) {
double sel = newRowCount / originRowCount;

for (Entry<Expression, ColumnStatistic> entry : expressionToColumnStats.entrySet()) {
ColumnStatistic columnStatistic = entry.getValue();
ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder(columnStatistic);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ public void testInNaN() {
Statistics stat = new Statistics(1000, slotToColumnStat);
FilterEstimation filterEstimation = new FilterEstimation();
Statistics expected = filterEstimation.estimate(in, stat);
Assertions.assertEquals(
FilterEstimation.DEFAULT_INEQUALITY_COEFFICIENT * stat.getRowCount(),
expected.getRowCount());
Assertions.assertTrue(Precision.equals(333.33, expected.getRowCount(), 0.01));
}

@Test
Expand All @@ -134,9 +132,7 @@ public void testNotInNaN() {
Statistics stat = new Statistics(1000, slotToColumnStat);
FilterEstimation filterEstimation = new FilterEstimation();
Statistics expected = filterEstimation.estimate(notIn, stat);
Assertions.assertEquals(
FilterEstimation.DEFAULT_INEQUALITY_COEFFICIENT * stat.getRowCount(),
expected.getRowCount());
Assertions.assertTrue(Precision.equals(333.33, expected.getRowCount(), 0.01));
}

/**
Expand Down
28 changes: 22 additions & 6 deletions tools/qerror.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import requests
import json
import time

mycli_cmd = "mysql -h127.0.0.1 -P9030 -uroot -Dtpch1G"

Expand All @@ -37,6 +38,8 @@
SET session_context='trace_id:{}';
"""

q_err_list = []


def extract_number(string):
return int(''.join([c for c in string if c.isdigit()]))
Expand Down Expand Up @@ -73,6 +76,7 @@ def execute_sql(sql_file: str):


def get_q_error(trace_id):
time.sleep(1)
# 'YWRtaW46' is the base64 encoded result for 'admin:'
headers = {'Authorization': 'BASIC YWRtaW46'}
resp_wrapper = requests.get(trace_url.format(trace_id), headers=headers)
Expand All @@ -81,9 +85,13 @@ def get_q_error(trace_id):
resp_wrapper = requests.get(qerror_url.format(query_id), headers=headers)
resp_text = resp_wrapper.text
write_result(str(trace_id), resp_text)
print(trace_id)
print(resp_text)
qerr = json.loads(resp_text)["qError"]
q_err_list.append(float(qerr))


def iterates_sqls(path: str) -> list:
def iterates_sqls(path: str, if_write_results: bool) -> list:
cost_times = []
files = os.listdir(path)
files.sort(key=extract_number)
Expand All @@ -93,14 +101,22 @@ def iterates_sqls(path: str) -> list:
traced_sql_file = filepath + ".traced"
content = read_lines(filepath)
sql_num = extract_number(filename)
write_results(traced_sql_file, str(sql_file_prefix_for_trace.format(sql_num)), content)
execute_sql(traced_sql_file)
get_q_error(sql_num)
os.remove(traced_sql_file)
print("sql num" + str(sql_num))
if if_write_results:
write_results(traced_sql_file, str(sql_file_prefix_for_trace.format(sql_num)), content)
execute_sql(traced_sql_file)
get_q_error(sql_num)
os.remove(traced_sql_file)
else:
execute_sql(filepath)
return cost_times


if __name__ == '__main__':
execute_command("echo 'set global enable_nereids_planner=true' | mysql -h127.0.0.1 -P9030")
execute_command("echo 'set global enable_fallback_to_original_planner=false' | mysql -h127.0.0.1 -P9030")
iterates_sqls(original_sql_dir)
print("Preparing")
iterates_sqls(original_sql_dir, False)
print("Started...")
iterates_sqls(original_sql_dir, True)
write_results(qerr_saved_file_path, "AVG\n", [sum(q_err_list) / len(qerror_url)])

0 comments on commit f600f70

Please sign in to comment.