Skip to content

Commit

Permalink
Use BufferPtr for HashProbe::outputTableRows_ (facebookincubator#10864)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#10864

Follow up on
facebookincubator#10832 (comment).
Use `BufferPtr` so that the memory is counted in operator memory pool.

Reviewed By: xiaoxmeng

Differential Revision: D61875530
  • Loading branch information
Yuhta authored and facebook-github-bot committed Aug 28, 2024
1 parent db8875c commit a173cf8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 53 deletions.
103 changes: 51 additions & 52 deletions velox/exec/HashProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ RowTypePtr makeTableType(
// 'result'. Reuses 'result' children where possible.
void extractColumns(
BaseHashTable* table,
folly::Range<char**> rows,
folly::Range<char* const*> rows,
folly::Range<const IdentityProjection*> projections,
memory::MemoryPool* pool,
const std::vector<TypePtr>& resultTypes,
Expand Down Expand Up @@ -108,6 +108,16 @@ SpillPartitionNumSet toPartitionNumSet(
}
return partitionNumSet;
}

template <typename T>
T* initBuffer(BufferPtr& buffer, vector_size_t size, memory::MemoryPool* pool) {
VELOX_CHECK(!buffer || buffer->isMutable());
if (!buffer || buffer->size() < size * sizeof(T)) {
buffer = AlignedBuffer::allocate<T>(size, pool);
}
return buffer->asMutable<T>();
}

} // namespace

HashProbe::HashProbe(
Expand All @@ -132,7 +142,7 @@ HashProbe::HashProbe(
operatorCtx_->driverCtx()->splitGroupId,
planNodeId())),
filterResult_(1),
outputTableRows_(outputBatchSize_) {
outputTableRowsCapacity_(outputBatchSize_) {
VELOX_CHECK_NOT_NULL(joinBridge_);
}

Expand Down Expand Up @@ -723,14 +733,15 @@ void HashProbe::fillLeftSemiProjectMatchColumn(vector_size_t size) {
auto flatMatch = matchColumn()->as<FlatVector<bool>>();
flatMatch->resize(size);
auto rawValues = flatMatch->mutableRawValues<uint64_t>();
auto* outputTableRows = outputTableRows_->as<char*>();
for (auto i = 0; i < size; ++i) {
if (nullAware_) {
// Null-aware join may produce TRUE, FALSE or NULL.
if (filter_) {
if (leftSemiProjectIsNull_.isValid(i)) {
flatMatch->setNull(i, true);
} else {
bool hasMatch = outputTableRows_[i] != nullptr;
const bool hasMatch = outputTableRows[i] != nullptr;
bits::setBit(rawValues, i, hasMatch);
}
} else {
Expand All @@ -739,7 +750,7 @@ void HashProbe::fillLeftSemiProjectMatchColumn(vector_size_t size) {
flatMatch->setNull(i, true);
} else {
// Probe key is not null.
bool hasMatch = outputTableRows_[i] != nullptr;
const bool hasMatch = outputTableRows[i] != nullptr;
if (!hasMatch && buildSideHasNullKeys_) {
flatMatch->setNull(i, true);
} else {
Expand All @@ -748,7 +759,7 @@ void HashProbe::fillLeftSemiProjectMatchColumn(vector_size_t size) {
}
}
} else {
bool hasMatch = outputTableRows_[i] != nullptr;
const bool hasMatch = outputTableRows[i] != nullptr;
bits::setBit(rawValues, i, hasMatch);
}
}
Expand All @@ -773,7 +784,7 @@ void HashProbe::fillOutput(vector_size_t size) {
} else {
extractColumns(
table_.get(),
folly::Range<char**>(outputTableRows_.data(), size),
folly::Range<char* const*>(outputTableRows_->as<char*>(), size),
tableOutputProjections_,
pool(),
outputType_->children(),
Expand All @@ -782,27 +793,28 @@ void HashProbe::fillOutput(vector_size_t size) {
}

RowVectorPtr HashProbe::getBuildSideOutput() {
outputTableRows_.resize(outputBatchSize_);
auto* outputTableRows =
initBuffer<char*>(outputTableRows_, outputTableRowsCapacity_, pool());
int32_t numOut;
if (isRightSemiFilterJoin(joinType_)) {
numOut = table_->listProbedRows(
&lastProbeIterator_,
outputBatchSize_,
outputTableRowsCapacity_,
RowContainer::kUnlimited,
outputTableRows_.data());
outputTableRows);
} else if (isRightSemiProjectJoin(joinType_)) {
numOut = table_->listAllRows(
&lastProbeIterator_,
outputBatchSize_,
outputTableRowsCapacity_,
RowContainer::kUnlimited,
outputTableRows_.data());
outputTableRows);
} else {
// Must be a right join or full join.
numOut = table_->listNotProbedRows(
&lastProbeIterator_,
outputBatchSize_,
outputTableRowsCapacity_,
RowContainer::kUnlimited,
outputTableRows_.data());
outputTableRows);
}
if (numOut == 0) {
return nullptr;
Expand All @@ -818,7 +830,7 @@ RowVectorPtr HashProbe::getBuildSideOutput() {

extractColumns(
table_.get(),
folly::Range<char**>(outputTableRows_.data(), numOut),
folly::Range<char**>(outputTableRows, numOut),
tableOutputProjections_,
pool(),
outputType_->children(),
Expand All @@ -832,7 +844,7 @@ RowVectorPtr HashProbe::getBuildSideOutput() {
matchColumn() = createConstantFalse(numOut, pool());
} else {
table_->rows()->extractProbedFlags(
outputTableRows_.data(),
outputTableRows,
numOut,
nullAware_,
nullAware_ && probeSideHasNullKeys_,
Expand Down Expand Up @@ -986,18 +998,19 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) {
auto outputBatchSize = (isLeftSemiOrAntiJoinNoFilter || emptyBuildSide)
? inputSize
: outputBatchSize_;
auto outputBatchCapacity = outputBatchSize;
outputTableRowsCapacity_ = outputBatchSize;
if (filter_ &&
(isLeftJoin(joinType_) || isFullJoin(joinType_) ||
isAntiJoin(joinType_))) {
// If we need non-matching probe side row, there is a possibility that such
// row exists at end of an input batch and being carried over in the next
// output batch, so we need to make extra room of one row in output.
++outputBatchCapacity;
++outputTableRowsCapacity_;
}
auto mapping = initializeRowNumberMapping(
outputRowMapping_, outputBatchCapacity, pool());
outputTableRows_.resize(outputBatchCapacity);
outputRowMapping_, outputTableRowsCapacity_, pool());
auto* outputTableRows =
initBuffer<char*>(outputTableRows_, outputTableRowsCapacity_, pool());

for (;;) {
int numOut = 0;
Expand All @@ -1006,10 +1019,7 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) {
// When build side is empty, anti and left joins return all probe side
// rows, including ones with null join keys.
std::iota(mapping.begin(), mapping.begin() + inputSize, 0);
std::fill(
outputTableRows_.begin(),
outputTableRows_.begin() + inputSize,
nullptr);
std::fill(outputTableRows, outputTableRows + inputSize, nullptr);
numOut = inputSize;
} else if (isAntiJoin(joinType_) && !filter_) {
if (nullAware_) {
Expand Down Expand Up @@ -1037,7 +1047,7 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) {
*resultIter_,
joinIncludesMissesFromLeft(joinType_),
folly::Range(mapping.data(), outputBatchSize),
folly::Range(outputTableRows_.data(), outputBatchSize),
folly::Range(outputTableRows, outputBatchSize),
operatorCtx_->driverCtx()->queryConfig().preferredOutputBatchBytes());
}

Expand All @@ -1058,7 +1068,7 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) {

if (needLastProbe()) {
// Mark build-side rows that have a match on the join condition.
table_->rows()->setProbedFlag(outputTableRows_.data(), numOut);
table_->rows()->setProbedFlag(outputTableRows, numOut);
}

// Right semi join only returns the build side output when the probe side
Expand Down Expand Up @@ -1115,7 +1125,7 @@ RowVectorPtr HashProbe::createFilterInput(vector_size_t size) {

extractColumns(
table_.get(),
folly::Range<char**>(outputTableRows_.data(), size),
folly::Range<char* const*>(outputTableRows_->as<char*>(), size),
filterTableProjections_,
pool(),
filterInputType_->children(),
Expand Down Expand Up @@ -1314,19 +1324,6 @@ SelectivityVector HashProbe::evalFilterForNullAwareJoin(
return filterPassedRows;
}

namespace {

template <typename T>
T* initBuffer(BufferPtr& buffer, vector_size_t size, memory::MemoryPool* pool) {
VELOX_CHECK(!buffer || buffer->isMutable());
if (!buffer || buffer->size() < size * sizeof(T)) {
buffer = AlignedBuffer::allocate<T>(size, pool);
}
return buffer->asMutable<T>();
}

} // namespace

int32_t HashProbe::evalFilter(int32_t numRows) {
if (!filter_) {
return numRows;
Expand All @@ -1335,6 +1332,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
const bool filterPropagateNulls = filter_->expr(0)->propagatesNulls();
auto* rawOutputProbeRowMapping =
outputRowMapping_->asMutable<vector_size_t>();
auto* outputTableRows = outputTableRows_->asMutable<char*>();

filterInputRows_.resizeFill(numRows);

Expand All @@ -1345,7 +1343,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
// TODO Apply the same to left joins.
if (isAntiJoin(joinType_) || isLeftSemiProjectJoin(joinType_)) {
for (auto i = 0; i < numRows; ++i) {
if (outputTableRows_[i] == nullptr) {
if (outputTableRows[i] == nullptr) {
filterInputRows_.setValid(i, false);
}
}
Expand All @@ -1370,9 +1368,9 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
// for build side.
if (noMatchDetector_.hasLastMissedRow()) {
auto* tempOutputTableRows = initBuffer<char*>(
tempOutputTableRows_, outputTableRows_.size(), pool());
tempOutputTableRows_, outputTableRowsCapacity_, pool());
auto* tempOutputRowMapping = initBuffer<vector_size_t>(
tempOutputRowMapping_, outputTableRows_.size(), pool());
tempOutputRowMapping_, outputTableRowsCapacity_, pool());
auto addMiss = [&](auto row) {
tempOutputTableRows[numPassed] = nullptr;
tempOutputRowMapping[numPassed++] = row;
Expand All @@ -1381,7 +1379,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
const bool passed = filterPassed(i);
noMatchDetector_.advance(rawOutputProbeRowMapping[i], passed, addMiss);
if (passed) {
tempOutputTableRows[numPassed] = outputTableRows_[i];
tempOutputTableRows[numPassed] = outputTableRows[i];
tempOutputRowMapping[numPassed++] = rawOutputProbeRowMapping[i];
}
}
Expand All @@ -1391,21 +1389,21 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
std::copy(
tempOutputTableRows,
tempOutputTableRows + numPassed,
outputTableRows_.data());
outputTableRows);
std::copy(
tempOutputRowMapping,
tempOutputRowMapping + numPassed,
rawOutputProbeRowMapping);
} else {
auto addMiss = [&](auto row) {
outputTableRows_[numPassed] = nullptr;
outputTableRows[numPassed] = nullptr;
rawOutputProbeRowMapping[numPassed++] = row;
};
for (auto i = 0; i < numRows; ++i) {
const bool passed = filterPassed(i);
noMatchDetector_.advance(rawOutputProbeRowMapping[i], passed, addMiss);
if (passed) {
outputTableRows_[numPassed] = outputTableRows_[i];
outputTableRows[numPassed] = outputTableRows[i];
rawOutputProbeRowMapping[numPassed++] = rawOutputProbeRowMapping[i];
}
}
Expand All @@ -1415,7 +1413,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
}
} else if (isLeftSemiFilterJoin(joinType_)) {
auto addLastMatch = [&](auto row) {
outputTableRows_[numPassed] = nullptr;
outputTableRows[numPassed] = nullptr;
rawOutputProbeRowMapping[numPassed++] = row;
};
for (auto i = 0; i < numRows; ++i) {
Expand All @@ -1439,7 +1437,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {

auto addLast = [&](auto row, std::optional<bool> passed) {
if (passed.has_value()) {
outputTableRows_[numPassed] =
outputTableRows[numPassed] =
passed.value() ? const_cast<char*>(kPassed) : nullptr;
} else {
leftSemiProjectIsNull_.setValid(numPassed, true);
Expand All @@ -1466,7 +1464,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
}
} else {
auto addLast = [&](auto row, std::optional<bool> passed) {
outputTableRows_[numPassed] =
outputTableRows[numPassed] =
passed.value() ? const_cast<char*>(kPassed) : nullptr;
rawOutputProbeRowMapping[numPassed++] = row;
};
Expand All @@ -1480,7 +1478,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
}
} else if (isAntiJoin(joinType_)) {
auto addMiss = [&](auto row) {
outputTableRows_[numPassed] = nullptr;
outputTableRows[numPassed] = nullptr;
rawOutputProbeRowMapping[numPassed++] = row;
};
if (nullAware_) {
Expand All @@ -1503,12 +1501,12 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
} else {
for (auto i = 0; i < numRows; ++i) {
if (filterPassed(i)) {
outputTableRows_[numPassed] = outputTableRows_[i];
outputTableRows[numPassed] = outputTableRows[i];
rawOutputProbeRowMapping[numPassed++] = rawOutputProbeRowMapping[i];
}
}
}
VELOX_CHECK_LE(numPassed, outputTableRows_.size());
VELOX_CHECK_LE(numPassed, outputTableRowsCapacity_);
return numPassed;
}

Expand Down Expand Up @@ -1995,6 +1993,7 @@ void HashProbe::close() {
table_.reset();
outputRowMapping_.reset();
tempOutputRowMapping_.reset();
outputTableRows_.reset();
tempOutputTableRows_.reset();
output_.reset();
nonSpillInputIndicesBuffer_.reset();
Expand Down
3 changes: 2 additions & 1 deletion velox/exec/HashProbe.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,8 @@ class HashProbe : public Operator {
std::vector<IdentityProjection> tableOutputProjections_;

// Rows of table found by join probe, later filtered by 'filter_'.
std::vector<char*> outputTableRows_;
BufferPtr outputTableRows_;
vector_size_t outputTableRowsCapacity_;

// For left join with filter, we could overwrite the row which we have not
// checked if there is a carryover. Use a temporary buffer in this case.
Expand Down

0 comments on commit a173cf8

Please sign in to comment.