diff --git a/internal/core/src/exec/HashTable.cpp b/internal/core/src/exec/HashTable.cpp index 6b25a17bf8c06..dfac68da1c878 100644 --- a/internal/core/src/exec/HashTable.cpp +++ b/internal/core/src/exec/HashTable.cpp @@ -22,10 +22,25 @@ namespace milvus{ namespace exec { +void populateLookupRows(const TargetBitmapView& activeRows, std::vector& lookupRows) { + if (activeRows.all()) { + std::iota(lookupRows.begin(), lookupRows.end(), 0); + } else { + auto start = 0; + do { + auto next_active = activeRows.find_next(start); + if (!next_active.has_value()) break; + auto next_active_row = next_active.value(); + lookupRows.emplace_back(next_active_row); + start = next_active_row; + } while(true); + } +} + void BaseHashTable::prepareForGroupProbe(HashLookup& lookup, const RowVectorPtr& input, TargetBitmap& activeRows, - bool nullableKeys) { + bool ignoreNullKeys) { auto& hashers = lookup.hashers_; int numKeys = hashers.size(); // set up column vector to each column @@ -36,7 +51,7 @@ void BaseHashTable::prepareForGroupProbe(HashLookup& lookup, AssertInfo(column_ptr!=nullptr, "Failed to get column vector from row vector input"); hashers[i]->setColumnData(column_ptr); // deselect null values - if (!nullableKeys) { + if (!ignoreNullKeys) { int64_t length = column_ptr->size(); TargetBitmapView valid_bits_view(column_ptr->GetValidRawData(), length); activeRows&=valid_bits_view; @@ -47,11 +62,12 @@ void BaseHashTable::prepareForGroupProbe(HashLookup& lookup, const auto mode = hashMode(); for (auto i = 0; i < hashers.size(); i++) { if (mode == BaseHashTable::HashMode::kHash) { - hashers[i]->hash(i > 0, lookup.hashes_); + hashers[i]->hash(i > 0, activeRows, lookup.hashes_); } else { PanicInfo(milvus::OpTypeInvalid, "Not support target hashMode, only support kHash for now"); } - } + } + populateLookupRows(activeRows, lookup.rows_); } class ProbeState { @@ -98,7 +114,6 @@ class ProbeState { template inline char* fullProbe(Table& table, int32_t firstKey, Compare compare, Insert insert, - int64_t& numTombstones, bool extraCheck) { AssertInfo(op == Operation::kInsert, "Only support insert operation for group cases"); if (group_ && compare(group_, row_)) { @@ -111,46 +126,42 @@ class ProbeState { hits_ = milvus::toBitMask(tagsInTable_ == wantedTags_); } - const int64_t startBucketOffset = bucketOffset_; int64_t insertBucketOffset = -1; const auto kEmptyGroup = BaseHashTable::TagVector::broadcast(0); const auto kTombstoneGroup = BaseHashTable::TagVector::broadcast(kTombstoneTag); for(int64_t numProbedBuckets = 0; numProbedBuckets < table.numBuckets(); ++numProbedBuckets) { - while(hits_ > 0) { + while (hits_ > 0) { loadNextHit(table, firstKey); if (!(extraCheck && group_ == alreadyChecked) && compare(group_, row_)) { return group_; } } - } - uint16_t empty = milvus::toBitMask(tagsInTable_ == kEmptyGroup) & kFullMask; - // if there are still empty slot available, try to insert into existing empty slot or tombstone slot - if (empty > 0) { - if (op == ProbeState::Operation::kProbe) { - return nullptr; - } - if (indexInTags_ != kNotSet) { - // We came to the end of the probe without a hit. We replace the first - // tombstone on the way. - --numTombstones; - return insert(row_, insertBucketOffset + indexInTags_); + uint16_t empty = milvus::toBitMask(tagsInTable_ == kEmptyGroup) & kFullMask; + // if there are still empty slot available, try to insert into existing empty slot or tombstone slot + if (empty > 0) { + if (op == ProbeState::Operation::kProbe) { + return nullptr; + } + if (indexInTags_ != kNotSet) { + return insert(row_, insertBucketOffset + indexInTags_); + } + auto pos = milvus::bits::getAndClearLastSetBit(empty); + return insert(row_, bucketOffset_ + pos); } - auto pos = milvus::bits::getAndClearLastSetBit(empty); - return insert(row_, bucketOffset_ + pos); - } - if (op == Operation::kInsert && indexInTags_ == kNotSet) { - // We passed through a full group. - uint16_t tombstones = - milvus::toBitMask(tagsInTable_ == kTombstoneGroup) & kFullMask; - if (tombstones > 0) { - insertBucketOffset = bucketOffset_; - indexInTags_ = milvus::bits::getAndClearLastSetBit(tombstones); + if (op == Operation::kInsert && indexInTags_ == kNotSet) { + // We passed through a full group. + uint16_t tombstones = + milvus::toBitMask(tagsInTable_ == kTombstoneGroup) & kFullMask; + if (tombstones > 0) { + insertBucketOffset = bucketOffset_; + indexInTags_ = milvus::bits::getAndClearLastSetBit(tombstones); + } } + bucketOffset_ = table.nextBucketOffset(bucketOffset_); + tagsInTable_ = table.loadTags(bucketOffset_); + hits_ = milvus::toBitMask(tagsInTable_ == wantedTags_); } - bucketOffset_ = table.nextBucketOffset(bucketOffset_); - tagsInTable_ = table.loadTags(bucketOffset_); - hits_ = milvus::toBitMask(tagsInTable_ == wantedTags_); } @@ -183,9 +194,8 @@ void HashTable::allocateTables(uint64_t size) { const uint64_t byteSize = capacity_ * tableSlotSize(); AssertInfo(byteSize % kBucketSize == 0, "byteSize:{} for hashTable must be a multiple of kBucketSize:{}", byteSize, kBucketSize); - numTombstones_ = 0; - sizeMask_ = byteSize - 1; numBuckets_ = byteSize / kBucketSize; + sizeMask_ = byteSize - 1; sizeBits_ = __builtin_popcountll(sizeMask_); bucketOffsetMask_ = sizeMask_ & ~(kBucketSize - 1); // The total size is 8 bytes per slot, in groups of 16 slots with 16 bytes of @@ -198,18 +208,17 @@ void HashTable::allocateTables(uint64_t size) { template void HashTable::checkSize(int32_t numNew, bool initNormalizedKeys) { - AssertInfo(capacity_ == 0 || capacity_ > (numDistinct_ + numTombstones_), - "size {}, numDistinct {}, numTombstoneRows {}", + AssertInfo(capacity_ == 0 || capacity_ > numDistinct_, + "capacity_ {}, numDistinct {}", capacity_, - numDistinct_, - numTombstones_); + numDistinct_); const int64_t newNumDistinct = numNew + numDistinct_; if (table_ == nullptr || capacity_ == 0) { const auto newSize = newHashTableEntriesNumber(numDistinct_, numNew); allocateTables(newSize); } else if (newNumDistinct > rehashSize()) { const auto newCapacity = - milvus::bits::nextPowerOfTwo(std::max(newNumDistinct, capacity_ - numTombstones_) + 1); + milvus::bits::nextPowerOfTwo(std::max(newNumDistinct, capacity_) + 1); allocateTables(newCapacity); } } @@ -239,10 +248,6 @@ void HashTable::storeKeys(milvus::exec::HashLookup &lookup, milvus template void HashTable::storeRowPointer(uint64_t index, uint64_t hash, char *row) { - if (hashMode_==HashMode::kArray) { - reinterpret_cast(table_)[index] = row; - return; - } const int64_t bktOffset = bucketOffset(index); auto* bucket = bucketAt(bktOffset); const auto slotIndex = index & (sizeof(TagVector) - 1); @@ -274,7 +279,6 @@ FOLLY_ALWAYS_INLINE void HashTable::fullProbe(HashLookup &lookup, [&](int32_t row, uint64_t index) { return isJoin? nullptr: insertEntry(lookup, index, row); }, - numTombstones_, !isJoin && extraCheck); } @@ -282,7 +286,7 @@ FOLLY_ALWAYS_INLINE void HashTable::fullProbe(HashLookup &lookup, template void HashTable::groupProbe(milvus::exec::HashLookup &lookup) { AssertInfo(hashMode_ == HashMode::kHash, "Only support kHash mode for now"); - checkSize(lookup.rows_.size(), false); // hc--- + checkSize(lookup.rows_.size(), false); ProbeState state1; ProbeState state2; ProbeState state3; diff --git a/internal/core/src/exec/HashTable.h b/internal/core/src/exec/HashTable.h index 7bc54ab1744f8..5e7d336a42edf 100644 --- a/internal/core/src/exec/HashTable.h +++ b/internal/core/src/exec/HashTable.h @@ -261,7 +261,7 @@ class HashTable : public BaseHashTable { } uint64_t rehashSize() const { - return rehashSize(capacity_ - numTombstones_); + return rehashSize(capacity_); } static uint64_t newHashTableEntriesNumber(uint64_t numDistinct, uint64_t numNew) { @@ -278,8 +278,6 @@ class HashTable : public BaseHashTable { int64_t bucketOffsetMask_{0}; int64_t numBuckets_{0}; int64_t numDistinct_{0}; - // Counts the number of tombstone table slots. - int64_t numTombstones_{0}; // Number of slots across all buckets. int64_t capacity_{0}; diff --git a/internal/core/src/exec/VectorHasher.cpp b/internal/core/src/exec/VectorHasher.cpp index 5aec8980b27d2..326222468071b 100644 --- a/internal/core/src/exec/VectorHasher.cpp +++ b/internal/core/src/exec/VectorHasher.cpp @@ -34,35 +34,41 @@ std::vector> createVectorHashers( } template -void VectorHasher::hashValues(const ColumnVectorPtr& column_data, bool mix, uint64_t* result) { +void VectorHasher::hashValues(const ColumnVectorPtr& column_data, const TargetBitmapView& activeRows, bool mix, uint64_t* result) { if constexpr (Type==DataType::ROW || Type==DataType::ARRAY || Type==DataType::JSON) { PanicInfo(milvus::DataTypeInvalid, "NotSupport hash for complext type row/array/json:{}", Type); } else { using T = typename TypeTraits::NativeType; auto element_data_type = ChannelDataType(); auto element_size = GetDataTypeSize(element_data_type); - auto element_count = column_data->size(); - for(auto i = 0; i < element_count; i++) { - void *raw_value = column_data->RawValueAt(i, element_size); - AssertInfo(raw_value != nullptr, "Failed to get raw value pointer from column data"); - if (!column_data->ValidAt(i)) { - result[i] = kNullHash; + auto start = 0; + do { + auto next_valid_op = activeRows.find_next(start); + if (!next_valid_op.has_value()){ + break; + } + auto next_valid_row = next_valid_op.value(); + if (!column_data->ValidAt(next_valid_row)) { + result[next_valid_row] = mix? milvus::bits::hashMix(result[next_valid_row], kNullHash): kNullHash; continue; } - auto value = static_cast(raw_value); + void* raw_value = column_data->RawValueAt(next_valid_row, element_size); + AssertInfo(raw_value != nullptr, "Failed to get raw value pointer from column data"); + auto* value = static_cast(raw_value); uint64_t hash_value = kNullHash; if constexpr (std::is_floating_point_v) { hash_value = milvus::NaNAwareHash()(*value); } else { hash_value = folly::hasher()(*value); } - result[i] = mix? milvus::bits::hashMix(result[i], hash_value) : hash_value; - } + result[next_valid_row] = mix? milvus::bits::hashMix(result[next_valid_row], hash_value) : hash_value; + start = next_valid_row; + } while(true); } } void -VectorHasher::hash(bool mix, std::vector& result) { +VectorHasher::hash(bool mix, const TargetBitmapView& activeRows, std::vector& result) { // auto element_size = GetDataTypeSize(element_data_type); // auto element_count = column_data->size(); @@ -71,7 +77,7 @@ VectorHasher::hash(bool mix, std::vector& result) { // void* raw_value = column_data->RawValueAt(i, element_size); // } auto element_data_type = ChannelDataType(); - MILVUS_DYNAMIC_TYPE_DISPATCH(hashValues, element_data_type, columnData(), mix, result.data()); + MILVUS_DYNAMIC_TYPE_DISPATCH(hashValues, element_data_type, columnData(), activeRows, mix, result.data()); //PanicInfo(DataTypeInvalid, "Unsupported data type for dispatch"); } diff --git a/internal/core/src/exec/VectorHasher.h b/internal/core/src/exec/VectorHasher.h index 7f12e1ffe0825..8c05fd13281d2 100644 --- a/internal/core/src/exec/VectorHasher.h +++ b/internal/core/src/exec/VectorHasher.h @@ -41,7 +41,7 @@ class VectorHasher{ } void - hash(bool mix, std::vector& result); + hash(bool mix, const TargetBitmapView& activeRows, std::vector& result); static constexpr uint64_t kNullHash = 1; @@ -61,7 +61,7 @@ static bool typeSupportValueIds(DataType type) { } template -void hashValues(const ColumnVectorPtr& column_data, bool mix, uint64_t* result); +void hashValues(const ColumnVectorPtr& column_data, const TargetBitmapView& activeRows, bool mix, uint64_t* result); void setColumnData(const ColumnVectorPtr& column_data) { column_data_ = column_data;