Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support IVF search: Part 2 #1984

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/executor/operator/physical_scan/physical_knn_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ void PhysicalKnnScan::ExecuteInternalByColumnDataTypeAndQueryDataType(QueryConte
switch (segment_index_entry->table_index_entry()->index_base()->index_type_) {
case IndexType::kIVF: {
const SegmentOffset max_segment_offset = block_index->GetSegmentOffset(segment_id);
const auto ivf_search_params = IVF_Search_Params::Make(knn_scan_shared_data);
const auto ivf_search_params = IVF_Search_Params::Make(knn_scan_function_data);
auto ivf_result_handler =
GetIVFSearchHandler<t, C, DistanceDataType>(ivf_search_params, use_bitmask, bitmask, max_segment_offset);
ivf_result_handler->Begin();
Expand Down
2 changes: 1 addition & 1 deletion src/function/table/knn_filter.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export class AppendFilter final : public FilterBase<SegmentOffset> {
public:
AppendFilter(SegmentOffset max_segment_offset) : max_segment_offset_(max_segment_offset) {}

bool operator()(const SegmentOffset &segment_offset) const final { return segment_offset <= max_segment_offset_; }
bool operator()(const SegmentOffset &segment_offset) const final { return segment_offset < max_segment_offset_; }

private:
const SegmentOffset max_segment_offset_;
Expand Down
4 changes: 2 additions & 2 deletions src/function/table/knn_scan_data.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ class KnnDistance1 : public KnnDistanceBase1 {
public:
KnnDistance1(KnnDistanceType dist_type);

Vector<DistType> Calculate(const QueryDataType *datas, SizeT data_count, const QueryDataType *query, SizeT dim) {
Vector<DistType> Calculate(const QueryDataType *datas, SizeT data_count, const QueryDataType *query, SizeT dim) const {
Vector<DistType> res(data_count);
for (SizeT i = 0; i < data_count; ++i) {
res[i] = dist_func_(query, datas + i * dim, dim);
}
return res;
}

Vector<DistType> Calculate(const QueryDataType *datas, SizeT data_count, const QueryDataType *query, SizeT dim, Bitmask &bitmask) {
Vector<DistType> Calculate(const QueryDataType *datas, SizeT data_count, const QueryDataType *query, SizeT dim, Bitmask &bitmask) const {
Vector<DistType> res(data_count);
for (SizeT i = 0; i < data_count; ++i) {
if (bitmask.IsTrue(i)) {
Expand Down
1 change: 1 addition & 0 deletions src/storage/knn_index/knn_ivf/ivf_index_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import column_vector;
import logger;
import kmeans_partition;
import logical_type;
import ivf_index_util_func;

namespace infinity {

Expand Down
102 changes: 78 additions & 24 deletions src/storage/knn_index/knn_ivf/ivf_index_data_in_mem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

module;

#include <vector>
module ivf_index_data_in_mem;

import stl;
Expand All @@ -37,26 +38,10 @@ import search_top_1;
import column_vector;
import ivf_index_data;
import buffer_handle;
import knn_scan_data;
import ivf_index_util_func;

namespace infinity {
template <IsAnyOf<u8, i8, f64, f32, Float16T, BFloat16T> ColumnEmbeddingElementT>
Pair<const f32 *, UniquePtr<f32[]>> GetF32Ptr(const ColumnEmbeddingElementT *src_data_ptr, const u32 src_data_cnt) {
Pair<const f32 *, UniquePtr<f32[]>> dst_data_ptr;
if constexpr (std::is_same_v<f32, ColumnEmbeddingElementT>) {
dst_data_ptr.first = src_data_ptr;
} else {
dst_data_ptr.second = MakeUniqueForOverwrite<f32[]>(src_data_cnt);
dst_data_ptr.first = dst_data_ptr.second.get();
for (u32 i = 0; i < src_data_cnt; ++i) {
if constexpr (std::is_same_v<f64, ColumnEmbeddingElementT>) {
dst_data_ptr.second[i] = static_cast<f32>(src_data_ptr[i]);
} else {
dst_data_ptr.second[i] = src_data_ptr[i];
}
}
}
return dst_data_ptr;
}

IVFIndexInMem::IVFIndexInMem(const RowID begin_row_id,
const IndexIVFOption &ivf_option,
Expand Down Expand Up @@ -169,6 +154,7 @@ class IVFIndexInMemT final : public IVFIndexInMem {
}

void BuildIndex() {
LOG_TRACE("Start building in-memory IVF index");
if (have_ivf_index_.test(std::memory_order_acquire)) {
UnrecoverableError("Already have index");
}
Expand Down Expand Up @@ -210,11 +196,78 @@ class IVFIndexInMemT final : public IVFIndexInMem {
return new_chunk_index_entry;
}

void SearchIndexInMem(KnnDistanceType knn_distance_type,
void SearchIndexInMem(const KnnDistanceBase1 *knn_distance,
const void *query_ptr,
EmbeddingDataType query_element_type,
const EmbeddingDataType query_element_type,
std::function<bool(SegmentOffset)> satisfy_filter_func,
std::function<void(f32, SegmentOffset)> add_result_func) const override {
// TODO
auto ReturnT = [&]<EmbeddingDataType query_element_type> {
if constexpr ((query_element_type == EmbeddingDataType::kElemFloat && IsAnyOf<ColumnEmbeddingElementT, f64, f32, Float16T, BFloat16T>) ||
(query_element_type == embedding_data_type &&
(query_element_type == EmbeddingDataType::kElemInt8 || query_element_type == EmbeddingDataType::kElemUInt8))) {
return SearchIndexInMemT<query_element_type>(knn_distance,
static_cast<const EmbeddingDataTypeToCppTypeT<query_element_type> *>(query_ptr),
satisfy_filter_func,
add_result_func);
} else {
UnrecoverableError("Invalid Query EmbeddingDataType");
}
};
switch (query_element_type) {
case EmbeddingDataType::kElemFloat: {
return ReturnT.template operator()<EmbeddingDataType::kElemFloat>();
}
case EmbeddingDataType::kElemUInt8: {
return ReturnT.template operator()<EmbeddingDataType::kElemUInt8>();
}
case EmbeddingDataType::kElemInt8: {
return ReturnT.template operator()<EmbeddingDataType::kElemInt8>();
}
default: {
UnrecoverableError("Invalid EmbeddingDataType");
}
}
}

template <EmbeddingDataType query_element_type>
void SearchIndexInMemT(const KnnDistanceBase1 *knn_distance,
const EmbeddingDataTypeToCppTypeT<query_element_type> *query_ptr,
std::function<bool(SegmentOffset)> satisfy_filter_func,
std::function<void(f32, SegmentOffset)> add_result_func) const {
using QueryDataType = EmbeddingDataTypeToCppTypeT<query_element_type>;
auto knn_distance_1 = dynamic_cast<const KnnDistance1<QueryDataType, f32> *>(knn_distance);
if (!knn_distance_1) [[unlikely]] {
UnrecoverableError("Invalid KnnDistance1");
}
if constexpr (column_logical_type == LogicalType::kEmbedding) {
auto dist_func = knn_distance_1->dist_func_;
for (u32 i = 0; i < in_mem_storage_.source_offsets_.size(); ++i) {
const auto segment_offset = in_mem_storage_.source_offsets_[i];
if (!satisfy_filter_func(segment_offset)) {
continue;
}
auto v_ptr = in_mem_storage_.raw_source_data_.data() + i * embedding_dimension();
auto [calc_ptr, _] = GetSearchCalcPtr<QueryDataType>(v_ptr, embedding_dimension());
auto d = dist_func(calc_ptr, query_ptr, embedding_dimension());
add_result_func(d, segment_offset);
}
} else if constexpr (column_logical_type == LogicalType::kMultiVector) {
for (u32 i = 0; i < in_mem_storage_.source_offsets_.size(); ++i) {
const auto segment_offset = in_mem_storage_.source_offsets_[i];
if (!satisfy_filter_func(segment_offset)) {
continue;
}
auto mv_ptr = in_mem_storage_.raw_source_data_.data() + in_mem_storage_.multi_vector_data_start_pos_[i];
auto mv_num = in_mem_storage_.multi_vector_embedding_num_[i];
auto [calc_ptr, _] = GetSearchCalcPtr<QueryDataType>(mv_ptr, mv_num * embedding_dimension());
auto dists = knn_distance_1->Calculate(calc_ptr, mv_num, query_ptr, embedding_dimension());
for (const auto d : dists) {
add_result_func(d, segment_offset);
}
}
} else {
static_assert(false);
}
}
};

Expand Down Expand Up @@ -267,16 +320,17 @@ SharedPtr<IVFIndexInMem> IVFIndexInMem::NewIVFIndexInMem(const ColumnDef *column
return {};
}

void IVFIndexInMem::SearchIndex(const KnnDistanceType knn_distance_type,
void IVFIndexInMem::SearchIndex(const KnnDistanceBase1 *knn_distance,
const void *query_ptr,
const EmbeddingDataType query_element_type,
const u32 nprobe,
std::function<bool(SegmentOffset)> satisfy_filter_func,
std::function<void(f32, SegmentOffset)> add_result_func) const {
std::shared_lock lock(rw_mutex_);
if (have_ivf_index_.test(std::memory_order_acquire)) {
ivf_index_storage_->SearchIndex(knn_distance_type, query_ptr, query_element_type, nprobe, add_result_func);
ivf_index_storage_->SearchIndex(knn_distance, query_ptr, query_element_type, nprobe, satisfy_filter_func, add_result_func);
} else {
SearchIndexInMem(knn_distance_type, query_ptr, query_element_type, add_result_func);
SearchIndexInMem(knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func);
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/storage/knn_index/knn_ivf/ivf_index_data_in_mem.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import ivf_index_storage;
import column_def;
import logical_type;
import buffer_handle;
import knn_expr;

namespace infinity {

Expand All @@ -32,6 +31,7 @@ class BufferManager;
class ChunkIndexEntry;
class SegmentIndexEntry;
class IndexBase;
class KnnDistanceBase1;

export class IVFIndexInMem {
protected:
Expand Down Expand Up @@ -61,17 +61,19 @@ public:
u32 row_offset,
u32 row_count) = 0;
virtual SharedPtr<ChunkIndexEntry> Dump(SegmentIndexEntry *segment_index_entry, BufferManager *buffer_mgr) = 0;
void SearchIndex(KnnDistanceType knn_distance_type,
void SearchIndex(const KnnDistanceBase1 *knn_distance,
const void *query_ptr,
EmbeddingDataType query_element_type,
u32 nprobe,
std::function<bool(SegmentOffset)> satisfy_filter_func,
std::function<void(f32, SegmentOffset)> add_result_func) const;
static SharedPtr<IVFIndexInMem> NewIVFIndexInMem(const ColumnDef *column_def, const IndexBase *index_base, RowID begin_row_id);

private:
virtual void SearchIndexInMem(KnnDistanceType knn_distance_type,
virtual void SearchIndexInMem(const KnnDistanceBase1 *knn_distance,
const void *query_ptr,
EmbeddingDataType query_element_type,
std::function<bool(SegmentOffset)> satisfy_filter_func,
std::function<void(f32, SegmentOffset)> add_result_func) const = 0;
};

Expand Down
5 changes: 4 additions & 1 deletion src/storage/knn_index/knn_ivf/ivf_index_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ import ivf_index_storage;

namespace infinity {

IVF_Search_Params IVF_Search_Params::Make(const KnnScanSharedData *knn_scan_shared_data) {
IVF_Search_Params IVF_Search_Params::Make(const KnnScanFunctionData *knn_scan_function_data) {
IVF_Search_Params params;
params.knn_distance_ = knn_scan_function_data->knn_distance_.get();
const auto *knn_scan_shared_data = knn_scan_function_data->knn_scan_shared_data_;
params.knn_scan_shared_data_ = knn_scan_shared_data;
if (knn_scan_shared_data->query_count_ != 1) {
RecoverableError(Status::SyntaxError(fmt::format("Invalid query_count: {} which is not 1.", knn_scan_shared_data->query_count_)));
}
Expand Down
31 changes: 23 additions & 8 deletions src/storage/knn_index/knn_ivf/ivf_index_search.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ import search_top_k;
namespace infinity {

export struct IVF_Search_Params {
KnnScanSharedData *knn_scan_shared_data_{};
const KnnDistanceBase1 *knn_distance_{};
const KnnScanSharedData *knn_scan_shared_data_{};
i64 topk_{};
void *query_embedding_{};
const void *query_embedding_{};
EmbeddingDataType query_elem_type_{EmbeddingDataType::kElemInvalid};
KnnDistanceType knn_distance_type_{KnnDistanceType::kInvalid};
i32 nprobe_{1};

static IVF_Search_Params Make(const KnnScanSharedData *knn_scan_shared_data);
static IVF_Search_Params Make(const KnnScanFunctionData *knn_scan_function_data);
};

export template <typename DistanceDataType>
Expand Down Expand Up @@ -81,14 +82,14 @@ template <>
struct IVF_Filter<true> {
BitmaskFilter<SegmentOffset> filter_;
IVF_Filter(const Bitmask &bitmask, const SegmentOffset max_segment_offset) : filter_(bitmask) {}
bool operator()(const SegmentOffset &segment_offset) const { return filter_(segment_offset); }
bool operator()(const SegmentOffset segment_offset) const { return filter_(segment_offset); }
};

template <>
struct IVF_Filter<false> {
AppendFilter filter_;
IVF_Filter(const Bitmask &bitmask, const SegmentOffset max_segment_offset) : filter_(max_segment_offset) {}
bool operator()(const SegmentOffset &segment_offset) const { return filter_(segment_offset); }
bool operator()(const SegmentOffset segment_offset) const { return filter_(segment_offset); }
};

template <LogicalType t,
Expand All @@ -98,6 +99,7 @@ template <LogicalType t,
bool use_bitmask,
typename MultiVectorInnerTopnIndexType = void>
class IVF_Search_HandlerT final : public IVF_Search_Handler<DistanceDataType> {
static_assert(std::is_same_v<DistanceDataType, f32>); // KnnDistanceBase1 type?
static_assert(t == LogicalType::kEmbedding || t == LogicalType::kMultiVector);
static constexpr bool NEED_FLIP = !std::is_same_v<CompareMax<DistanceDataType, SegmentOffset>, C<DistanceDataType, SegmentOffset>>;
using ResultHandler = std::conditional_t<t == LogicalType::kEmbedding,
Expand All @@ -113,20 +115,27 @@ public:
void Begin() override { result_handler_.Begin(); }
void Search(const IVFIndexInChunk *ivf_index_in_chunk) override {
const auto *ivf_index_storage = ivf_index_in_chunk->GetIVFIndexStoragePtr();
ivf_index_storage->SearchIndex(this->ivf_params_.knn_distance_type_,
ivf_index_storage->SearchIndex(this->ivf_params_.knn_distance_,
this->ivf_params_.query_embedding_,
this->ivf_params_.query_elem_type_,
this->ivf_params_.nprobe_,
std::bind(&IVF_Search_HandlerT::SatisfyFilter, this, std::placeholders::_1),
std::bind(&IVF_Search_HandlerT::AddResult, this, std::placeholders::_1, std::placeholders::_2));
}
void Search(const IVFIndexInMem *ivf_index_in_mem) override {
ivf_index_in_mem->SearchIndex(this->ivf_params_.knn_distance_type_,
ivf_index_in_mem->SearchIndex(this->ivf_params_.knn_distance_,
this->ivf_params_.query_embedding_,
this->ivf_params_.query_elem_type_,
this->ivf_params_.nprobe_,
std::bind(&IVF_Search_HandlerT::SatisfyFilter, this, std::placeholders::_1),
std::bind(&IVF_Search_HandlerT::AddResult, this, std::placeholders::_1, std::placeholders::_2));
}
bool SatisfyFilter(SegmentOffset i) { return filter_(i); }
void AddResult(DistanceDataType d, SegmentOffset i) {
assert(SatisfyFilter(i));
if constexpr (NEED_FLIP) {
d = -d;
}
if constexpr (t == LogicalType::kEmbedding) {
result_handler_.AddResult(0, d, i);
} else {
Expand All @@ -136,7 +145,13 @@ public:
}
SizeT EndWithoutSortAndGetResultSize() override {
result_handler_.EndWithoutSort();
return result_handler_.GetSize(0);
const auto result_cnt = result_handler_.GetSize(0);
if constexpr (NEED_FLIP) {
for (u32 i = 0; i < result_cnt; ++i) {
this->distance_output_ptr_[i] = -(this->distance_output_ptr_[i]);
}
}
return result_cnt;
}
};

Expand Down
Loading
Loading