|
22 | 22 |
|
23 | 23 | namespace facebook::velox::substrait { |
24 | 24 |
|
25 | | -namespace { |
26 | | -template <TypeKind KIND> |
27 | | -VectorPtr setVectorFromVariantsByKind( |
28 | | - const std::vector<velox::variant>& value, |
29 | | - memory::MemoryPool* pool) { |
30 | | - using T = typename TypeTraits<KIND>::NativeType; |
31 | | - |
32 | | - auto flatVector = std::dynamic_pointer_cast<FlatVector<T>>( |
33 | | - BaseVector::create(CppToType<T>::create(), value.size(), pool)); |
34 | | - |
35 | | - for (vector_size_t i = 0; i < value.size(); i++) { |
36 | | - if (value[i].isNull()) { |
37 | | - flatVector->setNull(i, true); |
38 | | - } else { |
39 | | - flatVector->set(i, value[i].value<T>()); |
40 | | - } |
41 | | - } |
42 | | - return flatVector; |
43 | | -} |
44 | | - |
45 | | -template <> |
46 | | -VectorPtr setVectorFromVariantsByKind<TypeKind::VARBINARY>( |
47 | | - const std::vector<velox::variant>& value, |
48 | | - memory::MemoryPool* pool) { |
49 | | - throw std::invalid_argument("Return of VARBINARY data is not supported"); |
50 | | -} |
51 | | - |
52 | | -template <> |
53 | | -VectorPtr setVectorFromVariantsByKind<TypeKind::VARCHAR>( |
54 | | - const std::vector<velox::variant>& value, |
55 | | - memory::MemoryPool* pool) { |
56 | | - auto flatVector = std::dynamic_pointer_cast<FlatVector<StringView>>( |
57 | | - BaseVector::create(VARCHAR(), value.size(), pool)); |
58 | | - |
59 | | - for (vector_size_t i = 0; i < value.size(); i++) { |
60 | | - if (value[i].isNull()) { |
61 | | - flatVector->setNull(i, true); |
62 | | - } else { |
63 | | - flatVector->set(i, StringView(value[i].value<Varchar>())); |
64 | | - } |
| 25 | +std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan( |
| 26 | + const ::substrait::JoinRel& sJoin) { |
| 27 | + if (!sJoin.has_left()) { |
| 28 | + VELOX_FAIL("Left Rel is expected in JoinRel."); |
| 29 | + } |
| 30 | + if (!sJoin.has_right()) { |
| 31 | + VELOX_FAIL("Right Rel is expected in JoinRel."); |
| 32 | + } |
| 33 | + |
| 34 | + auto leftNode = toVeloxPlan(sJoin.left()); |
| 35 | + auto rightNode = toVeloxPlan(sJoin.right()); |
| 36 | + |
| 37 | + auto outputSize = |
| 38 | + leftNode->outputType()->size() + rightNode->outputType()->size(); |
| 39 | + std::vector<std::string> outputNames; |
| 40 | + std::vector<std::shared_ptr<const Type>> outputTypes; |
| 41 | + outputNames.reserve(outputSize); |
| 42 | + outputTypes.reserve(outputSize); |
| 43 | + for (const auto& node : {leftNode, rightNode}) { |
| 44 | + const auto& names = node->outputType()->names(); |
| 45 | + outputNames.insert(outputNames.end(), names.begin(), names.end()); |
| 46 | + const auto& types = node->outputType()->children(); |
| 47 | + outputTypes.insert(outputTypes.end(), types.begin(), types.end()); |
| 48 | + } |
| 49 | + auto outputRowType = std::make_shared<const RowType>( |
| 50 | + std::move(outputNames), std::move(outputTypes)); |
| 51 | + |
| 52 | + // extract join keys from join expression |
| 53 | + std::vector<const ::substrait::Expression::FieldReference*> leftExprs, |
| 54 | + rightExprs; |
| 55 | + extractJoinKeys(sJoin.expression(), leftExprs, rightExprs); |
| 56 | + VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size()); |
| 57 | + size_t numKeys = leftExprs.size(); |
| 58 | + |
| 59 | + std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys, |
| 60 | + rightKeys; |
| 61 | + leftKeys.reserve(numKeys); |
| 62 | + rightKeys.reserve(numKeys); |
| 63 | + for (size_t i = 0; i < numKeys; ++i) { |
| 64 | + leftKeys.emplace_back( |
| 65 | + exprConverter_->toVeloxExpr(*leftExprs[i], outputRowType)); |
| 66 | + rightKeys.emplace_back( |
| 67 | + exprConverter_->toVeloxExpr(*rightExprs[i], outputRowType)); |
| 68 | + } |
| 69 | + |
| 70 | + std::shared_ptr<const core::ITypedExpr> filter; |
| 71 | + if (sJoin.has_post_join_filter()) { |
| 72 | + filter = |
| 73 | + exprConverter_->toVeloxExpr(sJoin.post_join_filter(), outputRowType); |
| 74 | + } |
| 75 | + |
| 76 | + // Map join type |
| 77 | + core::JoinType joinType; |
| 78 | + switch (sJoin.type()) { |
| 79 | + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER: |
| 80 | + joinType = core::JoinType::kInner; |
| 81 | + break; |
| 82 | + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER: |
| 83 | + joinType = core::JoinType::kFull; |
| 84 | + break; |
| 85 | + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT: |
| 86 | + joinType = core::JoinType::kLeft; |
| 87 | + break; |
| 88 | + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT: |
| 89 | + joinType = core::JoinType::kRight; |
| 90 | + break; |
| 91 | + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_SEMI: |
| 92 | + joinType = core::JoinType::kSemi; |
| 93 | + break; |
| 94 | + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI: |
| 95 | + joinType = core::JoinType::kAnti; |
| 96 | + break; |
| 97 | + default: |
| 98 | + VELOX_NYI("Unsupported Join type: {}", sJoin.type()); |
65 | 99 | } |
66 | | - return flatVector; |
67 | | -} |
68 | 100 |
|
69 | | -VectorPtr setVectorFromVariants( |
70 | | - const TypePtr& type, |
71 | | - const std::vector<velox::variant>& value, |
72 | | - velox::memory::MemoryPool* pool) { |
73 | | - return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( |
74 | | - setVectorFromVariantsByKind, type->kind(), value, pool); |
| 101 | + // Create join node |
| 102 | + return std::make_shared<core::HashJoinNode>( |
| 103 | + nextPlanNodeId(), |
| 104 | + joinType, |
| 105 | + leftKeys, |
| 106 | + rightKeys, |
| 107 | + filter, |
| 108 | + leftNode, |
| 109 | + rightNode, |
| 110 | + outputRowType); |
75 | 111 | } |
76 | | -} // namespace |
77 | 112 |
|
78 | | -core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( |
79 | | - const ::substrait::AggregateRel& aggRel, |
80 | | - memory::MemoryPool* pool) { |
81 | | - core::PlanNodePtr childNode; |
82 | | - if (aggRel.has_input()) { |
83 | | - childNode = toVeloxPlan(aggRel.input(), pool); |
| 113 | +std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan( |
| 114 | + const ::substrait::AggregateRel& sAgg) { |
| 115 | + std::shared_ptr<const core::PlanNode> childNode; |
| 116 | + if (sAgg.has_input()) { |
| 117 | + childNode = toVeloxPlan(sAgg.input()); |
84 | 118 | } else { |
85 | 119 | VELOX_FAIL("Child Rel is expected in AggregateRel."); |
86 | 120 | } |
@@ -524,12 +558,11 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( |
524 | 558 | if (rel.has_project()) { |
525 | 559 | return toVeloxPlan(rel.project(), pool); |
526 | 560 | } |
527 | | - if (rel.has_filter()) { |
528 | | - return toVeloxPlan(rel.filter(), pool); |
| 561 | + if (sRel.has_join()) { |
| 562 | + return toVeloxPlan(sRel.join()); |
529 | 563 | } |
530 | | - if (rel.has_read()) { |
531 | | - return toVeloxPlan( |
532 | | - rel.read(), pool, partitionIndex_, paths_, starts_, lengths_); |
| 564 | + if (sRel.has_read()) { |
| 565 | + return toVeloxPlan(sRel.read(), partitionIndex_, paths_, starts_, lengths_); |
533 | 566 | } |
534 | 567 | VELOX_NYI("Substrait conversion not supported for Rel."); |
535 | 568 | } |
@@ -823,4 +856,38 @@ int32_t SubstraitVeloxPlanConverter::streamIsInput( |
823 | 856 | VELOX_FAIL("Local file is expected."); |
824 | 857 | } |
825 | 858 |
|
| 859 | +void SubstraitVeloxPlanConverter::extractJoinKeys( |
| 860 | + const ::substrait::Expression& joinExpression, |
| 861 | + std::vector<const ::substrait::Expression::FieldReference*>& leftExprs, |
| 862 | + std::vector<const ::substrait::Expression::FieldReference*>& rightExprs) { |
| 863 | + std::vector<const ::substrait::Expression*> expressions; |
| 864 | + expressions.push_back(&joinExpression); |
| 865 | + while (!expressions.empty()) { |
| 866 | + auto visited = expressions.back(); |
| 867 | + expressions.pop_back(); |
| 868 | + if (visited->rex_type_case() == |
| 869 | + ::substrait::Expression::RexTypeCase::kScalarFunction) { |
| 870 | + const auto& funcName = |
| 871 | + subParser_->getSubFunctionName(subParser_->findVeloxFunction( |
| 872 | + functionMap_, visited->scalar_function().function_reference())); |
| 873 | + const auto& args = visited->scalar_function().args(); |
| 874 | + if (funcName == "and") { |
| 875 | + expressions.push_back(&args[0]); |
| 876 | + expressions.push_back(&args[1]); |
| 877 | + } else if (funcName == "equal") { |
| 878 | + VELOX_CHECK(std::all_of( |
| 879 | + args.cbegin(), args.cend(), [](const ::substrait::Expression& arg) { |
| 880 | + return arg.has_selection(); |
| 881 | + })); |
| 882 | + leftExprs.push_back(&args[0].selection()); |
| 883 | + rightExprs.push_back(&args[1].selection()); |
| 884 | + } |
| 885 | + } else { |
| 886 | + VELOX_FAIL( |
| 887 | + "Unable to parse from join expression: {}", |
| 888 | + joinExpression.DebugString()); |
| 889 | + } |
| 890 | + } |
| 891 | +} |
| 892 | + |
826 | 893 | } // namespace facebook::velox::substrait |
0 commit comments