Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: bm25 brute force search need index params k1 and b #37721

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions internal/core/src/query/SearchBruteForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
}

knowhere::Json
PrepareBFSearchParams(const SearchInfo& search_info) {
PrepareBFSearchParams(const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info) {
knowhere::Json search_cfg = search_info.search_params_;

search_cfg[knowhere::meta::METRIC_TYPE] = search_info.metric_type_;
Expand All @@ -62,6 +63,10 @@
if (search_info.metric_type_ == knowhere::metric::BM25) {
search_cfg[knowhere::meta::BM25_AVGDL] =
search_info.search_params_[knowhere::meta::BM25_AVGDL];
search_cfg[knowhere::meta::BM25_K1] =
std::stof(index_info.at(knowhere::meta::BM25_K1));
search_cfg[knowhere::meta::BM25_B] =
std::stof(index_info.at(knowhere::meta::BM25_B));

Check warning on line 69 in internal/core/src/query/SearchBruteForce.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/query/SearchBruteForce.cpp#L66-L69

Added lines #L66 - L69 were not covered by tests
}
return search_cfg;
}
Expand All @@ -71,6 +76,7 @@
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
SubSearchResult sub_result(dataset.num_queries,
Expand All @@ -87,12 +93,11 @@
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
}
auto search_cfg = PrepareBFSearchParams(search_info);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
// `range_search_k` is only used as one of the conditions for iterator early termination.
// not gurantee to return exactly `range_search_k` results, which may be more or less.
// set it to -1 will return all results in the range.
search_cfg[knowhere::meta::RANGE_SEARCH_K] = topk;

sub_result.mutable_seg_offsets().resize(nq * topk);
sub_result.mutable_distances().resize(nq * topk);

Expand Down Expand Up @@ -201,6 +206,7 @@
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
auto nq = dataset.num_queries;
Expand All @@ -211,7 +217,7 @@
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
}
auto search_cfg = PrepareBFSearchParams(search_info);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);

knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
iterators_val;
Expand Down
2 changes: 2 additions & 0 deletions internal/core/src/query/SearchBruteForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);

Expand All @@ -36,6 +37,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);

Expand Down
15 changes: 15 additions & 0 deletions internal/core/src/query/SearchOnGrowing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
#include "common/Tracer.h"
#include "common/Types.h"
#include "SearchOnGrowing.h"
#include <cstddef>
#include "knowhere/comp/index_param.h"
#include "knowhere/config.h"
#include "log/Log.h"
#include "query/SearchBruteForce.h"
#include "query/SearchOnIndex.h"

Expand Down Expand Up @@ -109,6 +113,15 @@
dataset::SearchDataset search_dataset{
metric_type, num_queries, topk, round_decimal, dim, query_data};
int32_t current_chunk_id = 0;

// get K1 and B from index for bm25 brute force
std::map<std::string, std::string> index_info;
if (metric_type == knowhere::metric::BM25) {
index_info = segment.get_indexing_record()

Check warning on line 120 in internal/core/src/query/SearchOnGrowing.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/query/SearchOnGrowing.cpp#L120

Added line #L120 was not covered by tests
.get_field_index_meta(vecfield_id)
.GetIndexParams();
}

// step 3: brute force search where small indexing is unavailable
auto vec_ptr = record.get_data_base(vecfield_id);
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
Expand All @@ -129,6 +142,7 @@
chunk_data,
size_per_chunk,
info,
index_info,
sub_view,
data_type);
final_qr.merge(sub_qr);
Expand All @@ -137,6 +151,7 @@
chunk_data,
size_per_chunk,
info,
index_info,
sub_view,
data_type);

Expand Down
22 changes: 18 additions & 4 deletions internal/core/src/query/SearchOnSealed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ void
SearchOnSealed(const Schema& schema,
std::shared_ptr<ChunkedColumnBase> column,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,
Expand Down Expand Up @@ -137,6 +138,7 @@ SearchOnSealed(const Schema& schema,
vec_data,
chunk_size,
search_info,
index_info,
bitset_view,
data_type);
final_qr.merge(sub_qr);
Expand All @@ -145,6 +147,7 @@ SearchOnSealed(const Schema& schema,
vec_data,
chunk_size,
search_info,
index_info,
bitset_view,
data_type);
for (auto& o : sub_qr.mutable_seg_offsets()) {
Expand Down Expand Up @@ -177,6 +180,7 @@ void
SearchOnSealed(const Schema& schema,
const void* vec_data,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,
Expand All @@ -200,13 +204,23 @@ SearchOnSealed(const Schema& schema,
auto data_type = field.get_data_type();
CheckBruteForceSearchParam(field, search_info);
if (search_info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(
dataset, vec_data, row_count, search_info, bitset, data_type);
auto sub_qr = BruteForceSearchIterators(dataset,
vec_data,
row_count,
search_info,
index_info,
bitset,
data_type);
result.AssembleChunkVectorIterators(
num_queries, 1, {0}, sub_qr.chunk_iterators());
} else {
auto sub_qr = BruteForceSearch(
dataset, vec_data, row_count, search_info, bitset, data_type);
auto sub_qr = BruteForceSearch(dataset,
vec_data,
row_count,
search_info,
index_info,
bitset,
data_type);
result.distances_ = std::move(sub_qr.mutable_distances());
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
}
Expand Down
2 changes: 2 additions & 0 deletions internal/core/src/query/SearchOnSealed.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ void
SearchOnSealed(const Schema& schema,
std::shared_ptr<ChunkedColumnBase> column,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,
Expand All @@ -41,6 +42,7 @@ void
SearchOnSealed(const Schema& schema,
const void* vec_data,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,
Expand Down
9 changes: 9 additions & 0 deletions internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,9 +940,18 @@
AssertInfo(num_rows_.has_value(), "Can't get row count value");
auto row_count = num_rows_.value();
auto vec_data = fields_.at(field_id);

// get index params for bm25 brute force
std::map<std::string, std::string> index_info;
if (search_info.metric_type_ == knowhere::metric::BM25) {
auto index_info =
col_index_meta_->GetFieldIndexMeta(field_id).GetIndexParams();

Check warning on line 948 in internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp#L945-L948

Added lines #L945 - L948 were not covered by tests
}

query::SearchOnSealed(*schema_,
vec_data,
search_info,
index_info,
query_data,
query_count,
row_count,
Expand Down
7 changes: 7 additions & 0 deletions internal/core/src/segcore/FieldIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@

#include "AckResponder.h"
#include "InsertRecord.h"
#include "common/FieldMeta.h"
#include "common/Schema.h"
#include "common/IndexMeta.h"
#include "IndexConfigGenerator.h"
#include "knowhere/config.h"
#include "log/Log.h"
#include "segcore/SegcoreConfig.h"
#include "index/VectorIndex.h"
Expand Down Expand Up @@ -429,6 +431,11 @@
return *ptr;
}

const FieldIndexMeta&
get_field_index_meta(FieldId fieldId) const {
return index_meta_->GetFieldIndexMeta(fieldId);

Check warning on line 436 in internal/core/src/segcore/FieldIndexing.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/segcore/FieldIndexing.h#L436

Added line #L436 was not covered by tests
}

bool
is_in(FieldId field_id) const {
return field_indexings_.count(field_id);
Expand Down
9 changes: 9 additions & 0 deletions internal/core/src/segcore/SegmentSealedImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,9 +985,18 @@
AssertInfo(num_rows_.has_value(), "Can't get row count value");
auto row_count = num_rows_.value();
auto vec_data = fields_.at(field_id);

// get index params for bm25 brute force
std::map<std::string, std::string> index_info;
if (search_info.metric_type_ == knowhere::metric::BM25) {
auto index_info =
col_index_meta_->GetFieldIndexMeta(field_id).GetIndexParams();

Check warning on line 993 in internal/core/src/segcore/SegmentSealedImpl.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/segcore/SegmentSealedImpl.cpp#L992-L993

Added lines #L992 - L993 were not covered by tests
}

query::SearchOnSealed(*schema_,
vec_data->Data(),
search_info,
index_info,
query_data,
query_count,
row_count,
Expand Down
2 changes: 2 additions & 0 deletions internal/core/unittest/test_bf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {

auto base = GenFloatVecs(dim, nb, metric_type);
auto query = GenFloatVecs(dim, nq, metric_type);
auto index_info = std::map<std::string, std::string>{};

dataset::SearchDataset dataset{
metric_type, nq, topk, -1, dim, query.data()};
Expand All @@ -137,6 +138,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
base.data(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_FLOAT);
for (int i = 0; i < nq; i++) {
Expand Down
6 changes: 6 additions & 0 deletions internal/core/unittest/test_bf_sparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License

#include <gtest/gtest.h>
#include <map>
#include <random>

#include "common/Utils.h"
Expand Down Expand Up @@ -98,6 +99,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {

auto base = milvus::segcore::GenerateRandomSparseFloatVector(nb);
auto query = milvus::segcore::GenerateRandomSparseFloatVector(nq);
auto index_info = std::map<std::string, std::string>{};
SearchInfo search_info;
search_info.topk_ = topk;
search_info.metric_type_ = metric_type;
Expand All @@ -108,6 +110,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT));
return;
Expand All @@ -116,6 +119,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
for (int i = 0; i < nq; i++) {
Expand All @@ -130,6 +134,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
for (int i = 0; i < nq; i++) {
Expand All @@ -143,6 +148,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
auto iterators = result3.chunk_iterators();
Expand Down
3 changes: 3 additions & 0 deletions internal/core/unittest/test_chunked_segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,13 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
auto query_ds = segcore::DataGen(schema, 1);
auto col_query_data = query_ds.get_col<float>(fakevec_id);
auto query_data = col_query_data.data();
auto index_info = std::map<std::string, std::string>{};
SearchResult search_result;

query::SearchOnSealed(*schema,
column,
search_info,
index_info,
query_data,
1,
total_row_count,
Expand All @@ -135,6 +137,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
query::SearchOnSealed(*schema,
column,
search_info,
index_info,
query_data,
1,
total_row_count,
Expand Down
2 changes: 2 additions & 0 deletions internal/core/unittest/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,15 @@ TEST(Indexing, BinaryBruteForce) {
};

SearchInfo search_info;
auto index_info = std::map<std::string, std::string>{};
search_info.topk_ = topk;
search_info.round_decimal_ = round_decimal;
search_info.metric_type_ = metric_type;
auto sub_result = query::BruteForceSearch(search_dataset,
bin_vec.data(),
N,
search_info,
index_info,
nullptr,
DataType::VECTOR_BINARY);

Expand Down
2 changes: 2 additions & 0 deletions internal/core/unittest/test_string_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto index_info = std::map<std::string, std::string>{};

std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};

Expand All @@ -1257,6 +1258,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
vec_col.data(),
N,
search_info,
index_info,
nullptr,
DataType::VECTOR_FLOAT);

Expand Down
Loading