Skip to content

Commit

Permalink
[BugFix] Clear probe RF whose probe expr contains dict mapping expr (S…
Browse files Browse the repository at this point in the history
…tarRocks#50690)

Signed-off-by: zihe.liu <ziheliu1024@gmail.com>
Signed-off-by: zhiminr.ren <1240388654@qq.com>
  • Loading branch information
ZiheLiu authored and renzhimin7 committed Nov 7, 2024
1 parent 41844c8 commit 52cf135
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 3 deletions.
24 changes: 22 additions & 2 deletions be/src/exprs/runtime_filter_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "column/column.h"
#include "exec/pipeline/runtime_filter_types.h"
#include "exprs/dictmapping_expr.h"
#include "exprs/in_const_predicate.hpp"
#include "exprs/literal.h"
#include "exprs/runtime_filter.h"
Expand Down Expand Up @@ -642,6 +643,23 @@ void RuntimeFilterProbeCollector::update_selectivity(Chunk* chunk, RuntimeBloomF
}
}

static bool contains_dict_mapping_expr(Expr* expr) {
if (typeid(*expr) == typeid(DictMappingExpr)) {
return true;
}

return std::any_of(expr->children().begin(), expr->children().end(),
[](Expr* child) { return contains_dict_mapping_expr(child); });
}

static bool contains_dict_mapping_expr(RuntimeFilterProbeDescriptor* probe_desc) {
auto* probe_expr_ctx = probe_desc->probe_expr_ctx();
if (probe_expr_ctx == nullptr) {
return false;
}
return contains_dict_mapping_expr(probe_expr_ctx->root());
}

void RuntimeFilterProbeCollector::push_down(const RuntimeState* state, TPlanNodeId target_plan_node_id,
RuntimeFilterProbeCollector* parent, const std::vector<TupleId>& tuple_ids,
std::set<TPlanNodeId>& local_rf_waiting_set) {
Expand All @@ -653,8 +671,10 @@ void RuntimeFilterProbeCollector::push_down(const RuntimeState* state, TPlanNode
++iter;
continue;
}
if (desc->is_bound(tuple_ids) && !(state->broadcast_join_right_offsprings().contains(target_plan_node_id) &&
state->non_broadcast_rf_ids().contains(desc->filter_id()))) {
if (desc->is_bound(tuple_ids) &&
!(state->broadcast_join_right_offsprings().contains(target_plan_node_id) &&
state->non_broadcast_rf_ids().contains(desc->filter_id())) &&
!contains_dict_mapping_expr(desc)) {
add_descriptor(desc);
if (desc->is_local()) {
local_rf_waiting_set.insert(desc->build_plan_node_id());
Expand Down
11 changes: 11 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/analysis/Expr.java
Original file line number Diff line number Diff line change
Expand Up @@ -1524,4 +1524,15 @@ public List<String> getHints() {
return hints;
}

public boolean containsDictMappingExpr() {
return containsDictMappingExpr(this);
}

private static boolean containsDictMappingExpr(Expr expr) {
if (expr instanceof DictMappingExpr) {
return true;
}
return expr.getChildren().stream().anyMatch(child -> containsDictMappingExpr(child));
}

}
18 changes: 18 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/planner/PlanFragment.java
Original file line number Diff line number Diff line change
Expand Up @@ -972,4 +972,22 @@ public void removeRfOnRightOffspringsOfBroadcastJoin() {

removeRfOfRightOffspring(getPlanRoot(), localRightOffsprings, filterIds);
}

public void removeDictMappingProbeRuntimeFilters() {
removeDictMappingProbeRuntimeFilters(getPlanRoot());
}

private void removeDictMappingProbeRuntimeFilters(PlanNode root) {
root.getProbeRuntimeFilters().removeIf(filter -> {
Expr probExpr = filter.getNodeIdToProbeExpr().get(root.getId().asInt());
return probExpr.containsDictMappingExpr();
});

for (PlanNode child : root.getChildren()) {
if (child.getFragmentId().equals(root.getFragmentId())) {
removeDictMappingProbeRuntimeFilters(child);
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,11 @@ public boolean pushDownRuntimeFilters(RuntimeFilterPushDownContext context,
return false;
}

Optional<List<Expr>> optProbeExprCandidates = candidatesOfSlotExpr(probeExpr, couldBound(description, descTbl));
optProbeExprCandidates.ifPresent(exprs -> exprs.removeIf(probeExprCandidate -> probeExprCandidate.containsDictMappingExpr()));

return pushdownRuntimeFilterForChildOrAccept(context, probeExpr,
candidatesOfSlotExpr(probeExpr, couldBound(description, descTbl)),
optProbeExprCandidates,
partitionByExprs, candidatesOfSlotExprs(partitionByExprs, couldBoundForPartitionExpr()), 0, true);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ private static ExecPlan finalizeFragments(ExecPlan execPlan, TResultSinkType res
fragment.computeLocalRfWaitingSet(fragment.getPlanRoot(), shouldClearRuntimeFilters);
}

fragments.forEach(PlanFragment::removeDictMappingProbeRuntimeFilters);

if (useQueryCache(execPlan)) {
for (PlanFragment fragment : execPlan.getFragments()) {
FragmentNormalizer normalizer = new FragmentNormalizer(execPlan, fragment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1943,4 +1943,50 @@ public void testExistRequiredDistribution() throws Exception {
Assert.assertFalse("table doesn't contain global dict, we can change its distribution",
execPlan.getOptExpression(1).isExistRequiredDistribution());
}

@Test
public void testRuntimeFilterOnProjectWithDictExpr() throws Exception {
String sql = "WITH \n" +
" w1 AS (\n" +
" SELECT CASE\n" +
" WHEN P_NAME = 'a' THEN 'a1'\n" +
" WHEN P_BRAND = 'b' THEN 'b1'\n" +
" ELSE 'c1'\n" +
" END as P_NAME2, P_NAME from part_v2\n" +
" UNION ALL\n" +
" SELECT P_NAME, P_NAME from part_v2\n" +
")\n" +
"SELECT count(1) \n" +
"FROM \n" +
" w1 t1 \n" +
" JOIN [broadcast] part_v2 t2 ON t1.P_NAME2 = t2.P_NAME AND t1.P_NAME = t2.P_NAME;";
String plan = getCostExplain(sql);
assertContains(plan, " 3:Decode\n" +
" | <dict id 38> : <string id 2>\n" +
" | cardinality: 1\n" +
" | probe runtime filters:\n" +
" | - filter_id = 1, probe_expr = (2: P_NAME)\n" +
" | column statistics: \n" +
" | * P_NAME-->[-Infinity, Infinity, 0.0, 1.0, 1.0] UNKNOWN\n" +
" | * P_BRAND-->[-Infinity, Infinity, 0.0, 1.0, 1.0] UNKNOWN\n" +
" | * cast-->[-Infinity, Infinity, 0.0, 16.0, 3.0] ESTIMATE\n" +
" | \n" +
" 2:Project\n" +
" | output columns:\n" +
" | 12 <-> CASE WHEN DictDecode(38: P_NAME, [<place-holder> = 'a']) THEN 'a1' " +
"WHEN DictDecode(39: P_BRAND, [<place-holder> = 'b']) THEN 'b1' ELSE 'c1' END\n" +
" | 38 <-> [38: P_NAME, INT, false]\n" +
" | cardinality: 1\n" +
" | probe runtime filters:\n" +
" | - filter_id = 0, probe_expr = (<slot 12>)\n" +
" | column statistics: \n" +
" | * cast-->[-Infinity, Infinity, 0.0, 16.0, 3.0] ESTIMATE\n" +
" | \n" +
" 1:OlapScanNode\n" +
" table: part_v2, rollup: part_v2\n" +
" preAggregation: on\n" +
" dict_col=P_NAME,P_BRAND");
System.out.println(plan);
}

}

0 comments on commit 52cf135

Please sign in to comment.