Skip to content

Commit b228025

Browse files
zhouyuanglutenperfbot
authored andcommitted
fix: Fix smj result mismatch issue in semi, anit and full outer join
Signed-off-by: Yuan <yuanzhou@apache.org>
1 parent 0272dd5 commit b228025

File tree

3 files changed

+154
-97
lines changed

3 files changed

+154
-97
lines changed

velox/exec/MergeJoin.cpp

Lines changed: 91 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,10 @@ void MergeJoin::initialize() {
113113
isSemiFilterJoin(joinType_)) {
114114
joinTracker_ = JoinTracker(outputBatchSize_, pool());
115115
}
116-
} else if (joinNode_->isAntiJoin()) {
116+
} else if (joinNode_->isAntiJoin() || joinNode_->isFullJoin()) {
117117
// Anti join needs to track the left side rows that have no match on the
118-
// right.
118+
// right. Full outer join needs to track the right side rows that have no
119+
// match on the left.
119120
joinTracker_ = JoinTracker(outputBatchSize_, pool());
120121
}
121122

@@ -392,7 +393,8 @@ bool MergeJoin::tryAddOutputRow(
392393
const RowVectorPtr& leftBatch,
393394
vector_size_t leftRow,
394395
const RowVectorPtr& rightBatch,
395-
vector_size_t rightRow) {
396+
vector_size_t rightRow,
397+
bool isRightJoinForFullOuter) {
396398
if (outputSize_ == outputBatchSize_) {
397399
return false;
398400
}
@@ -426,12 +428,15 @@ bool MergeJoin::tryAddOutputRow(
426428
filterRightInputProjections_);
427429

428430
if (joinTracker_) {
429-
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
431+
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_) ||
432+
(isFullJoin(joinType_) && isRightJoinForFullOuter)) {
430433
// Record right-side row with a match on the left-side.
431-
joinTracker_->addMatch(rightBatch, rightRow, outputSize_);
434+
joinTracker_->addMatch(
435+
rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
432436
} else {
433437
// Record left-side row with a match on the right-side.
434-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
438+
joinTracker_->addMatch(
439+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
435440
}
436441
}
437442
}
@@ -441,7 +446,8 @@ bool MergeJoin::tryAddOutputRow(
441446
if (isAntiJoin(joinType_)) {
442447
VELOX_CHECK(joinTracker_.has_value());
443448
// Record left-side row with a match on the right-side.
444-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
449+
joinTracker_->addMatch(
450+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
445451
}
446452

447453
++outputSize_;
@@ -459,14 +465,15 @@ bool MergeJoin::prepareOutput(
459465
return true;
460466
}
461467

462-
if (isRightJoin(joinType_) && right != currentRight_) {
463-
return true;
464-
}
465-
466468
// If there is a new right, we need to flatten the dictionary.
467469
if (!isRightFlattened_ && right && currentRight_ != right) {
468470
flattenRightProjections();
469471
}
472+
473+
if (right != currentRight_) {
474+
return true;
475+
}
476+
470477
return false;
471478
}
472479

@@ -489,11 +496,15 @@ bool MergeJoin::prepareOutput(
489496
}
490497
} else {
491498
for (const auto& projection : leftProjections_) {
499+
auto column = left->childAt(projection.inputChannel);
500+
// Flatten the left column if the column already is DictionaryVector.
501+
if (column->wrappedVector()->encoding() ==
502+
VectorEncoding::Simple::DICTIONARY) {
503+
BaseVector::flattenVector(column);
504+
}
505+
column->clearContainingLazyAndWrapped();
492506
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
493-
{},
494-
leftOutputIndices_,
495-
outputBatchSize_,
496-
left->childAt(projection.inputChannel));
507+
{}, leftOutputIndices_, outputBatchSize_, column);
497508
}
498509
}
499510
currentLeft_ = left;
@@ -509,11 +520,10 @@ bool MergeJoin::prepareOutput(
509520
isRightFlattened_ = true;
510521
} else {
511522
for (const auto& projection : rightProjections_) {
523+
auto column = right->childAt(projection.inputChannel);
524+
column->clearContainingLazyAndWrapped();
512525
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
513-
{},
514-
rightOutputIndices_,
515-
outputBatchSize_,
516-
right->childAt(projection.inputChannel));
526+
{}, rightOutputIndices_, outputBatchSize_, column);
517527
}
518528
isRightFlattened_ = false;
519529
}
@@ -577,6 +587,39 @@ bool MergeJoin::prepareOutput(
577587
bool MergeJoin::addToOutput() {
578588
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
579589
return addToOutputForRightJoin();
590+
} else if (isFullJoin(joinType_) && filter_) {
591+
if (!leftForRightJoinMatch_) {
592+
leftForRightJoinMatch_ = leftMatch_;
593+
rightForRightJoinMatch_ = rightMatch_;
594+
}
595+
596+
if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
597+
auto left = addToOutputForLeftJoin();
598+
if (!leftMatch_) {
599+
leftJoinForFullFinished_ = true;
600+
}
601+
if (left) {
602+
if (!leftMatch_) {
603+
leftMatch_ = leftForRightJoinMatch_;
604+
rightMatch_ = rightForRightJoinMatch_;
605+
}
606+
607+
return true;
608+
}
609+
}
610+
611+
if (!leftMatch_ && !rightJoinForFullFinished_) {
612+
leftMatch_ = leftForRightJoinMatch_;
613+
rightMatch_ = rightForRightJoinMatch_;
614+
rightJoinForFullFinished_ = true;
615+
}
616+
617+
auto right = addToOutputForRightJoin();
618+
619+
leftForRightJoinMatch_ = leftMatch_;
620+
rightForRightJoinMatch_ = rightMatch_;
621+
622+
return right;
580623
} else {
581624
return addToOutputForLeftJoin();
582625
}
@@ -669,7 +712,13 @@ bool MergeJoin::addToOutputImpl() {
669712
} else {
670713
for (auto innerRow = innerStartRow; innerRow < innerEndRow;
671714
++innerRow) {
672-
if (!tryAddOutputRow(leftBatch, innerRow, rightBatch, outerRow)) {
715+
const auto isRightJoinForFullOuter = isFullJoin(joinType_);
716+
if (!tryAddOutputRow(
717+
leftBatch,
718+
innerRow,
719+
rightBatch,
720+
outerRow,
721+
isRightJoinForFullOuter)) {
673722
outerMatch->setCursor(outerBatchIndex, outerRow);
674723
innerMatch->setCursor(innerBatchIndex, innerRow);
675724
return true;
@@ -938,7 +987,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
938987
isFullJoin(joinType_)) {
939988
// If output_ is currently wrapping a different buffer, return it
940989
// first.
941-
if (prepareOutput(input_, nullptr)) {
990+
if (prepareOutput(input_, rightInput_)) {
942991
output_->resize(outputSize_);
943992
return std::move(output_);
944993
}
@@ -963,7 +1012,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
9631012
if (isRightJoin(joinType_) || isFullJoin(joinType_)) {
9641013
// If output_ is currently wrapping a different buffer, return it
9651014
// first.
966-
if (prepareOutput(nullptr, rightInput_)) {
1015+
if (prepareOutput(input_, rightInput_)) {
9671016
output_->resize(outputSize_);
9681017
return std::move(output_);
9691018
}
@@ -1013,6 +1062,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
10131062
matchedLeftRows_ += leftEndRow - leftMatch_->startRowIndex;
10141063
matchedRightRows_ += rightEndRow - rightMatch_->startRowIndex;
10151064

1065+
leftJoinForFullFinished_ = false;
1066+
rightJoinForFullFinished_ = false;
10161067
if (!leftMatch_->complete || !rightMatch_->complete) {
10171068
if (!leftMatch_->complete) {
10181069
// Need to continue looking for the end of match.
@@ -1274,8 +1325,6 @@ void MergeJoin::clearRightInput() {
12741325
RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12751326
const auto numRows = output->size();
12761327

1277-
RowVectorPtr fullOuterOutput = nullptr;
1278-
12791328
BufferPtr indices = allocateIndices(numRows, pool());
12801329
auto* rawIndices = indices->asMutable<vector_size_t>();
12811330
vector_size_t numPassed = 0;
@@ -1292,84 +1341,41 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12921341

12931342
// If all matches for a given left-side row fail the filter, add a row to
12941343
// the output with nulls for the right-side columns.
1295-
const auto onMiss = [&](auto row) {
1344+
const auto onMiss = [&](auto row, bool isRightJoinForFullOuter) {
12961345
if (isSemiFilterJoin(joinType_)) {
12971346
return;
12981347
}
12991348
rawIndices[numPassed++] = row;
13001349

1301-
if (isFullJoin(joinType_)) {
1302-
// For filtered rows, it is necessary to insert additional data
1303-
// to ensure the result set is complete. Specifically, we
1304-
// need to generate two records: one record containing the
1305-
// columns from the left table along with nulls for the
1306-
// right table, and another record containing the columns
1307-
// from the right table along with nulls for the left table.
1308-
// For instance, the current output is filtered based on the condition
1309-
// t > 1.
1310-
1311-
// 1, 1
1312-
// 2, 2
1313-
// 3, 3
1314-
1315-
// In this scenario, we need to additionally insert a record 1, 1.
1316-
// Subsequently, we will set the values of the columns on the left to
1317-
// null and the values of the columns on the right to null as well. By
1318-
// doing so, we will obtain the final result set.
1319-
1320-
// 1, null
1321-
// null, 1
1322-
// 2, 2
1323-
// 3, 3
1324-
fullOuterOutput = BaseVector::create<RowVector>(
1325-
output->type(), output->size() + 1, pool());
1326-
1327-
for (auto i = 0; i < row + 1; ++i) {
1328-
for (auto j = 0; j < output->type()->size(); ++j) {
1329-
fullOuterOutput->childAt(j)->copy(
1330-
output->childAt(j).get(), i, i, 1);
1350+
if (!isRightJoin(joinType_)) {
1351+
if (isFullJoin(joinType_) && isRightJoinForFullOuter) {
1352+
for (auto& projection : leftProjections_) {
1353+
auto target = output->childAt(projection.outputChannel);
1354+
target->setNull(row, true);
13311355
}
1332-
}
1333-
1334-
for (auto j = 0; j < output->type()->size(); ++j) {
1335-
fullOuterOutput->childAt(j)->copy(
1336-
output->childAt(j).get(), row + 1, row, 1);
1337-
}
1338-
1339-
for (auto i = row + 1; i < output->size(); ++i) {
1340-
for (auto j = 0; j < output->type()->size(); ++j) {
1341-
fullOuterOutput->childAt(j)->copy(
1342-
output->childAt(j).get(), i + 1, i, 1);
1356+
} else {
1357+
for (auto& projection : rightProjections_) {
1358+
auto target = output->childAt(projection.outputChannel);
1359+
target->setNull(row, true);
13431360
}
13441361
}
1345-
1346-
for (auto& projection : leftProjections_) {
1347-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1348-
target->setNull(row, true);
1349-
}
1350-
1351-
for (auto& projection : rightProjections_) {
1352-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1353-
target->setNull(row + 1, true);
1354-
}
1355-
} else if (!isRightJoin(joinType_)) {
1356-
for (auto& projection : rightProjections_) {
1357-
auto& target = output->childAt(projection.outputChannel);
1358-
target->setNull(row, true);
1359-
}
13601362
} else {
13611363
for (auto& projection : leftProjections_) {
1362-
auto& target = output->childAt(projection.outputChannel);
1364+
auto target = output->childAt(projection.outputChannel);
13631365
target->setNull(row, true);
13641366
}
13651367
}
13661368
};
13671369

13681370
auto onMatch = [&](auto row, bool firstMatch) {
1369-
const bool isNonSemiAntiJoin =
1370-
!isSemiFilterJoin(joinType_) && !isAntiJoin(joinType_);
1371+
const bool isFullLeftJoin =
1372+
isFullJoin(joinType_) && !joinTracker_->isRightJoinForFullOuter(row);
1373+
1374+
const bool isNonSemiAntiFullJoin = !isSemiFilterJoin(joinType_) &&
1375+
!isAntiJoin(joinType_) && !isFullJoin(joinType_);
13711376

1372-
if ((isSemiFilterJoin(joinType_) && firstMatch) || isNonSemiAntiJoin) {
1377+
if ((isSemiFilterJoin(joinType_) && firstMatch) ||
1378+
isNonSemiAntiFullJoin || isFullLeftJoin) {
13731379
rawIndices[numPassed++] = row;
13741380
}
13751381
};
@@ -1430,17 +1436,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
14301436

14311437
if (numPassed == numRows) {
14321438
// All rows passed.
1433-
if (fullOuterOutput) {
1434-
return fullOuterOutput;
1435-
}
14361439
return output;
14371440
}
14381441

14391442
// Some, but not all rows passed.
1440-
if (fullOuterOutput) {
1441-
return wrap(numPassed, indices, fullOuterOutput);
1442-
}
1443-
14441443
return wrap(numPassed, indices, output);
14451444
}
14461445

0 commit comments

Comments
 (0)