Skip to content

Commit

Permalink
check for rehash
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPresent-Han committed Nov 13, 2024
1 parent f6b7f6b commit 36ec496
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 62 deletions.
94 changes: 49 additions & 45 deletions internal/core/src/exec/HashTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,25 @@
namespace milvus{
namespace exec {

void populateLookupRows(const TargetBitmapView& activeRows, std::vector<vector_size_t>& lookupRows) {
if (activeRows.all()) {
std::iota(lookupRows.begin(), lookupRows.end(), 0);
} else {
auto start = 0;
do {
auto next_active = activeRows.find_next(start);
if (!next_active.has_value()) break;
auto next_active_row = next_active.value();
lookupRows.emplace_back(next_active_row);
start = next_active_row;
} while(true);
}
}

void BaseHashTable::prepareForGroupProbe(HashLookup& lookup,
const RowVectorPtr& input,
TargetBitmap& activeRows,
bool nullableKeys) {
bool ignoreNullKeys) {
auto& hashers = lookup.hashers_;
int numKeys = hashers.size();
// set up column vector to each column
Expand All @@ -36,7 +51,7 @@ void BaseHashTable::prepareForGroupProbe(HashLookup& lookup,
AssertInfo(column_ptr!=nullptr, "Failed to get column vector from row vector input");
hashers[i]->setColumnData(column_ptr);
// deselect null values
if (!nullableKeys) {
if (!ignoreNullKeys) {
int64_t length = column_ptr->size();
TargetBitmapView valid_bits_view(column_ptr->GetValidRawData(), length);
activeRows&=valid_bits_view;
Expand All @@ -47,11 +62,12 @@ void BaseHashTable::prepareForGroupProbe(HashLookup& lookup,
const auto mode = hashMode();
for (auto i = 0; i < hashers.size(); i++) {
if (mode == BaseHashTable::HashMode::kHash) {
hashers[i]->hash(i > 0, lookup.hashes_);
hashers[i]->hash(i > 0, activeRows, lookup.hashes_);
} else {
PanicInfo(milvus::OpTypeInvalid, "Not support target hashMode, only support kHash for now");
}
}
}
populateLookupRows(activeRows, lookup.rows_);
}

class ProbeState {
Expand Down Expand Up @@ -98,7 +114,6 @@ class ProbeState {
template<Operation op, typename Compare, typename Insert, typename Table>
inline char* fullProbe(Table& table, int32_t firstKey,
Compare compare, Insert insert,
int64_t& numTombstones,
bool extraCheck) {
AssertInfo(op == Operation::kInsert, "Only support insert operation for group cases");
if (group_ && compare(group_, row_)) {
Expand All @@ -111,46 +126,42 @@ class ProbeState {
hits_ = milvus::toBitMask(tagsInTable_ == wantedTags_);
}

const int64_t startBucketOffset = bucketOffset_;
int64_t insertBucketOffset = -1;
const auto kEmptyGroup = BaseHashTable::TagVector::broadcast(0);
const auto kTombstoneGroup = BaseHashTable::TagVector::broadcast(kTombstoneTag);
for(int64_t numProbedBuckets = 0; numProbedBuckets < table.numBuckets(); ++numProbedBuckets) {
while(hits_ > 0) {
while (hits_ > 0) {
loadNextHit<op>(table, firstKey);
if (!(extraCheck && group_ == alreadyChecked) && compare(group_, row_)) {
return group_;
}
}
}

uint16_t empty = milvus::toBitMask(tagsInTable_ == kEmptyGroup) & kFullMask;
// if there are still empty slot available, try to insert into existing empty slot or tombstone slot
if (empty > 0) {
if (op == ProbeState::Operation::kProbe) {
return nullptr;
}
if (indexInTags_ != kNotSet) {
// We came to the end of the probe without a hit. We replace the first
// tombstone on the way.
--numTombstones;
return insert(row_, insertBucketOffset + indexInTags_);
uint16_t empty = milvus::toBitMask(tagsInTable_ == kEmptyGroup) & kFullMask;
// if there are still empty slot available, try to insert into existing empty slot or tombstone slot
if (empty > 0) {
if (op == ProbeState::Operation::kProbe) {
return nullptr;
}
if (indexInTags_ != kNotSet) {
return insert(row_, insertBucketOffset + indexInTags_);
}
auto pos = milvus::bits::getAndClearLastSetBit(empty);
return insert(row_, bucketOffset_ + pos);
}
auto pos = milvus::bits::getAndClearLastSetBit(empty);
return insert(row_, bucketOffset_ + pos);
}
if (op == Operation::kInsert && indexInTags_ == kNotSet) {
// We passed through a full group.
uint16_t tombstones =
milvus::toBitMask(tagsInTable_ == kTombstoneGroup) & kFullMask;
if (tombstones > 0) {
insertBucketOffset = bucketOffset_;
indexInTags_ = milvus::bits::getAndClearLastSetBit(tombstones);
if (op == Operation::kInsert && indexInTags_ == kNotSet) {
// We passed through a full group.
uint16_t tombstones =
milvus::toBitMask(tagsInTable_ == kTombstoneGroup) & kFullMask;
if (tombstones > 0) {
insertBucketOffset = bucketOffset_;
indexInTags_ = milvus::bits::getAndClearLastSetBit(tombstones);
}
}
bucketOffset_ = table.nextBucketOffset(bucketOffset_);
tagsInTable_ = table.loadTags(bucketOffset_);
hits_ = milvus::toBitMask(tagsInTable_ == wantedTags_);
}
bucketOffset_ = table.nextBucketOffset(bucketOffset_);
tagsInTable_ = table.loadTags(bucketOffset_);
hits_ = milvus::toBitMask(tagsInTable_ == wantedTags_);
}


Expand Down Expand Up @@ -183,9 +194,8 @@ void HashTable<nullableKeys>::allocateTables(uint64_t size) {
const uint64_t byteSize = capacity_ * tableSlotSize();
AssertInfo(byteSize % kBucketSize == 0, "byteSize:{} for hashTable must be a multiple of kBucketSize:{}",
byteSize, kBucketSize);
numTombstones_ = 0;
sizeMask_ = byteSize - 1;
numBuckets_ = byteSize / kBucketSize;
sizeMask_ = byteSize - 1;
sizeBits_ = __builtin_popcountll(sizeMask_);
bucketOffsetMask_ = sizeMask_ & ~(kBucketSize - 1);
// The total size is 8 bytes per slot, in groups of 16 slots with 16 bytes of
Expand All @@ -198,18 +208,17 @@ void HashTable<nullableKeys>::allocateTables(uint64_t size) {

template<bool nullableKeys>
void HashTable<nullableKeys>::checkSize(int32_t numNew, bool initNormalizedKeys) {
AssertInfo(capacity_ == 0 || capacity_ > (numDistinct_ + numTombstones_),
"size {}, numDistinct {}, numTombstoneRows {}",
AssertInfo(capacity_ == 0 || capacity_ > numDistinct_,
"capacity_ {}, numDistinct {}",
capacity_,
numDistinct_,
numTombstones_);
numDistinct_);
const int64_t newNumDistinct = numNew + numDistinct_;
if (table_ == nullptr || capacity_ == 0) {
const auto newSize = newHashTableEntriesNumber(numDistinct_, numNew);
allocateTables(newSize);
} else if (newNumDistinct > rehashSize()) {
const auto newCapacity =
milvus::bits::nextPowerOfTwo(std::max(newNumDistinct, capacity_ - numTombstones_) + 1);
milvus::bits::nextPowerOfTwo(std::max(newNumDistinct, capacity_) + 1);
allocateTables(newCapacity);
}
}
Expand Down Expand Up @@ -239,10 +248,6 @@ void HashTable<nullableKeys>::storeKeys(milvus::exec::HashLookup &lookup, milvus

template<bool nullableKeys>
void HashTable<nullableKeys>::storeRowPointer(uint64_t index, uint64_t hash, char *row) {
if (hashMode_==HashMode::kArray) {
reinterpret_cast<char**>(table_)[index] = row;
return;
}
const int64_t bktOffset = bucketOffset(index);
auto* bucket = bucketAt(bktOffset);
const auto slotIndex = index & (sizeof(TagVector) - 1);
Expand Down Expand Up @@ -274,15 +279,14 @@ FOLLY_ALWAYS_INLINE void HashTable<nullableKeys>::fullProbe(HashLookup &lookup,
[&](int32_t row, uint64_t index) {
return isJoin? nullptr: insertEntry(lookup, index, row);
},
numTombstones_,
!isJoin && extraCheck);

}

template<bool nullableKeys>
void HashTable<nullableKeys>::groupProbe(milvus::exec::HashLookup &lookup) {
AssertInfo(hashMode_ == HashMode::kHash, "Only support kHash mode for now");
checkSize(lookup.rows_.size(), false); // hc---
checkSize(lookup.rows_.size(), false);
ProbeState state1;
ProbeState state2;
ProbeState state3;
Expand Down
4 changes: 1 addition & 3 deletions internal/core/src/exec/HashTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class HashTable : public BaseHashTable {
}

uint64_t rehashSize() const {
return rehashSize(capacity_ - numTombstones_);
return rehashSize(capacity_);
}

static uint64_t newHashTableEntriesNumber(uint64_t numDistinct, uint64_t numNew) {
Expand All @@ -278,8 +278,6 @@ class HashTable : public BaseHashTable {
int64_t bucketOffsetMask_{0};
int64_t numBuckets_{0};
int64_t numDistinct_{0};
// Counts the number of tombstone table slots.
int64_t numTombstones_{0};

// Number of slots across all buckets.
int64_t capacity_{0};
Expand Down
30 changes: 18 additions & 12 deletions internal/core/src/exec/VectorHasher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,35 +34,41 @@ std::vector<std::unique_ptr<VectorHasher>> createVectorHashers(
}

template<DataType Type>
void VectorHasher::hashValues(const ColumnVectorPtr& column_data, bool mix, uint64_t* result) {
void VectorHasher::hashValues(const ColumnVectorPtr& column_data, const TargetBitmapView& activeRows, bool mix, uint64_t* result) {
if constexpr (Type==DataType::ROW || Type==DataType::ARRAY || Type==DataType::JSON) {
PanicInfo(milvus::DataTypeInvalid, "NotSupport hash for complext type row/array/json:{}", Type);
} else {
using T = typename TypeTraits<Type>::NativeType;
auto element_data_type = ChannelDataType();
auto element_size = GetDataTypeSize(element_data_type);
auto element_count = column_data->size();
for(auto i = 0; i < element_count; i++) {
void *raw_value = column_data->RawValueAt(i, element_size);
AssertInfo(raw_value != nullptr, "Failed to get raw value pointer from column data");
if (!column_data->ValidAt(i)) {
result[i] = kNullHash;
auto start = 0;
do {
auto next_valid_op = activeRows.find_next(start);
if (!next_valid_op.has_value()){
break;
}
auto next_valid_row = next_valid_op.value();
if (!column_data->ValidAt(next_valid_row)) {
result[next_valid_row] = mix? milvus::bits::hashMix(result[next_valid_row], kNullHash): kNullHash;
continue;
}
auto value = static_cast<T *>(raw_value);
void* raw_value = column_data->RawValueAt(next_valid_row, element_size);
AssertInfo(raw_value != nullptr, "Failed to get raw value pointer from column data");
auto* value = static_cast<T*>(raw_value);
uint64_t hash_value = kNullHash;
if constexpr (std::is_floating_point_v<T>) {
hash_value = milvus::NaNAwareHash<T>()(*value);
} else {
hash_value = folly::hasher<T>()(*value);
}
result[i] = mix? milvus::bits::hashMix(result[i], hash_value) : hash_value;
}
result[next_valid_row] = mix? milvus::bits::hashMix(result[next_valid_row], hash_value) : hash_value;
start = next_valid_row;
} while(true);
}
}

void
VectorHasher::hash(bool mix, std::vector<uint64_t>& result) {
VectorHasher::hash(bool mix, const TargetBitmapView& activeRows, std::vector<uint64_t>& result) {

// auto element_size = GetDataTypeSize(element_data_type);
// auto element_count = column_data->size();
Expand All @@ -71,7 +77,7 @@ VectorHasher::hash(bool mix, std::vector<uint64_t>& result) {
// void* raw_value = column_data->RawValueAt(i, element_size);
// }
auto element_data_type = ChannelDataType();
MILVUS_DYNAMIC_TYPE_DISPATCH(hashValues, element_data_type, columnData(), mix, result.data());
MILVUS_DYNAMIC_TYPE_DISPATCH(hashValues, element_data_type, columnData(), activeRows, mix, result.data());
//PanicInfo(DataTypeInvalid, "Unsupported data type for dispatch");
}

Expand Down
4 changes: 2 additions & 2 deletions internal/core/src/exec/VectorHasher.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class VectorHasher{
}

void
hash(bool mix, std::vector<uint64_t>& result);
hash(bool mix, const TargetBitmapView& activeRows, std::vector<uint64_t>& result);

static constexpr uint64_t kNullHash = 1;

Expand All @@ -61,7 +61,7 @@ static bool typeSupportValueIds(DataType type) {
}

template<DataType type>
void hashValues(const ColumnVectorPtr& column_data, bool mix, uint64_t* result);
void hashValues(const ColumnVectorPtr& column_data, const TargetBitmapView& activeRows, bool mix, uint64_t* result);

void setColumnData(const ColumnVectorPtr& column_data) {
column_data_ = column_data;
Expand Down

0 comments on commit 36ec496

Please sign in to comment.