Skip to content

Commit

Permalink
fix bm25 brute force search need index params k1 and b
Browse files Browse the repository at this point in the history
Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
  • Loading branch information
aoiasd committed Nov 15, 2024
1 parent 0f0162f commit 4295b66
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 6 deletions.
13 changes: 9 additions & 4 deletions internal/core/src/query/SearchBruteForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ CheckBruteForceSearchParam(const FieldMeta& field,
}

knowhere::Json
PrepareBFSearchParams(const SearchInfo& search_info) {
PrepareBFSearchParams(const SearchInfo& search_info, const std::map<std::string, std::string>& index_info) {
knowhere::Json search_cfg = search_info.search_params_;

search_cfg[knowhere::meta::METRIC_TYPE] = search_info.metric_type_;
Expand All @@ -62,6 +62,10 @@ PrepareBFSearchParams(const SearchInfo& search_info) {
if (search_info.metric_type_ == knowhere::metric::BM25) {
search_cfg[knowhere::meta::BM25_AVGDL] =
search_info.search_params_[knowhere::meta::BM25_AVGDL];
search_cfg[knowhere::meta::BM25_K1] =
std::stof(index_info.at(knowhere::meta::BM25_K1));
search_cfg[knowhere::meta::BM25_B] =
std::stof(index_info.at(knowhere::meta::BM25_B));
}
return search_cfg;
}
Expand All @@ -71,6 +75,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
SubSearchResult sub_result(dataset.num_queries,
Expand All @@ -87,12 +92,11 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
}
auto search_cfg = PrepareBFSearchParams(search_info);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
// `range_search_k` is only used as one of the conditions for iterator early termination.
// not gurantee to return exactly `range_search_k` results, which may be more or less.
// set it to -1 will return all results in the range.
search_cfg[knowhere::meta::RANGE_SEARCH_K] = topk;

sub_result.mutable_seg_offsets().resize(nq * topk);
sub_result.mutable_distances().resize(nq * topk);

Expand Down Expand Up @@ -201,6 +205,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
auto nq = dataset.num_queries;
Expand All @@ -211,7 +216,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
}
auto search_cfg = PrepareBFSearchParams(search_info);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);

knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
iterators_val;
Expand Down
2 changes: 2 additions & 0 deletions internal/core/src/query/SearchBruteForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);

Expand All @@ -36,6 +37,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);

Expand Down
13 changes: 13 additions & 0 deletions internal/core/src/query/SearchOnGrowing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
#include "common/Tracer.h"
#include "common/Types.h"
#include "SearchOnGrowing.h"
#include <cstddef>
#include "knowhere/comp/index_param.h"
#include "knowhere/config.h"
#include "log/Log.h"
#include "query/SearchBruteForce.h"
#include "query/SearchOnIndex.h"

Expand Down Expand Up @@ -109,6 +113,13 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
dataset::SearchDataset search_dataset{
metric_type, num_queries, topk, round_decimal, dim, query_data};
int32_t current_chunk_id = 0;

// get K1 and B from index for bm25 brute force
std::map<std::string, std::string> index_info;
if (metric_type == knowhere::metric::BM25){
index_info = segment.get_indexing_record().get_field_index_meta(vecfield_id).GetIndexParams();
}

// step 3: brute force search where small indexing is unavailable
auto vec_ptr = record.get_data_base(vecfield_id);
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
Expand All @@ -129,6 +140,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
chunk_data,
size_per_chunk,
info,
index_info,
sub_view,
data_type);
final_qr.merge(sub_qr);
Expand All @@ -137,6 +149,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
chunk_data,
size_per_chunk,
info,
index_info,
sub_view,
data_type);

Expand Down
8 changes: 6 additions & 2 deletions internal/core/src/query/SearchOnSealed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ void
SearchOnSealed(const Schema& schema,
std::shared_ptr<ChunkedColumnBase> column,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,
Expand Down Expand Up @@ -137,6 +138,7 @@ SearchOnSealed(const Schema& schema,
vec_data,
chunk_size,
search_info,
index_info,
bitset_view,
data_type);
final_qr.merge(sub_qr);
Expand All @@ -145,6 +147,7 @@ SearchOnSealed(const Schema& schema,
vec_data,
chunk_size,
search_info,
index_info,
bitset_view,
data_type);
for (auto& o : sub_qr.mutable_seg_offsets()) {
Expand Down Expand Up @@ -177,6 +180,7 @@ void
SearchOnSealed(const Schema& schema,
const void* vec_data,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,
Expand All @@ -201,12 +205,12 @@ SearchOnSealed(const Schema& schema,
CheckBruteForceSearchParam(field, search_info);
if (search_info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(
dataset, vec_data, row_count, search_info, bitset, data_type);
dataset, vec_data, row_count, search_info,index_info, bitset, data_type);
result.AssembleChunkVectorIterators(
num_queries, 1, {0}, sub_qr.chunk_iterators());
} else {
auto sub_qr = BruteForceSearch(
dataset, vec_data, row_count, search_info, bitset, data_type);
dataset, vec_data, row_count, search_info,index_info, bitset, data_type);
result.distances_ = std::move(sub_qr.mutable_distances());
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
}
Expand Down
2 changes: 2 additions & 0 deletions internal/core/src/query/SearchOnSealed.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ void
SearchOnSealed(const Schema& schema,
std::shared_ptr<ChunkedColumnBase> column,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,
Expand All @@ -41,6 +42,7 @@ void
SearchOnSealed(const Schema& schema,
const void* vec_data,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,
Expand Down
2 changes: 2 additions & 0 deletions internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,9 +940,11 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
AssertInfo(num_rows_.has_value(), "Can't get row count value");
auto row_count = num_rows_.value();
auto vec_data = fields_.at(field_id);
auto index_info = col_index_meta_->GetFieldIndexMeta(field_id).GetIndexParams();
query::SearchOnSealed(*schema_,
vec_data,
search_info,
index_info,
query_data,
query_count,
row_count,
Expand Down
7 changes: 7 additions & 0 deletions internal/core/src/segcore/FieldIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@

#include "AckResponder.h"
#include "InsertRecord.h"
#include "common/FieldMeta.h"
#include "common/Schema.h"
#include "common/IndexMeta.h"
#include "IndexConfigGenerator.h"
#include "knowhere/config.h"
#include "log/Log.h"
#include "segcore/SegcoreConfig.h"
#include "index/VectorIndex.h"
Expand Down Expand Up @@ -429,6 +431,11 @@ class IndexingRecord {
return *ptr;
}

const FieldIndexMeta&
get_field_index_meta(FieldId fieldId) const{
return index_meta_->GetFieldIndexMeta(fieldId);
}

bool
is_in(FieldId field_id) const {
return field_indexings_.count(field_id);
Expand Down
3 changes: 3 additions & 0 deletions internal/core/src/segcore/SegmentSealedImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,9 +985,12 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info,
AssertInfo(num_rows_.has_value(), "Can't get row count value");
auto row_count = num_rows_.value();
auto vec_data = fields_.at(field_id);
auto index_info = col_index_meta_->GetFieldIndexMeta(field_id).GetIndexParams();

query::SearchOnSealed(*schema_,
vec_data->Data(),
search_info,
index_info,
query_data,
query_count,
row_count,
Expand Down
2 changes: 2 additions & 0 deletions internal/core/unittest/test_bf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {

auto base = GenFloatVecs(dim, nb, metric_type);
auto query = GenFloatVecs(dim, nq, metric_type);
auto index_info = std::map<std::string, std::string>{};

dataset::SearchDataset dataset{
metric_type, nq, topk, -1, dim, query.data()};
Expand All @@ -137,6 +138,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
base.data(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_FLOAT);
for (int i = 0; i < nq; i++) {
Expand Down
6 changes: 6 additions & 0 deletions internal/core/unittest/test_bf_sparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License

#include <gtest/gtest.h>
#include <map>
#include <random>

#include "common/Utils.h"
Expand Down Expand Up @@ -98,6 +99,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {

auto base = milvus::segcore::GenerateRandomSparseFloatVector(nb);
auto query = milvus::segcore::GenerateRandomSparseFloatVector(nq);
auto index_info = std::map<std::string, std::string>{};
SearchInfo search_info;
search_info.topk_ = topk;
search_info.metric_type_ = metric_type;
Expand All @@ -108,6 +110,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT));
return;
Expand All @@ -116,6 +119,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
for (int i = 0; i < nq; i++) {
Expand All @@ -130,6 +134,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
for (int i = 0; i < nq; i++) {
Expand All @@ -143,6 +148,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
auto iterators = result3.chunk_iterators();
Expand Down
4 changes: 4 additions & 0 deletions internal/core/unittest/test_chunked_segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
auto schema = std::make_shared<Schema>();
auto fakevec_id = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::COSINE);


for (int i = 0; i < chunk_num; i++) {
auto dataset = segcore::DataGen(schema, chunk_size);
Expand Down Expand Up @@ -106,11 +107,13 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
auto query_ds = segcore::DataGen(schema, 1);
auto col_query_data = query_ds.get_col<float>(fakevec_id);
auto query_data = col_query_data.data();
auto index_info = std::map<std::string, std::string>{};
SearchResult search_result;

query::SearchOnSealed(*schema,
column,
search_info,
index_info,
query_data,
1,
total_row_count,
Expand All @@ -135,6 +138,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
query::SearchOnSealed(*schema,
column,
search_info,
index_info,
query_data,
1,
total_row_count,
Expand Down
2 changes: 2 additions & 0 deletions internal/indexnode/task_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"runtime/debug"
"sync"
"time"

"github.com/cockroachdb/errors"
"go.uber.org/zap"
Expand Down Expand Up @@ -254,6 +255,7 @@ func (sched *TaskScheduler) indexBuildLoop() {
return
case <-sched.TaskQueue.utChan():
tasks := sched.scheduleIndexBuildTask()
time.Sleep(6000 * time.Second)
var wg sync.WaitGroup
for _, t := range tasks {
wg.Add(1)
Expand Down

0 comments on commit 4295b66

Please sign in to comment.