|
| 1 | +/* Copyright (c) 2022 vesoft inc. All rights reserved. |
| 2 | + * |
| 3 | + * This source code is licensed under Apache 2.0 License. |
| 4 | + */ |
| 5 | + |
| 6 | +#include "graph/optimizer/rule/PushFilterDownInnerJoinRule.h" |
| 7 | + |
| 8 | +#include "graph/optimizer/OptContext.h" |
| 9 | +#include "graph/optimizer/OptGroup.h" |
| 10 | +#include "graph/planner/plan/PlanNode.h" |
| 11 | +#include "graph/planner/plan/Query.h" |
| 12 | +#include "graph/util/ExpressionUtils.h" |
| 13 | + |
| 14 | +using nebula::graph::PlanNode; |
| 15 | +using nebula::graph::QueryContext; |
| 16 | + |
| 17 | +namespace nebula { |
| 18 | +namespace opt { |
| 19 | + |
| 20 | +std::unique_ptr<OptRule> PushFilterDownInnerJoinRule::kInstance = |
| 21 | + std::unique_ptr<PushFilterDownInnerJoinRule>(new PushFilterDownInnerJoinRule()); |
| 22 | + |
| 23 | +PushFilterDownInnerJoinRule::PushFilterDownInnerJoinRule() { |
| 24 | + RuleSet::QueryRules().addRule(this); |
| 25 | +} |
| 26 | + |
| 27 | +const Pattern& PushFilterDownInnerJoinRule::pattern() const { |
| 28 | + static Pattern pattern = Pattern::create(graph::PlanNode::Kind::kFilter, |
| 29 | + {Pattern::create(graph::PlanNode::Kind::kInnerJoin)}); |
| 30 | + return pattern; |
| 31 | +} |
| 32 | + |
| 33 | +StatusOr<OptRule::TransformResult> PushFilterDownInnerJoinRule::transform( |
| 34 | + OptContext* octx, const MatchedResult& matched) const { |
| 35 | + auto* filterGroupNode = matched.node; |
| 36 | + auto* oldFilterNode = filterGroupNode->node(); |
| 37 | + auto deps = matched.dependencies; |
| 38 | + DCHECK_EQ(deps.size(), 1); |
| 39 | + auto innerJoinGroupNode = deps.front().node; |
| 40 | + auto* innerJoinNode = innerJoinGroupNode->node(); |
| 41 | + DCHECK_EQ(oldFilterNode->kind(), PlanNode::Kind::kFilter); |
| 42 | + DCHECK_EQ(innerJoinNode->kind(), PlanNode::Kind::kInnerJoin); |
| 43 | + auto* oldInnerJoinNode = static_cast<graph::InnerJoin*>(innerJoinNode); |
| 44 | + const auto* condition = static_cast<graph::Filter*>(oldFilterNode)->condition(); |
| 45 | + DCHECK(condition); |
| 46 | + const std::pair<std::string, int64_t>& leftVar = oldInnerJoinNode->leftVar(); |
| 47 | + auto symTable = octx->qctx()->symTable(); |
| 48 | + std::vector<std::string> leftVarColNames = symTable->getVar(leftVar.first)->colNames; |
| 49 | + |
| 50 | + // split the `condition` based on whether the varPropExpr comes from the left |
| 51 | + // child |
| 52 | + auto picker = [&leftVarColNames](const Expression* e) -> bool { |
| 53 | + auto varProps = graph::ExpressionUtils::collectAll(e, {Expression::Kind::kVarProperty}); |
| 54 | + if (varProps.empty()) { |
| 55 | + return false; |
| 56 | + } |
| 57 | + std::vector<std::string> propNames; |
| 58 | + for (auto* expr : varProps) { |
| 59 | + DCHECK(expr->kind() == Expression::Kind::kVarProperty); |
| 60 | + propNames.emplace_back(static_cast<const VariablePropertyExpression*>(expr)->prop()); |
| 61 | + } |
| 62 | + for (auto prop : propNames) { |
| 63 | + auto iter = std::find_if(leftVarColNames.begin(), |
| 64 | + leftVarColNames.end(), |
| 65 | + [&prop](std::string item) { return !item.compare(prop); }); |
| 66 | + if (iter == leftVarColNames.end()) { |
| 67 | + return false; |
| 68 | + } |
| 69 | + } |
| 70 | + return true; |
| 71 | + }; |
| 72 | + Expression* filterPicked = nullptr; |
| 73 | + Expression* filterUnpicked = nullptr; |
| 74 | + graph::ExpressionUtils::splitFilter(condition, picker, &filterPicked, &filterUnpicked); |
| 75 | + |
| 76 | + if (!filterPicked) { |
| 77 | + return TransformResult::noTransform(); |
| 78 | + } |
| 79 | + |
| 80 | + // produce new left Filter node |
| 81 | + auto* newLeftFilterNode = |
| 82 | + graph::Filter::make(octx->qctx(), |
| 83 | + const_cast<graph::PlanNode*>(oldInnerJoinNode->dep()), |
| 84 | + graph::ExpressionUtils::rewriteInnerVar(filterPicked, leftVar.first)); |
| 85 | + newLeftFilterNode->setInputVar(leftVar.first); |
| 86 | + newLeftFilterNode->setColNames(leftVarColNames); |
| 87 | + auto newFilterGroup = OptGroup::create(octx); |
| 88 | + auto newFilterGroupNode = newFilterGroup->makeGroupNode(newLeftFilterNode); |
| 89 | + for (auto dep : innerJoinGroupNode->dependencies()) { |
| 90 | + newFilterGroupNode->dependsOn(dep); |
| 91 | + } |
| 92 | + auto newLeftFilterOutputVar = newLeftFilterNode->outputVar(); |
| 93 | + |
| 94 | + // produce new InnerJoin node |
| 95 | + auto* newInnerJoinNode = static_cast<graph::InnerJoin*>(oldInnerJoinNode->clone()); |
| 96 | + newInnerJoinNode->setLeftVar({newLeftFilterOutputVar, 0}); |
| 97 | + const std::vector<Expression*>& hashKeys = oldInnerJoinNode->hashKeys(); |
| 98 | + std::vector<Expression*> newHashKeys; |
| 99 | + for (auto* k : hashKeys) { |
| 100 | + newHashKeys.emplace_back(graph::ExpressionUtils::rewriteInnerVar(k, newLeftFilterOutputVar)); |
| 101 | + } |
| 102 | + newInnerJoinNode->setHashKeys(newHashKeys); |
| 103 | + |
| 104 | + TransformResult result; |
| 105 | + result.eraseAll = true; |
| 106 | + if (filterUnpicked) { |
| 107 | + auto* newAboveFilterNode = graph::Filter::make(octx->qctx(), newInnerJoinNode); |
| 108 | + newAboveFilterNode->setOutputVar(oldFilterNode->outputVar()); |
| 109 | + newAboveFilterNode->setCondition(filterUnpicked); |
| 110 | + auto newAboveFilterGroupNode = |
| 111 | + OptGroupNode::create(octx, newAboveFilterNode, filterGroupNode->group()); |
| 112 | + |
| 113 | + auto newInnerJoinGroup = OptGroup::create(octx); |
| 114 | + auto newInnerJoinGroupNode = newInnerJoinGroup->makeGroupNode(newInnerJoinNode); |
| 115 | + newAboveFilterGroupNode->setDeps({newInnerJoinGroup}); |
| 116 | + newInnerJoinGroupNode->setDeps({newFilterGroup}); |
| 117 | + result.newGroupNodes.emplace_back(newAboveFilterGroupNode); |
| 118 | + } else { |
| 119 | + newInnerJoinNode->setOutputVar(oldFilterNode->outputVar()); |
| 120 | + newInnerJoinNode->setColNames(oldInnerJoinNode->colNames()); |
| 121 | + auto newInnerJoinGroupNode = |
| 122 | + OptGroupNode::create(octx, newInnerJoinNode, filterGroupNode->group()); |
| 123 | + newInnerJoinGroupNode->setDeps({newFilterGroup}); |
| 124 | + result.newGroupNodes.emplace_back(newInnerJoinGroupNode); |
| 125 | + } |
| 126 | + return result; |
| 127 | +} |
| 128 | + |
| 129 | +std::string PushFilterDownInnerJoinRule::toString() const { |
| 130 | + return "PushFilterDownInnerJoinRule"; |
| 131 | +} |
| 132 | + |
| 133 | +} // namespace opt |
| 134 | +} // namespace nebula |
0 commit comments