Skip to content

Commit

Permalink
fix: fix chunked segment term filter expression and add ut (#37392)
Browse files Browse the repository at this point in the history
issue: #37143

---------

Signed-off-by: sunby <sunbingyi1992@gmail.com>
  • Loading branch information
sunby authored Nov 7, 2024
1 parent 5310d34 commit 40ba5a3
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 42 deletions.
12 changes: 4 additions & 8 deletions internal/core/src/exec/expression/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,10 @@ class SegmentExpr : public Expr {
if (segment_->type() == SegmentType::Sealed) {
// first is the raw data, second is valid_data
// use valid_data to see if raw data is null
auto data_vec = segment_
->get_batch_views<T>(
field_id_, i, data_pos, size)
.first;
auto valid_data = segment_
->get_batch_views<T>(
field_id_, i, data_pos, size)
.second;
auto fetched_data = segment_->get_batch_views<T>(
field_id_, i, data_pos, size);
auto data_vec = fetched_data.first;
auto valid_data = fetched_data.second;
func(data_vec.data(),
valid_data.data(),
size,
Expand Down
17 changes: 15 additions & 2 deletions internal/core/src/exec/expression/TermExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "TermExpr.h"
#include <memory>
#include <utility>
#include "log/Log.h"
#include "query/Utils.h"
namespace milvus {
namespace exec {
Expand Down Expand Up @@ -137,9 +138,21 @@ PhyTermFilterExpr::CanSkipSegment() {
max = i == 0 ? val : std::max(val, max);
min = i == 0 ? val : std::min(val, min);
}
auto can_skip = [&]() -> bool {
bool res = false;
for (int i = 0; i < num_data_chunk_; ++i) {
if (!skip_index.CanSkipBinaryRange<T>(
field_id_, i, min, max, true, true)) {
return false;
} else {
res = true;
}
}
return res;
};

// using skip index to help skipping this segment
if (segment_->type() == SegmentType::Sealed &&
skip_index.CanSkipBinaryRange<T>(field_id_, 0, min, max, true, true)) {
if (segment_->type() == SegmentType::Sealed && can_skip()) {
cached_bits_.resize(active_count_, false);
cached_bits_inited_ = true;
return true;
Expand Down
44 changes: 18 additions & 26 deletions internal/core/src/mmap/ChunkedColumn.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ class ChunkedColumnBase : public ColumnBase {
return true;
}

bool
IsValid(int64_t chunk_id, int64_t offset) const {
if (nullable_) {
return chunks_[chunk_id]->isValid(offset);
}
return true;
}

bool
IsNullable() const {
return nullable_;
Expand Down Expand Up @@ -136,7 +144,7 @@ class ChunkedColumnBase : public ColumnBase {

// used for sequential access for search
virtual BufferView
GetBatchBuffer(int64_t start_offset, int64_t length) {
GetBatchBuffer(int64_t chunk_id, int64_t start_offset, int64_t length) {
PanicInfo(ErrorCode::Unsupported,
"GetBatchBuffer only supported for VariableColumn");
}
Expand Down Expand Up @@ -323,33 +331,17 @@ class ChunkedVariableColumn : public ChunkedColumnBase {
}

BufferView
GetBatchBuffer(int64_t start_offset, int64_t length) override {
if (start_offset < 0 || start_offset > num_rows_ ||
start_offset + length > num_rows_) {
PanicInfo(ErrorCode::OutOfRange, "index out of range");
}

int chunk_num = chunks_.size();

auto [start_chunk_id, start_offset_in_chunk] =
GetChunkIDByOffset(start_offset);
GetBatchBuffer(int64_t chunk_id,
int64_t start_offset,
int64_t length) override {
BufferView buffer_view;

std::vector<BufferView::Element> elements;
for (; start_chunk_id < chunk_num && length > 0; ++start_chunk_id) {
int chunk_size = chunks_[start_chunk_id]->RowNums();
int len =
std::min(int64_t(chunk_size - start_offset_in_chunk), length);
elements.push_back(
{chunks_[start_chunk_id]->Data(),
std::dynamic_pointer_cast<StringChunk>(chunks_[start_chunk_id])
->Offsets(),
static_cast<int>(start_offset_in_chunk),
static_cast<int>(start_offset_in_chunk + len)});

start_offset_in_chunk = 0;
length -= len;
}
elements.push_back(
{chunks_[chunk_id]->Data(),
std::dynamic_pointer_cast<StringChunk>(chunks_[chunk_id])
->Offsets(),
static_cast<int>(start_offset),
static_cast<int>(start_offset + length)});

buffer_view.data_ = elements;
return buffer_view;
Expand Down
17 changes: 11 additions & 6 deletions internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,11 +776,13 @@ ChunkedSegmentSealedImpl::get_chunk_buffer(FieldId field_id,
if (field_data->IsNullable()) {
valid_data.reserve(length);
for (int i = 0; i < length; i++) {
valid_data.push_back(field_data->IsValid(start_offset + i));
valid_data.push_back(
field_data->IsValid(chunk_id, start_offset + i));
}
}
return std::make_pair(field_data->GetBatchBuffer(start_offset, length),
valid_data);
return std::make_pair(
field_data->GetBatchBuffer(chunk_id, start_offset, length),
valid_data);
}
PanicInfo(ErrorCode::UnexpectedError,
"get_chunk_buffer only used for variable column field");
Expand Down Expand Up @@ -1227,9 +1229,10 @@ ChunkedSegmentSealedImpl::search_sorted_pk(const PkType& pk,
[](const int64_t& elem, const int64_t& value) {
return elem < value;
});
auto num_rows_until_chunk = pk_column->GetNumRowsUntilChunk(i);
for (; it != src + pk_column->NumRows() && *it == target;
++it) {
auto offset = it - src;
auto offset = it - src + num_rows_until_chunk;
if (condition(offset)) {
pk_offsets.emplace_back(offset);
}
Expand All @@ -1248,14 +1251,16 @@ ChunkedSegmentSealedImpl::search_sorted_pk(const PkType& pk,
auto num_chunk = var_column->num_chunks();
for (int i = 0; i < num_chunk; ++i) {
// TODO @xiaocai2333, @sunby: chunk need to record the min/max.
auto num_rows_until_chunk = pk_column->GetNumRowsUntilChunk(i);
auto string_chunk = std::dynamic_pointer_cast<StringChunk>(
var_column->GetChunk(i));
auto offset = string_chunk->binary_search_string(target);
for (; offset != -1 && offset < var_column->NumRows() &&
var_column->RawAt(offset) == target;
++offset) {
if (condition(offset)) {
pk_offsets.emplace_back(offset);
auto segment_offset = offset + num_rows_until_chunk;
if (condition(segment_offset)) {
pk_offsets.emplace_back(segment_offset);
}
}
}
Expand Down
124 changes: 124 additions & 0 deletions internal/core/unittest/test_chunked_segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,25 @@
#include <gtest/gtest.h>
#include <algorithm>
#include <cstdint>
#include "arrow/table_builder.h"
#include "arrow/type_fwd.h"
#include "common/BitsetView.h"
#include "common/Consts.h"
#include "common/FieldDataInterface.h"
#include "common/QueryInfo.h"
#include "common/Schema.h"
#include "common/Types.h"
#include "expr/ITypeExpr.h"
#include "knowhere/comp/index_param.h"
#include "mmap/ChunkedColumn.h"
#include "mmap/Types.h"
#include "query/ExecPlanNodeVisitor.h"
#include "query/SearchOnSealed.h"
#include "segcore/SegcoreConfig.h"
#include "segcore/SegmentSealedImpl.h"
#include "test_utils/DataGen.h"
#include <memory>
#include <numeric>
#include <vector>

struct DeferRelease {
Expand Down Expand Up @@ -135,3 +147,115 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
ASSERT_TRUE(offsets.find(i) != offsets.end());
}
}

TEST(test_chunk_segment, TestTermExpr) {
auto schema = std::make_shared<Schema>();
auto int64_fid = schema->AddDebugField("int64", DataType::INT64, true);
auto pk_fid = schema->AddDebugField("pk", DataType::INT64, true);
schema->AddField(FieldName("ts"), TimestampFieldID, DataType::INT64, true);
schema->set_primary_field_id(pk_fid);
auto segment =
segcore::CreateSealedSegment(schema,
nullptr,
-1,
segcore::SegcoreConfig::default_config(),
false,
false,
true);
size_t test_data_count = 1000;

auto arrow_i64_field = arrow::field("int64", arrow::int64());
auto arrow_pk_field = arrow::field("pk", arrow::int64());
auto arrow_ts_field = arrow::field("ts", arrow::int64());
std::vector<std::shared_ptr<arrow::Field>> arrow_fields = {
arrow_i64_field, arrow_pk_field, arrow_ts_field};

std::vector<FieldId> field_ids = {int64_fid, pk_fid, TimestampFieldID};

int start_id = 1;
int chunk_num = 2;

std::vector<FieldDataInfo> field_infos;
for (auto fid : field_ids) {
FieldDataInfo field_info;
field_info.field_id = fid.get();
field_info.row_count = test_data_count * chunk_num;
field_infos.push_back(field_info);
}

// generate data
for (int chunk_id = 0; chunk_id < chunk_num;
chunk_id++, start_id += test_data_count) {
std::vector<int64_t> test_data(test_data_count);
std::iota(test_data.begin(), test_data.end(), start_id);

auto builder = std::make_shared<arrow::Int64Builder>();
auto status = builder->AppendValues(test_data.begin(), test_data.end());
ASSERT_TRUE(status.ok());
auto res = builder->Finish();
ASSERT_TRUE(res.ok());
std::shared_ptr<arrow::Array> arrow_int64;
arrow_int64 = res.ValueOrDie();

for (int i = 0; i < arrow_fields.size(); i++) {
auto f = arrow_fields[i];
auto fid = field_ids[i];
auto arrow_schema =
std::make_shared<arrow::Schema>(arrow::FieldVector(1, f));
auto record_batch = arrow::RecordBatch::Make(
arrow_schema, arrow_int64->length(), {arrow_int64});

auto res2 = arrow::RecordBatchReader::Make({record_batch});
ASSERT_TRUE(res2.ok());
auto arrow_reader = res2.ValueOrDie();

field_infos[i].arrow_reader_channel->push(
std::make_shared<ArrowDataWrapper>(
arrow_reader, nullptr, nullptr));
}
}

// load
for (int i = 0; i < field_infos.size(); i++) {
field_infos[i].arrow_reader_channel->close();
segment->LoadFieldData(field_ids[i], field_infos[i]);
}

// query int64 expr
std::vector<proto::plan::GenericValue> filter_data;
for (int i = 1; i <= 10; ++i) {
proto::plan::GenericValue v;
v.set_int64_val(i);
filter_data.push_back(v);
}
auto term_filter_expr = std::make_shared<expr::TermFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64), filter_data);
BitsetType final;
auto plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
term_filter_expr);
final = query::ExecuteQueryExpr(
plan, segment.get(), chunk_num * test_data_count, MAX_TIMESTAMP);
ASSERT_EQ(10, final.count());

// query pk expr
auto pk_term_filter_expr = std::make_shared<expr::TermFilterExpr>(
expr::ColumnInfo(pk_fid, DataType::INT64), filter_data);
plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
pk_term_filter_expr);
final = query::ExecuteQueryExpr(
plan, segment.get(), chunk_num * test_data_count, MAX_TIMESTAMP);
ASSERT_EQ(10, final.count());

// query pk in second chunk
std::vector<proto::plan::GenericValue> filter_data2;
proto::plan::GenericValue v;
v.set_int64_val(test_data_count + 1);
filter_data2.push_back(v);
pk_term_filter_expr = std::make_shared<expr::TermFilterExpr>(
expr::ColumnInfo(pk_fid, DataType::INT64), filter_data2);
plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
pk_term_filter_expr);
final = query::ExecuteQueryExpr(
plan, segment.get(), chunk_num * test_data_count, MAX_TIMESTAMP);
ASSERT_EQ(1, final.count());
}

0 comments on commit 40ba5a3

Please sign in to comment.