Skip to content

Commit 42b90da

Browse files
marin-mazhejiangxiaomai
authored andcommitted
[OPPRO-10] Enable hash join in Substrait-to-Velox conversion (#9)
* hash join * remove extra projection
1 parent 37f2710 commit 42b90da

File tree

5 files changed

+220
-65
lines changed

5 files changed

+220
-65
lines changed

velox/substrait/SubstraitToVeloxExpr.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class SubstraitVeloxExprConverter {
2828
/// subParser: A Substrait parser used to convert Substrait representations
2929
/// into recognizable representations. functionMap: A pre-constructed map
3030
/// storing the relations between the function id and the function name.
31-
SubstraitVeloxExprConverter(
31+
explicit SubstraitVeloxExprConverter(
3232
const std::unordered_map<uint64_t, std::string>& functionMap)
3333
: functionMap_(functionMap) {}
3434

velox/substrait/SubstraitToVeloxPlan.cpp

Lines changed: 127 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -22,65 +22,99 @@
2222

2323
namespace facebook::velox::substrait {
2424

25-
namespace {
26-
template <TypeKind KIND>
27-
VectorPtr setVectorFromVariantsByKind(
28-
const std::vector<velox::variant>& value,
29-
memory::MemoryPool* pool) {
30-
using T = typename TypeTraits<KIND>::NativeType;
31-
32-
auto flatVector = std::dynamic_pointer_cast<FlatVector<T>>(
33-
BaseVector::create(CppToType<T>::create(), value.size(), pool));
34-
35-
for (vector_size_t i = 0; i < value.size(); i++) {
36-
if (value[i].isNull()) {
37-
flatVector->setNull(i, true);
38-
} else {
39-
flatVector->set(i, value[i].value<T>());
40-
}
41-
}
42-
return flatVector;
43-
}
44-
45-
template <>
46-
VectorPtr setVectorFromVariantsByKind<TypeKind::VARBINARY>(
47-
const std::vector<velox::variant>& value,
48-
memory::MemoryPool* pool) {
49-
throw std::invalid_argument("Return of VARBINARY data is not supported");
50-
}
51-
52-
template <>
53-
VectorPtr setVectorFromVariantsByKind<TypeKind::VARCHAR>(
54-
const std::vector<velox::variant>& value,
55-
memory::MemoryPool* pool) {
56-
auto flatVector = std::dynamic_pointer_cast<FlatVector<StringView>>(
57-
BaseVector::create(VARCHAR(), value.size(), pool));
58-
59-
for (vector_size_t i = 0; i < value.size(); i++) {
60-
if (value[i].isNull()) {
61-
flatVector->setNull(i, true);
62-
} else {
63-
flatVector->set(i, StringView(value[i].value<Varchar>()));
64-
}
25+
std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
26+
const ::substrait::JoinRel& sJoin) {
27+
if (!sJoin.has_left()) {
28+
VELOX_FAIL("Left Rel is expected in JoinRel.");
29+
}
30+
if (!sJoin.has_right()) {
31+
VELOX_FAIL("Right Rel is expected in JoinRel.");
32+
}
33+
34+
auto leftNode = toVeloxPlan(sJoin.left());
35+
auto rightNode = toVeloxPlan(sJoin.right());
36+
37+
auto outputSize =
38+
leftNode->outputType()->size() + rightNode->outputType()->size();
39+
std::vector<std::string> outputNames;
40+
std::vector<std::shared_ptr<const Type>> outputTypes;
41+
outputNames.reserve(outputSize);
42+
outputTypes.reserve(outputSize);
43+
for (const auto& node : {leftNode, rightNode}) {
44+
const auto& names = node->outputType()->names();
45+
outputNames.insert(outputNames.end(), names.begin(), names.end());
46+
const auto& types = node->outputType()->children();
47+
outputTypes.insert(outputTypes.end(), types.begin(), types.end());
48+
}
49+
auto outputRowType = std::make_shared<const RowType>(
50+
std::move(outputNames), std::move(outputTypes));
51+
52+
// extract join keys from join expression
53+
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
54+
rightExprs;
55+
extractJoinKeys(sJoin.expression(), leftExprs, rightExprs);
56+
VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size());
57+
size_t numKeys = leftExprs.size();
58+
59+
std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys,
60+
rightKeys;
61+
leftKeys.reserve(numKeys);
62+
rightKeys.reserve(numKeys);
63+
for (size_t i = 0; i < numKeys; ++i) {
64+
leftKeys.emplace_back(
65+
exprConverter_->toVeloxExpr(*leftExprs[i], outputRowType));
66+
rightKeys.emplace_back(
67+
exprConverter_->toVeloxExpr(*rightExprs[i], outputRowType));
68+
}
69+
70+
std::shared_ptr<const core::ITypedExpr> filter;
71+
if (sJoin.has_post_join_filter()) {
72+
filter =
73+
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), outputRowType);
74+
}
75+
76+
// Map join type
77+
core::JoinType joinType;
78+
switch (sJoin.type()) {
79+
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
80+
joinType = core::JoinType::kInner;
81+
break;
82+
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER:
83+
joinType = core::JoinType::kFull;
84+
break;
85+
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT:
86+
joinType = core::JoinType::kLeft;
87+
break;
88+
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT:
89+
joinType = core::JoinType::kRight;
90+
break;
91+
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_SEMI:
92+
joinType = core::JoinType::kSemi;
93+
break;
94+
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI:
95+
joinType = core::JoinType::kAnti;
96+
break;
97+
default:
98+
VELOX_NYI("Unsupported Join type: {}", sJoin.type());
6599
}
66-
return flatVector;
67-
}
68100

69-
VectorPtr setVectorFromVariants(
70-
const TypePtr& type,
71-
const std::vector<velox::variant>& value,
72-
velox::memory::MemoryPool* pool) {
73-
return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
74-
setVectorFromVariantsByKind, type->kind(), value, pool);
101+
// Create join node
102+
return std::make_shared<core::HashJoinNode>(
103+
nextPlanNodeId(),
104+
joinType,
105+
leftKeys,
106+
rightKeys,
107+
filter,
108+
leftNode,
109+
rightNode,
110+
outputRowType);
75111
}
76-
} // namespace
77112

78-
core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
79-
const ::substrait::AggregateRel& aggRel,
80-
memory::MemoryPool* pool) {
81-
core::PlanNodePtr childNode;
82-
if (aggRel.has_input()) {
83-
childNode = toVeloxPlan(aggRel.input(), pool);
113+
std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
114+
const ::substrait::AggregateRel& sAgg) {
115+
std::shared_ptr<const core::PlanNode> childNode;
116+
if (sAgg.has_input()) {
117+
childNode = toVeloxPlan(sAgg.input());
84118
} else {
85119
VELOX_FAIL("Child Rel is expected in AggregateRel.");
86120
}
@@ -524,12 +558,11 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
524558
if (rel.has_project()) {
525559
return toVeloxPlan(rel.project(), pool);
526560
}
527-
if (rel.has_filter()) {
528-
return toVeloxPlan(rel.filter(), pool);
561+
if (sRel.has_join()) {
562+
return toVeloxPlan(sRel.join());
529563
}
530-
if (rel.has_read()) {
531-
return toVeloxPlan(
532-
rel.read(), pool, partitionIndex_, paths_, starts_, lengths_);
564+
if (sRel.has_read()) {
565+
return toVeloxPlan(sRel.read(), partitionIndex_, paths_, starts_, lengths_);
533566
}
534567
VELOX_NYI("Substrait conversion not supported for Rel.");
535568
}
@@ -823,4 +856,38 @@ int32_t SubstraitVeloxPlanConverter::streamIsInput(
823856
VELOX_FAIL("Local file is expected.");
824857
}
825858

859+
void SubstraitVeloxPlanConverter::extractJoinKeys(
860+
const ::substrait::Expression& joinExpression,
861+
std::vector<const ::substrait::Expression::FieldReference*>& leftExprs,
862+
std::vector<const ::substrait::Expression::FieldReference*>& rightExprs) {
863+
std::vector<const ::substrait::Expression*> expressions;
864+
expressions.push_back(&joinExpression);
865+
while (!expressions.empty()) {
866+
auto visited = expressions.back();
867+
expressions.pop_back();
868+
if (visited->rex_type_case() ==
869+
::substrait::Expression::RexTypeCase::kScalarFunction) {
870+
const auto& funcName =
871+
subParser_->getSubFunctionName(subParser_->findVeloxFunction(
872+
functionMap_, visited->scalar_function().function_reference()));
873+
const auto& args = visited->scalar_function().args();
874+
if (funcName == "and") {
875+
expressions.push_back(&args[0]);
876+
expressions.push_back(&args[1]);
877+
} else if (funcName == "equal") {
878+
VELOX_CHECK(std::all_of(
879+
args.cbegin(), args.cend(), [](const ::substrait::Expression& arg) {
880+
return arg.has_selection();
881+
}));
882+
leftExprs.push_back(&args[0].selection());
883+
rightExprs.push_back(&args[1].selection());
884+
}
885+
} else {
886+
VELOX_FAIL(
887+
"Unable to parse from join expression: {}",
888+
joinExpression.DebugString());
889+
}
890+
}
891+
}
892+
826893
} // namespace facebook::velox::substrait

velox/substrait/SubstraitToVeloxPlan.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@ namespace facebook::velox::substrait {
2525
/// This class is used to convert the Substrait plan into Velox plan.
2626
class SubstraitVeloxPlanConverter {
2727
public:
28-
/// Convert Substrait AggregateRel into Velox PlanNode.
29-
core::PlanNodePtr toVeloxPlan(
30-
const ::substrait::AggregateRel& aggRel,
31-
memory::MemoryPool* pool);
28+
/// Used to convert Substrait JoinRel into Velox PlanNode.
29+
std::shared_ptr<const core::PlanNode> toVeloxPlan(
30+
const ::substrait::JoinRel& sJoin);
31+
32+
/// Used to convert Substrait AggregateRel into Velox PlanNode.
33+
std::shared_ptr<const core::PlanNode> toVeloxPlan(
34+
const ::substrait::AggregateRel& sAgg);
3235

3336
/// Convert Substrait ProjectRel into Velox PlanNode.
3437
core::PlanNodePtr toVeloxPlan(
@@ -123,6 +126,16 @@ class SubstraitVeloxPlanConverter {
123126
/// Used to find the function specification in the constructed function map.
124127
std::string findFuncSpec(uint64_t id);
125128

129+
/// Extract join keys from joinExpression.
130+
/// joinExpression is a boolean condition that describes whether each record
131+
/// from the left set “match” the record from the right set. The condition
132+
/// must only include the following operations: AND, ==, field references.
133+
/// Field references correspond to the direct output order of the data.
134+
void extractJoinKeys(
135+
const ::substrait::Expression& joinExpression,
136+
std::vector<const ::substrait::Expression::FieldReference*>& leftExprs,
137+
std::vector<const ::substrait::Expression::FieldReference*>& rightExprs);
138+
126139
private:
127140
/// Returns unique ID to use for plan node. Produces sequential numbers
128141
/// starting from zero.

velox/substrait/SubstraitToVeloxPlanValidator.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,75 @@ bool SubstraitToVeloxPlanValidator::validate(
115115
return false;
116116
}
117117

118+
bool SubstraitToVeloxPlanValidator::validate(
119+
const ::substrait::JoinRel& sJoin) {
120+
if (sJoin.has_left() && !validate(sJoin.left())) {
121+
return false;
122+
}
123+
if (sJoin.has_right() && !validate(sJoin.right())) {
124+
return false;
125+
}
126+
127+
switch (sJoin.type()) {
128+
case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER:
129+
case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER:
130+
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
131+
case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT:
132+
case ::substrait::JoinRel_JoinType_JOIN_TYPE_SEMI:
133+
case ::substrait::JoinRel_JoinType_JOIN_TYPE_ANTI:
134+
break;
135+
default:
136+
return false;
137+
}
138+
139+
// Validate input types.
140+
if (!sJoin.has_advanced_extension()) {
141+
std::cout << "Input types are expected in JoinRel." << std::endl;
142+
return false;
143+
}
144+
145+
const auto& extension = sJoin.advanced_extension();
146+
std::vector<TypePtr> types;
147+
if (!validateInputTypes(extension, types)) {
148+
std::cout << "Validation failed for input types in JoinRel" << std::endl;
149+
return false;
150+
}
151+
152+
int32_t inputPlanNodeId = 0;
153+
std::vector<std::string> names;
154+
names.reserve(types.size());
155+
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
156+
names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx));
157+
}
158+
auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));
159+
160+
if (sJoin.has_expression()) {
161+
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
162+
rightExprs;
163+
try {
164+
planConverter_->extractJoinKeys(
165+
sJoin.expression(), leftExprs, rightExprs);
166+
} catch (const VeloxException& err) {
167+
std::cout << "Validation failed for expression in JoinRel due to:"
168+
<< err.message() << std::endl;
169+
return false;
170+
}
171+
}
172+
173+
if (sJoin.has_post_join_filter()) {
174+
try {
175+
auto expression =
176+
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), rowType);
177+
exec::ExprSet exprSet({std::move(expression)}, &execCtx_);
178+
} catch (const VeloxException& err) {
179+
std::cout << "Validation failed for expression in ProjectRel due to:"
180+
<< err.message() << std::endl;
181+
return false;
182+
}
183+
}
184+
return true;
185+
}
186+
118187
bool SubstraitToVeloxPlanValidator::validate(
119188
const ::substrait::AggregateRel& sAgg) {
120189
if (sAgg.has_input() && !validate(sAgg.input())) {
@@ -304,6 +373,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& sRel) {
304373
if (sRel.has_filter()) {
305374
return validate(sRel.filter());
306375
}
376+
if (sRel.has_join()) {
377+
return validate(sRel.join());
378+
}
307379
if (sRel.has_read()) {
308380
return validate(sRel.read());
309381
}

velox/substrait/SubstraitToVeloxPlanValidator.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class SubstraitToVeloxPlanValidator {
3636
/// Used to validate whether the computing of this Filter is supported.
3737
bool validate(const ::substrait::FilterRel& sFilter);
3838

39+
/// Used to validate Join.
40+
bool validate(const ::substrait::JoinRel& sJoin);
41+
3942
/// Used to validate whether the computing of this Read is supported.
4043
bool validate(const ::substrait::ReadRel& sRead);
4144

0 commit comments

Comments
 (0)