diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index ffe37eb8dbe9..4013d7a2dee7 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -1526,6 +1526,10 @@ class AbstractJoinNode : public PlanNode { return joinType_ == JoinType::kAnti; } + bool isPreservingProbeOrder() const { + return isInnerJoin() || isLeftJoin() || isAntiJoin(); + } + const std::vector& leftKeys() const { return leftKeys_; } diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index 95f060b78914..defbbcdd8e11 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -986,9 +986,18 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { auto outputBatchSize = (isLeftSemiOrAntiJoinNoFilter || emptyBuildSide) ? inputSize : outputBatchSize_; - auto mapping = - initializeRowNumberMapping(outputRowMapping_, outputBatchSize, pool()); - outputTableRows_.resize(outputBatchSize); + auto outputBatchCapacity = outputBatchSize; + if (filter_ && + (isLeftJoin(joinType_) || isFullJoin(joinType_) || + isAntiJoin(joinType_))) { + // If we need non-matching probe side row, there is a possibility that such + // row exists at end of an input batch and being carried over in the next + // output batch, so we need to make extra room of one row in output. + ++outputBatchCapacity; + } + auto mapping = initializeRowNumberMapping( + outputRowMapping_, outputBatchCapacity, pool()); + outputTableRows_.resize(outputBatchCapacity); for (;;) { int numOut = 0; @@ -996,8 +1005,11 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { if (emptyBuildSide) { // When build side is empty, anti and left joins return all probe side // rows, including ones with null join keys. - std::iota(mapping.begin(), mapping.end(), 0); - std::fill(outputTableRows_.begin(), outputTableRows_.end(), nullptr); + std::iota(mapping.begin(), mapping.begin() + inputSize, 0); + std::fill( + outputTableRows_.begin(), + outputTableRows_.begin() + inputSize, + nullptr); numOut = inputSize; } else if (isAntiJoin(joinType_) && !filter_) { if (nullAware_) { @@ -1024,8 +1036,8 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { numOut = table_->listJoinResults( *resultIter_, joinIncludesMissesFromLeft(joinType_), - mapping, - folly::Range(outputTableRows_.data(), outputTableRows_.size()), + folly::Range(mapping.data(), outputBatchSize), + folly::Range(outputTableRows_.data(), outputBatchSize), operatorCtx_->driverCtx()->queryConfig().preferredOutputBatchBytes()); } @@ -1036,7 +1048,7 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { input_ = nullptr; return nullptr; } - VELOX_CHECK_LE(numOut, outputTableRows_.size()); + VELOX_CHECK_LE(numOut, outputBatchSize); numOut = evalFilter(numOut); @@ -1302,6 +1314,19 @@ SelectivityVector HashProbe::evalFilterForNullAwareJoin( return filterPassedRows; } +namespace { + +template +T* initBuffer(BufferPtr& buffer, vector_size_t size, memory::MemoryPool* pool) { + VELOX_CHECK(!buffer || buffer->isMutable()); + if (!buffer || buffer->size() < size * sizeof(T)) { + buffer = AlignedBuffer::allocate(size, pool); + } + return buffer->asMutable(); +} + +} // namespace + int32_t HashProbe::evalFilter(int32_t numRows) { if (!filter_) { return numRows; @@ -1343,21 +1368,51 @@ int32_t HashProbe::evalFilter(int32_t numRows) { if (isLeftJoin(joinType_) || isFullJoin(joinType_)) { // Identify probe rows which got filtered out and add them back with nulls // for build side. - auto addMiss = [&](auto row) { - outputTableRows_[numPassed] = nullptr; - rawOutputProbeRowMapping[numPassed++] = row; - }; - for (auto i = 0; i < numRows; ++i) { - const bool passed = filterPassed(i); - noMatchDetector_.advance(rawOutputProbeRowMapping[i], passed, addMiss); - if (passed) { - outputTableRows_[numPassed] = outputTableRows_[i]; - rawOutputProbeRowMapping[numPassed++] = rawOutputProbeRowMapping[i]; + if (noMatchDetector_.hasLastMissedRow()) { + auto* tempOutputTableRows = initBuffer( + tempOutputTableRows_, outputTableRows_.size(), pool()); + auto* tempOutputRowMapping = initBuffer( + tempOutputRowMapping_, outputTableRows_.size(), pool()); + auto addMiss = [&](auto row) { + tempOutputTableRows[numPassed] = nullptr; + tempOutputRowMapping[numPassed++] = row; + }; + for (auto i = 0; i < numRows; ++i) { + const bool passed = filterPassed(i); + noMatchDetector_.advance(rawOutputProbeRowMapping[i], passed, addMiss); + if (passed) { + tempOutputTableRows[numPassed] = outputTableRows_[i]; + tempOutputRowMapping[numPassed++] = rawOutputProbeRowMapping[i]; + } + } + if (resultIter_->atEnd()) { + noMatchDetector_.finish(addMiss); + } + std::copy( + tempOutputTableRows, + tempOutputTableRows + numPassed, + outputTableRows_.data()); + std::copy( + tempOutputRowMapping, + tempOutputRowMapping + numPassed, + rawOutputProbeRowMapping); + } else { + auto addMiss = [&](auto row) { + outputTableRows_[numPassed] = nullptr; + rawOutputProbeRowMapping[numPassed++] = row; + }; + for (auto i = 0; i < numRows; ++i) { + const bool passed = filterPassed(i); + noMatchDetector_.advance(rawOutputProbeRowMapping[i], passed, addMiss); + if (passed) { + outputTableRows_[numPassed] = outputTableRows_[i]; + rawOutputProbeRowMapping[numPassed++] = rawOutputProbeRowMapping[i]; + } + } + if (resultIter_->atEnd()) { + noMatchDetector_.finish(addMiss); } } - - noMatchDetector_.finishIteration( - addMiss, resultIter_->atEnd(), outputTableRows_.size() - numPassed); } else if (isLeftSemiFilterJoin(joinType_)) { auto addLastMatch = [&](auto row) { outputTableRows_[numPassed] = nullptr; @@ -1442,9 +1497,9 @@ int32_t HashProbe::evalFilter(int32_t numRows) { noMatchDetector_.advance(probeRow, filterPassed(i), addMiss); } } - - noMatchDetector_.finishIteration( - addMiss, resultIter_->atEnd(), outputTableRows_.size() - numPassed); + if (resultIter_->atEnd()) { + noMatchDetector_.finish(addMiss); + } } else { for (auto i = 0; i < numRows; ++i) { if (filterPassed(i)) { @@ -1453,6 +1508,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) { } } } + VELOX_CHECK_LE(numPassed, outputTableRows_.size()); return numPassed; } @@ -1938,6 +1994,8 @@ void HashProbe::close() { inputSpiller_.reset(); table_.reset(); outputRowMapping_.reset(); + tempOutputRowMapping_.reset(); + tempOutputTableRows_.reset(); output_.reset(); nonSpillInputIndicesBuffer_.reset(); spillInputIndicesBuffers_.clear(); diff --git a/velox/exec/HashProbe.h b/velox/exec/HashProbe.h index 79709e2917f0..ddb310af6bba 100644 --- a/velox/exec/HashProbe.h +++ b/velox/exec/HashProbe.h @@ -430,12 +430,20 @@ class HashProbe : public Operator { // Row number in 'input_' for each output row. BufferPtr outputRowMapping_; + // For left join with filter, we could overwrite the row which we have not + // checked if there is a carryover. Use a temporary buffer in this case. + BufferPtr tempOutputRowMapping_; + // maps from column index in 'table_' to channel in 'output_'. std::vector tableOutputProjections_; // Rows of table found by join probe, later filtered by 'filter_'. std::vector outputTableRows_; + // For left join with filter, we could overwrite the row which we have not + // checked if there is a carryover. Use a temporary buffer in this case. + BufferPtr tempOutputTableRows_; + // Indicates probe-side rows which should produce a NULL in left semi project // with filter. SelectivityVector leftSemiProjectIsNull_; @@ -447,89 +455,40 @@ class HashProbe : public Operator { // Called for each row that the filter was evaluated on. Expects that probe // side rows with multiple matches on the build side are next to each other. template - void advance(vector_size_t row, bool passed, TOnMiss onMiss) { - if (currentRow != row) { - // Check if 'currentRow' is the same input row as the last missed row - // from a previous output batch. If so finishIteration will call - // onMiss. - if (currentRow != -1 && !currentRowPassed && - (!lastMissedRow || currentRow != lastMissedRow)) { - onMiss(currentRow); + void advance(vector_size_t row, bool passed, TOnMiss&& onMiss) { + if (currentRow_ != row) { + if (hasLastMissedRow()) { + onMiss(currentRow_); } - currentRow = row; - currentRowPassed = false; + currentRow_ = row; + currentRowPassed_ = false; } - if (passed) { - // lastMissedRow can only be a row that has never passed the filter. If - // it passes there's no need to continue carrying it forward. - if (lastMissedRow && currentRow == lastMissedRow) { - lastMissedRow.reset(); - } - - currentRowPassed = true; + currentRowPassed_ = true; } } - // Invoked at the end of one output batch processing. 'end' is set to true - // at the end of processing an input batch. 'freeOutputRows' is the number - // of rows that can still be written to the output batch. + // Invoked at the end of all output batches. template - void - finishIteration(TOnMiss onMiss, bool endOfData, size_t freeOutputRows) { - if (endOfData) { - if (!currentRowPassed && currentRow != -1) { - // If we're at the end of the input batch and the current row hasn't - // passed the filter, it never will, process it as a miss. - // We're guaranteed to have space, at least the last row was never - // written out since it was a miss. - onMiss(currentRow); - freeOutputRows--; - } - - // We no longer need to carry the current row since we already called - // onMiss on it. - if (lastMissedRow && currentRow == lastMissedRow) { - lastMissedRow.reset(); - } - - currentRow = -1; - currentRowPassed = false; - } - - // If there's space left in the output batch, write out the last missed - // row. - if (lastMissedRow && currentRow != lastMissedRow && freeOutputRows > 0) { - onMiss(*lastMissedRow); - lastMissedRow.reset(); - } - - // If the current row hasn't passed the filter, we need to carry it - // forward in case it never passes the filter. - if (!currentRowPassed && currentRow != -1) { - lastMissedRow = currentRow; + void finish(TOnMiss&& onMiss) { + if (hasLastMissedRow()) { + onMiss(currentRow_); } + currentRow_ = -1; } // Returns if we're carrying forward a missed input row. Notably, if this is // true, we're not yet done processing the input batch. - bool hasLastMissedRow() { - return lastMissedRow.has_value(); + bool hasLastMissedRow() const { + return currentRow_ != -1 && !currentRowPassed_; } private: // Row number being processed. - vector_size_t currentRow{-1}; + vector_size_t currentRow_{-1}; - // True if currentRow has a match. - bool currentRowPassed{false}; - - // If set, it points to the last missed (input) row carried over from - // previous output batch processing. The last missed row is either written - // as a passed row if the same input row has a hit in the next output batch - // processed or written to the first output batch which has space at - // the end if it never has a hit. - std::optional lastMissedRow; + // True if currentRow_ has a match. + bool currentRowPassed_{false}; }; // For left semi join filter with extra filter, de-duplicates probe side rows diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 37601f0f913c..c07d462e5fd0 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -6627,6 +6627,48 @@ TEST_F(HashJoinTest, leftJoinWithMissAtEndOfBatchMultipleBuildMatches) { test("t_k2 != 4 and t_k2 != 8"); } +TEST_F(HashJoinTest, leftJoinPreserveProbeOrder) { + const std::vector probeVectors = { + makeRowVector( + {"k1", "v1"}, + { + makeConstant(0, 2), + makeFlatVector({1, 0}), + }), + }; + const std::vector buildVectors = { + makeRowVector( + {"k2", "v2"}, + { + makeConstant(0, 2), + makeConstant(0, 2), + }), + }; + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .hashJoin( + {"k1"}, + {"k2"}, + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(), + "v1 % 2 = v2 % 2", + {"v1"}, + core::JoinType::kLeft) + .planNode(); + auto result = AssertQueryBuilder(plan) + .config(core::QueryConfig::kPreferredOutputBatchRows, "1") + .singleThreaded(true) + .copyResults(pool_.get()); + ASSERT_EQ(result->size(), 3); + auto* v1 = + result->childAt(0)->loadedVector()->asUnchecked>(); + ASSERT_FALSE(v1->mayHaveNulls()); + ASSERT_EQ(v1->valueAt(0), 1); + ASSERT_EQ(v1->valueAt(1), 0); + ASSERT_EQ(v1->valueAt(2), 0); +} + DEBUG_ONLY_TEST_F(HashJoinTest, minSpillableMemoryReservation) { constexpr int64_t kMaxBytes = 1LL << 30; // 1GB VectorFuzzer fuzzer({.vectorSize = 1000}, pool());