@@ -113,9 +113,10 @@ void MergeJoin::initialize() {
113113 isSemiFilterJoin (joinType_)) {
114114 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
115115 }
116- } else if (joinNode_->isAntiJoin ()) {
116+ } else if (joinNode_->isAntiJoin () || joinNode_-> isFullJoin () ) {
117117 // Anti join needs to track the left side rows that have no match on the
118- // right.
118+ // right. Full outer join needs to track the right side rows that have no
119+ // match on the left.
119120 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
120121 }
121122
@@ -392,7 +393,8 @@ bool MergeJoin::tryAddOutputRow(
392393 const RowVectorPtr& leftBatch,
393394 vector_size_t leftRow,
394395 const RowVectorPtr& rightBatch,
395- vector_size_t rightRow) {
396+ vector_size_t rightRow,
397+ bool isRightJoinForFullOuter) {
396398 if (outputSize_ == outputBatchSize_) {
397399 return false ;
398400 }
@@ -426,12 +428,15 @@ bool MergeJoin::tryAddOutputRow(
426428 filterRightInputProjections_);
427429
428430 if (joinTracker_) {
429- if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
431+ if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_) ||
432+ (isFullJoin (joinType_) && isRightJoinForFullOuter)) {
430433 // Record right-side row with a match on the left-side.
431- joinTracker_->addMatch (rightBatch, rightRow, outputSize_);
434+ joinTracker_->addMatch (
435+ rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
432436 } else {
433437 // Record left-side row with a match on the right-side.
434- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
438+ joinTracker_->addMatch (
439+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
435440 }
436441 }
437442 }
@@ -441,7 +446,8 @@ bool MergeJoin::tryAddOutputRow(
441446 if (isAntiJoin (joinType_)) {
442447 VELOX_CHECK (joinTracker_.has_value ());
443448 // Record left-side row with a match on the right-side.
444- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
449+ joinTracker_->addMatch (
450+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
445451 }
446452
447453 ++outputSize_;
@@ -459,14 +465,15 @@ bool MergeJoin::prepareOutput(
459465 return true ;
460466 }
461467
462- if (isRightJoin (joinType_) && right != currentRight_) {
463- return true ;
464- }
465-
466468 // If there is a new right, we need to flatten the dictionary.
467469 if (!isRightFlattened_ && right && currentRight_ != right) {
468470 flattenRightProjections ();
469471 }
472+
473+ if (right != currentRight_) {
474+ return true ;
475+ }
476+
470477 return false ;
471478 }
472479
@@ -489,11 +496,15 @@ bool MergeJoin::prepareOutput(
489496 }
490497 } else {
491498 for (const auto & projection : leftProjections_) {
499+ auto column = left->childAt (projection.inputChannel );
500+ // Flatten the left column if the column already is DictionaryVector.
501+ if (column->wrappedVector ()->encoding () ==
502+ VectorEncoding::Simple::DICTIONARY) {
503+ BaseVector::flattenVector (column);
504+ }
505+ column->clearContainingLazyAndWrapped ();
492506 localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
493- {},
494- leftOutputIndices_,
495- outputBatchSize_,
496- left->childAt (projection.inputChannel ));
507+ {}, leftOutputIndices_, outputBatchSize_, column);
497508 }
498509 }
499510 currentLeft_ = left;
@@ -509,11 +520,10 @@ bool MergeJoin::prepareOutput(
509520 isRightFlattened_ = true ;
510521 } else {
511522 for (const auto & projection : rightProjections_) {
523+ auto column = right->childAt (projection.inputChannel );
524+ column->clearContainingLazyAndWrapped ();
512525 localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
513- {},
514- rightOutputIndices_,
515- outputBatchSize_,
516- right->childAt (projection.inputChannel ));
526+ {}, rightOutputIndices_, outputBatchSize_, column);
517527 }
518528 isRightFlattened_ = false ;
519529 }
@@ -577,6 +587,39 @@ bool MergeJoin::prepareOutput(
577587bool MergeJoin::addToOutput () {
578588 if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
579589 return addToOutputForRightJoin ();
590+ } else if (isFullJoin (joinType_) && filter_) {
591+ if (!leftForRightJoinMatch_) {
592+ leftForRightJoinMatch_ = leftMatch_;
593+ rightForRightJoinMatch_ = rightMatch_;
594+ }
595+
596+ if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
597+ auto left = addToOutputForLeftJoin ();
598+ if (!leftMatch_) {
599+ leftJoinForFullFinished_ = true ;
600+ }
601+ if (left) {
602+ if (!leftMatch_) {
603+ leftMatch_ = leftForRightJoinMatch_;
604+ rightMatch_ = rightForRightJoinMatch_;
605+ }
606+
607+ return true ;
608+ }
609+ }
610+
611+ if (!leftMatch_ && !rightJoinForFullFinished_) {
612+ leftMatch_ = leftForRightJoinMatch_;
613+ rightMatch_ = rightForRightJoinMatch_;
614+ rightJoinForFullFinished_ = true ;
615+ }
616+
617+ auto right = addToOutputForRightJoin ();
618+
619+ leftForRightJoinMatch_ = leftMatch_;
620+ rightForRightJoinMatch_ = rightMatch_;
621+
622+ return right;
580623 } else {
581624 return addToOutputForLeftJoin ();
582625 }
@@ -669,7 +712,13 @@ bool MergeJoin::addToOutputImpl() {
669712 } else {
670713 for (auto innerRow = innerStartRow; innerRow < innerEndRow;
671714 ++innerRow) {
672- if (!tryAddOutputRow (leftBatch, innerRow, rightBatch, outerRow)) {
715+ const auto isRightJoinForFullOuter = isFullJoin (joinType_);
716+ if (!tryAddOutputRow (
717+ leftBatch,
718+ innerRow,
719+ rightBatch,
720+ outerRow,
721+ isRightJoinForFullOuter)) {
673722 outerMatch->setCursor (outerBatchIndex, outerRow);
674723 innerMatch->setCursor (innerBatchIndex, innerRow);
675724 return true ;
@@ -938,7 +987,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
938987 isFullJoin (joinType_)) {
939988 // If output_ is currently wrapping a different buffer, return it
940989 // first.
941- if (prepareOutput (input_, nullptr )) {
990+ if (prepareOutput (input_, rightInput_ )) {
942991 output_->resize (outputSize_);
943992 return std::move (output_);
944993 }
@@ -963,7 +1012,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
9631012 if (isRightJoin (joinType_) || isFullJoin (joinType_)) {
9641013 // If output_ is currently wrapping a different buffer, return it
9651014 // first.
966- if (prepareOutput (nullptr , rightInput_)) {
1015+ if (prepareOutput (input_ , rightInput_)) {
9671016 output_->resize (outputSize_);
9681017 return std::move (output_);
9691018 }
@@ -1013,6 +1062,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
10131062 matchedLeftRows_ += leftEndRow - leftMatch_->startRowIndex ;
10141063 matchedRightRows_ += rightEndRow - rightMatch_->startRowIndex ;
10151064
1065+ leftJoinForFullFinished_ = false ;
1066+ rightJoinForFullFinished_ = false ;
10161067 if (!leftMatch_->complete || !rightMatch_->complete ) {
10171068 if (!leftMatch_->complete ) {
10181069 // Need to continue looking for the end of match.
@@ -1274,8 +1325,6 @@ void MergeJoin::clearRightInput() {
12741325RowVectorPtr MergeJoin::applyFilter (const RowVectorPtr& output) {
12751326 const auto numRows = output->size ();
12761327
1277- RowVectorPtr fullOuterOutput = nullptr ;
1278-
12791328 BufferPtr indices = allocateIndices (numRows, pool ());
12801329 auto * rawIndices = indices->asMutable <vector_size_t >();
12811330 vector_size_t numPassed = 0 ;
@@ -1292,84 +1341,41 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12921341
12931342 // If all matches for a given left-side row fail the filter, add a row to
12941343 // the output with nulls for the right-side columns.
1295- const auto onMiss = [&](auto row) {
1344+ const auto onMiss = [&](auto row, bool isRightJoinForFullOuter ) {
12961345 if (isSemiFilterJoin (joinType_)) {
12971346 return ;
12981347 }
12991348 rawIndices[numPassed++] = row;
13001349
1301- if (isFullJoin (joinType_)) {
1302- // For filtered rows, it is necessary to insert additional data
1303- // to ensure the result set is complete. Specifically, we
1304- // need to generate two records: one record containing the
1305- // columns from the left table along with nulls for the
1306- // right table, and another record containing the columns
1307- // from the right table along with nulls for the left table.
1308- // For instance, the current output is filtered based on the condition
1309- // t > 1.
1310-
1311- // 1, 1
1312- // 2, 2
1313- // 3, 3
1314-
1315- // In this scenario, we need to additionally insert a record 1, 1.
1316- // Subsequently, we will set the values of the columns on the left to
1317- // null and the values of the columns on the right to null as well. By
1318- // doing so, we will obtain the final result set.
1319-
1320- // 1, null
1321- // null, 1
1322- // 2, 2
1323- // 3, 3
1324- fullOuterOutput = BaseVector::create<RowVector>(
1325- output->type (), output->size () + 1 , pool ());
1326-
1327- for (auto i = 0 ; i < row + 1 ; ++i) {
1328- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1329- fullOuterOutput->childAt (j)->copy (
1330- output->childAt (j).get (), i, i, 1 );
1350+ if (!isRightJoin (joinType_)) {
1351+ if (isFullJoin (joinType_) && isRightJoinForFullOuter) {
1352+ for (auto & projection : leftProjections_) {
1353+ auto target = output->childAt (projection.outputChannel );
1354+ target->setNull (row, true );
13311355 }
1332- }
1333-
1334- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1335- fullOuterOutput->childAt (j)->copy (
1336- output->childAt (j).get (), row + 1 , row, 1 );
1337- }
1338-
1339- for (auto i = row + 1 ; i < output->size (); ++i) {
1340- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1341- fullOuterOutput->childAt (j)->copy (
1342- output->childAt (j).get (), i + 1 , i, 1 );
1356+ } else {
1357+ for (auto & projection : rightProjections_) {
1358+ auto target = output->childAt (projection.outputChannel );
1359+ target->setNull (row, true );
13431360 }
13441361 }
1345-
1346- for (auto & projection : leftProjections_) {
1347- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1348- target->setNull (row, true );
1349- }
1350-
1351- for (auto & projection : rightProjections_) {
1352- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1353- target->setNull (row + 1 , true );
1354- }
1355- } else if (!isRightJoin (joinType_)) {
1356- for (auto & projection : rightProjections_) {
1357- auto & target = output->childAt (projection.outputChannel );
1358- target->setNull (row, true );
1359- }
13601362 } else {
13611363 for (auto & projection : leftProjections_) {
1362- auto & target = output->childAt (projection.outputChannel );
1364+ auto target = output->childAt (projection.outputChannel );
13631365 target->setNull (row, true );
13641366 }
13651367 }
13661368 };
13671369
13681370 auto onMatch = [&](auto row, bool firstMatch) {
1369- const bool isNonSemiAntiJoin =
1370- !isSemiFilterJoin (joinType_) && !isAntiJoin (joinType_);
1371+ const bool isFullLeftJoin =
1372+ isFullJoin (joinType_) && !joinTracker_->isRightJoinForFullOuter (row);
1373+
1374+ const bool isNonSemiAntiFullJoin = !isSemiFilterJoin (joinType_) &&
1375+ !isAntiJoin (joinType_) && !isFullJoin (joinType_);
13711376
1372- if ((isSemiFilterJoin (joinType_) && firstMatch) || isNonSemiAntiJoin) {
1377+ if ((isSemiFilterJoin (joinType_) && firstMatch) ||
1378+ isNonSemiAntiFullJoin || isFullLeftJoin) {
13731379 rawIndices[numPassed++] = row;
13741380 }
13751381 };
@@ -1430,17 +1436,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
14301436
14311437 if (numPassed == numRows) {
14321438 // All rows passed.
1433- if (fullOuterOutput) {
1434- return fullOuterOutput;
1435- }
14361439 return output;
14371440 }
14381441
14391442 // Some, but not all rows passed.
1440- if (fullOuterOutput) {
1441- return wrap (numPassed, indices, fullOuterOutput);
1442- }
1443-
14441443 return wrap (numPassed, indices, output);
14451444}
14461445
0 commit comments