Skip to content

Commit

Permalink
simplify ResultHandler
Browse files Browse the repository at this point in the history
use HeapResultHandler in Hnsw
  • Loading branch information
yangzq50 committed Dec 19, 2023
1 parent 2519de5 commit c0f8976
Show file tree
Hide file tree
Showing 37 changed files with 1,375 additions and 2,639 deletions.
8 changes: 4 additions & 4 deletions benchmark/embedding/ann_ivfflat_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ static const char *ivfflat_index_name_suffix = "/ivfflat_index.save";
using namespace infinity;

template <typename T>
std::unique_ptr<T[]> load_data(const std::string &filename,
size_t &num,
size_t &dim) { // load data with sift10K pattern
std::unique_ptr<T[]> load_data(const std::string &filename, size_t &num, size_t &dim) {
std::ifstream in(filename, std::ios::binary);
if (!in.is_open()) {
std::cout << "open file error" << std::endl;
exit(-1);
}
in.read((char *)&dim, 4);
int dim_;
in.read((char *)&dim_, 4);
dim = (size_t)dim_;
in.seekg(0, std::ios::end);
auto ss = in.tellg();
num = ((size_t)ss) / (dim + 1) / 4;
Expand Down
17 changes: 14 additions & 3 deletions benchmark/remote_infinity/remote_query_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,13 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, auto fn) {
int main() {
size_t thread_num = 1;
size_t total_times = 1;
size_t ef = 100;
std::cout << "Please input thread_num, 0 means use all resources:" << std::endl;
std::cin >> thread_num;
std::cout << "Please input total_times:" << std::endl;
std::cin >> total_times;
std::cout << "Please input ef:" << std::endl;
std::cin >> ef;

infinity::LocalFileSystem fs;

Expand Down Expand Up @@ -175,12 +178,11 @@ int main() {
for (auto &v : query_results) {
v.reserve(100);
}
auto query_function = [dimension, topk, queries, &query_results](size_t query_idx, InfinityClient &client, size_t threadId) {
int64_t session_id = client.session_id;
auto query_function = [ef, dimension, topk, queries, &query_results](size_t query_idx, InfinityClient &client, size_t threadId) {
SelectRequest req;
SelectResponse ret;
{
req.session_id = session_id;
req.session_id = client.session_id;
req.__isset.session_id = true;
req.db_name = "default";
req.__isset.db_name = true;
Expand Down Expand Up @@ -215,6 +217,15 @@ int main() {
knn_expr.__isset.distance_type = true;
knn_expr.topn = topk;
knn_expr.__isset.topn = true;
InitParameter init_param;
{
init_param.param_name = "ef";
init_param.__isset.param_name = true;
init_param.param_value = std::to_string(ef);
init_param.__isset.param_value = true;
}
knn_expr.opt_params.push_back(std::move(init_param));
knn_expr.__isset.opt_params = true;
}
req.search_expr.knn_exprs.push_back(std::move(knn_expr));
req.search_expr.__isset.knn_exprs = true;
Expand Down
4 changes: 4 additions & 0 deletions src/common/default_values.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ export {
constexpr SizeT HNSW_M = 16;
constexpr SizeT HNSW_EF_CONSTRUCTION = 200;
constexpr SizeT HNSW_EF = 200;

// default distance compute blas parameter
constexpr SizeT DISTANCE_COMPUTE_BLAS_QUERY_BS = 4096;
constexpr SizeT DISTANCE_COMPUTE_BLAS_DATABASE_BS = 1024;
}

// constexpr SizeT DEFAULT_BUFFER_SIZE = 8192;
Expand Down
10 changes: 10 additions & 0 deletions src/common/stl.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ export {
return std::numeric_limits<T>::min();
}

template <typename T>
constexpr T LimitLowest() {
return std::numeric_limits<T>::lowest();
}

template <typename T>
using Atomic = std::atomic<T>;

Expand All @@ -277,6 +282,11 @@ export {
return std::make_unique<T>(std::forward<Args>(args)...);
}

template <typename T, typename... Args>
inline UniquePtr<T> MakeUniqueForOverwrite(Args && ...args) {
return std::make_unique_for_overwrite<T>(std::forward<Args>(args)...);
}

template <typename T, typename U>
inline constexpr Pair<T, U> MakePair(T && first, U && second) {
return std::make_pair<T, U>(std::forward<T>(first), std::forward<U>(second));
Expand Down
97 changes: 79 additions & 18 deletions src/executor/operator/physical_knn_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import column_expression;
import index_base;
import buffer_manager;
import merge_knn;
import faiss;
import knn_result_handler;
import index_def;
import ann_ivf_flat;
import annivfflat_index_data;
Expand Down Expand Up @@ -338,14 +338,14 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
auto index = static_cast<const AnnIVFFlatIndexData<DataType> *>(index_handle.GetData());
i32 n_probes = 1;
auto IVFFlatScan = [&]<typename AnnIVFFlatType>() {
AnnIVFFlatType ann_ivfflat_query{query,
AnnIVFFlatType ann_ivfflat_query(query,
knn_scan_shared_data->query_count_,
knn_scan_shared_data->topk_,
knn_scan_shared_data->dimension_,
knn_scan_shared_data->elem_type_};
knn_scan_shared_data->elem_type_);
ann_ivfflat_query.Begin();
ann_ivfflat_query.Search(index, segment_id, n_probes, bitmask);
ann_ivfflat_query.End();
ann_ivfflat_query.EndWithoutSort();
auto dists = ann_ivfflat_query.GetDistances();
auto row_ids = ann_ivfflat_query.GetIDs();
// TODO: now only work for one query
Expand Down Expand Up @@ -373,7 +373,7 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
case IndexType::kHnsw: {
BufferHandle index_handle = SegmentColumnIndexEntry::GetIndex(segment_column_index_entry, buffer_mgr);
auto index_hnsw = static_cast<IndexHnsw *>(segment_column_index_entry->column_index_entry_->index_base_.get());
auto KnnScan = [&](auto *index) {
auto KnnScanOld = [&](auto *index) {
Vector<DataType> dists(knn_scan_shared_data->topk_ * knn_scan_shared_data->query_count_);
Vector<RowID> row_ids(knn_scan_shared_data->topk_ * knn_scan_shared_data->query_count_);

Expand Down Expand Up @@ -419,6 +419,65 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
}
merge_heap->Search(dists.data(), row_ids.data(), result_n);
};
auto KnnScanUseHeap = [&]<typename LabelType>(auto *index) {
if constexpr (!std::is_same_v<LabelType, u64>) {
Error<ExecutorException>("Bug: Hnsw LabelType must be u64");
}
for (const auto &opt_param : knn_scan_shared_data->opt_params_) {
if (opt_param.param_name_ == "ef") {
u64 ef = std::stoull(opt_param.param_value_);
index->SetEf(ef);
}
}
i64 result_n = -1;
for (u64 query_idx = 0; query_idx < knn_scan_shared_data->query_count_; ++query_idx) {
const DataType *query =
static_cast<const DataType *>(knn_scan_shared_data->query_embedding_) + query_idx * knn_scan_shared_data->dimension_;
auto search_result = index->KnnSearchReturnPair(query, knn_scan_shared_data->topk_, bitmask);
auto &[result_size, unique_ptr_pair] = search_result;
auto &[d_ptr, l_ptr] = unique_ptr_pair;
if (result_n < 0) {
result_n = result_size;
} else if (result_n != (i64)result_size) {
throw ExecutorException("Bug");
}
if (result_size <= 0) {
continue;
}
UniquePtr<RowID[]> row_ids_ptr;
RowID *row_ids = nullptr;
if constexpr (sizeof(RowID) == sizeof(LabelType)) {
row_ids = reinterpret_cast<RowID *>(l_ptr.get());
} else {
row_ids_ptr = MakeUniqueForOverwrite<RowID[]>(result_size);
row_ids = row_ids_ptr.get();
}
for (SizeT i = 0; i < result_size; ++i) {
row_ids[i] = RowID::FromUint64(l_ptr[i]);
}
switch (knn_scan_shared_data->knn_distance_type_) {
case KnnDistanceType::kInvalid: {
throw ExecutorException("Bug");
}
case KnnDistanceType::kL2:
case KnnDistanceType::kHamming: {
break;
}
case KnnDistanceType::kCosine:
case KnnDistanceType::kInnerProduct: {
for (SizeT i = 0; i < result_size; ++i) {
d_ptr[i] = -d_ptr[i];
}
break;
}
}
merge_heap->Search(0, d_ptr.get(), row_ids, result_size);
}
};
auto KnnScan = [&](auto *index) {
using LabelType = typename std::remove_pointer_t<decltype(index)>::HnswLabelType;
KnnScanUseHeap.template operator()<LabelType>(index);
};
switch (index_hnsw->encode_type_) {
case HnswEncodeType::kPlain: {
switch (index_hnsw->metric_type_) {
Expand Down Expand Up @@ -476,20 +535,22 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
merge_heap->End();
i64 result_n = Min(knn_scan_shared_data->topk_, merge_heap->total_count());

UniquePtr<DataBlock> data_block = DataBlock::MakeUniquePtr();
data_block->Init(*GetOutputTypes());
operator_state->data_block_array_.emplace_back(Move(data_block));
SizeT row_idx = DEFAULT_BLOCK_CAPACITY;

SizeT total_data_row_count = knn_scan_shared_data->query_count_ * result_n;
for (; row_idx < total_data_row_count; row_idx += DEFAULT_BLOCK_CAPACITY) {
data_block = DataBlock::MakeUniquePtr();
data_block->Init(*GetOutputTypes());
operator_state->data_block_array_.emplace_back(Move(data_block));
if (!operator_state->data_block_array_.empty()) {
Error<ExecutorException>("In physical_knn_scan : operator_state->data_block_array_ is not empty.");
}
{
SizeT total_data_row_count = knn_scan_shared_data->query_count_ * result_n;
SizeT row_idx = 0;
do {
auto data_block = DataBlock::MakeUniquePtr();
data_block->Init(*GetOutputTypes());
operator_state->data_block_array_.emplace_back(Move(data_block));
row_idx += DEFAULT_BLOCK_CAPACITY;
} while (row_idx < total_data_row_count);
}

SizeT output_block_row_id = 0;
SizeT output_block_idx = operator_state->data_block_array_.size() - 1;
SizeT output_block_idx = 0;
DataBlock *output_block_ptr = operator_state->data_block_array_[output_block_idx].get();
for (u64 query_idx = 0; query_idx < knn_scan_shared_data->query_count_; ++query_idx) {
DataType *result_dists = merge_heap->GetDistancesByIdx(query_idx);
Expand All @@ -505,12 +566,12 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat

BlockEntry *block_entry = block_index->GetBlockEntry(segment_id, block_id);
if (block_entry == nullptr) {
throw ExecutorException(Format("Cannot find block segment id: {}, block id: {}", segment_id, block_id));
Error<ExecutorException>(Format("Cannot find block segment id: {}, block id: {}", segment_id, block_id));
}

if (output_block_row_id == DEFAULT_BLOCK_CAPACITY) {
output_block_ptr->Finalize();
--output_block_idx;
++output_block_idx;
output_block_ptr = operator_state->data_block_array_[output_block_idx].get();
output_block_row_id = 0;
}
Expand Down
2 changes: 1 addition & 1 deletion src/executor/operator/physical_merge_knn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import logger;

import infinity_exception;
import merge_knn_data;
import faiss;
import knn_result_handler;
import merge_knn;
import block_index;
import column_buffer;
Expand Down
2 changes: 1 addition & 1 deletion src/function/table/knn_scan_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import knn_flat_l2_top1;
import knn_flat_l2_top1_blas;

import merge_knn;
import faiss;
import knn_result_handler;
import vector_distance;
import data_block;
import column_vector;
Expand Down
2 changes: 1 addition & 1 deletion src/function/table/merge_knn_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import base_table_ref;

import infinity_exception;
import merge_knn;
import faiss;
import knn_result_handler;

module merge_knn_data;

Expand Down
Loading

0 comments on commit c0f8976

Please sign in to comment.