Skip to content

Commit

Permalink
Support element in json array in segcore part(milvus-io#24677)
Browse files Browse the repository at this point in the history
Signed-off-by: luzhang <luzhang@zilliz.com>
  • Loading branch information
luzhang committed Jun 13, 2023
1 parent 707bb77 commit 3c80f8b
Show file tree
Hide file tree
Showing 11 changed files with 369 additions and 15 deletions.
8 changes: 6 additions & 2 deletions internal/core/src/query/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,18 @@ struct LogicalBinaryExpr : BinaryExprBase {
struct TermExpr : Expr {
const ColumnInfo column_;
const proto::plan::GenericValue::ValCase val_case_;
const bool is_in_field_;

protected:
// prevent accidental instantiation
TermExpr() = delete;

TermExpr(ColumnInfo column,
const proto::plan::GenericValue::ValCase val_case)
: column_(std::move(column)), val_case_(val_case) {
const proto::plan::GenericValue::ValCase val_case,
const bool is_in_field)
: column_(std::move(column)),
val_case_(val_case),
is_in_field_(is_in_field) {
}

public:
Expand Down
6 changes: 4 additions & 2 deletions internal/core/src/query/ExprImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ struct TermExprImpl : TermExpr {

TermExprImpl(ColumnInfo column,
const std::vector<T>& terms,
const proto::plan::GenericValue::ValCase val_case)
: TermExpr(std::forward<ColumnInfo>(column), val_case), terms_(terms) {
const proto::plan::GenericValue::ValCase val_case,
const bool is_in_field = false)
: TermExpr(std::forward<ColumnInfo>(column), val_case, is_in_field),
terms_(terms) {
}
};

Expand Down
4 changes: 3 additions & 1 deletion internal/core/src/query/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ ExprPtr
Parser::ParseTermNodeImpl(const FieldName& field_name, const Json& body) {
Assert(body.is_object());
auto values = body["values"];
auto is_in_field = body["is_in_field"];

std::vector<T> terms(values.size());
auto val_case = proto::plan::GenericValue::ValCase::VAL_NOT_SET;
Expand All @@ -242,7 +243,8 @@ Parser::ParseTermNodeImpl(const FieldName& field_name, const Json& body) {
ColumnInfo(schema.get_field_id(field_name),
schema[field_name].get_data_type()),
terms,
val_case);
val_case,
is_in_field);
}

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/query/PlanProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ ExtractTermExprImpl(FieldId field_id,
}
std::sort(terms.begin(), terms.end());
return std::make_unique<TermExprImpl<T>>(
expr_proto.column_info(), terms, val_case);
expr_proto.column_info(), terms, val_case, expr_proto.is_in_field());
}

template <typename T>
Expand Down
8 changes: 8 additions & 0 deletions internal/core/src/query/generated/ExecExprVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ class ExecExprVisitor : public ExprVisitor {
auto
ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType;

template <typename ExprValueType>
auto
ExecTermJsonVariableInField(TermExpr& expr_raw) -> BitsetType;

template <typename ExprValueType>
auto
ExecTermJsonFieldInVariable(TermExpr& expr_raw) -> BitsetType;

template <typename ExprValueType>
auto
ExecTermVisitorImplTemplateJson(TermExpr& expr_raw) -> BitsetType;
Expand Down
50 changes: 48 additions & 2 deletions internal/core/src/query/visitors/ExecExprVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1740,8 +1740,7 @@ ExecExprVisitor::ExecTermVisitorImplTemplate<bool>(TermExpr& expr_raw)

template <typename ExprValueType>
auto
ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw)
-> BitsetType {
ExecExprVisitor::ExecTermJsonFieldInVariable(TermExpr& expr_raw) -> BitsetType {
using Index = index::ScalarIndex<milvus::Json>;
auto& expr = static_cast<TermExprImpl<ExprValueType>&>(expr_raw);
auto pointer = milvus::Json::pointer(expr.column_.nested_path);
Expand Down Expand Up @@ -1783,6 +1782,53 @@ ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw)
expr.column_.field_id, index_func, elem_func);
}

template <typename ExprValueType>
auto
ExecExprVisitor::ExecTermJsonVariableInField(TermExpr& expr_raw) -> BitsetType {
using Index = index::ScalarIndex<milvus::Json>;
auto& expr = static_cast<TermExprImpl<ExprValueType>&>(expr_raw);
auto pointer = milvus::Json::pointer(expr.column_.nested_path);
auto index_func = [](Index* index) { return TargetBitmap{}; };

AssertInfo(expr.terms_.size() == 1,
"element length in json array must be one");
ExprValueType target_val = expr.terms_[0];

auto elem_func = [&target_val, &pointer](const milvus::Json& json) {
using GetType =
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto doc = json.doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error())
return false;
for (auto it = array.begin(); it != array.end(); ++it) {
auto val = (*it).template get<GetType>();
if (val.error()) {
return false;
}
if (val.value() == target_val)
return true;
}
return false;
};

return ExecRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
}

template <typename ExprValueType>
auto
ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw)
-> BitsetType {
if (expr_raw.is_in_field_) {
return ExecTermJsonVariableInField<ExprValueType>(expr_raw);
} else {
return ExecTermJsonFieldInVariable<ExprValueType>(expr_raw);
}
}

void
ExecExprVisitor::visit(TermExpr& expr) {
auto& field_meta = segment_.get_schema()[expr.column_.field_id];
Expand Down
8 changes: 5 additions & 3 deletions internal/core/unittest/test_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2470,7 +2470,8 @@ TEST(CApiTest, Indexing_With_float_Predicate_Term) {
{
"term": {
"counter": {
"values": [4200, 4201, 4202, 4203, 4204]
"values": [4200, 4201, 4202, 4203, 4204],
"is_in_field": false
}
}
},
Expand Down Expand Up @@ -3134,7 +3135,8 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) {
{
"term": {
"counter": {
"values": [4200, 4201, 4202, 4203, 4204]
"values": [4200, 4201, 4202, 4203, 4204],
"is_in_field": false
}
}
},
Expand Down Expand Up @@ -4457,4 +4459,4 @@ TEST(CApiTest, AssembeChunkTest) {
for (size_t i = 0; i < 105; i++) {
ASSERT_EQ(result[index++], chunk[i]) << i;
}
}
}
198 changes: 197 additions & 1 deletion internal/core/unittest/test_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,8 @@ TEST(Expr, TestTerm) {
{
"term": {
"age": {
"values": @@@@
"values": @@@@,
"is_in_field" : false
}
}
},
Expand Down Expand Up @@ -810,6 +811,7 @@ TEST(Expr, TestSimpleDsl) {
}
query::Json s;
s["term"]["age"]["values"] = terms;
s["term"]["age"]["is_in_field"] = false;
return s;
};
// std::cout << get_item(0).dump(-2);
Expand Down Expand Up @@ -2899,3 +2901,197 @@ TEST(Expr, TestExistsWithJSON) {
}
}
}

template <typename T>
struct Testcase {
std::vector<T> term;
std::vector<std::string> nested_path;
};

TEST(Expr, TestTermInFieldJson) {
using namespace milvus::query;
using namespace milvus::segcore;

auto schema = std::make_shared<Schema>();
auto i64_fid = schema->AddDebugField("id", DataType::INT64);
auto json_fid = schema->AddDebugField("json", DataType::JSON);
schema->set_primary_field_id(i64_fid);

auto seg = CreateGrowingSegment(schema, empty_index_meta);
int N = 10000;
std::vector<std::string> json_col;
int num_iters = 2;
for (int iter = 0; iter < num_iters; ++iter) {
auto raw_data = DataGenForJsonArray(schema, N, iter);
auto new_json_col = raw_data.get_col<std::string>(json_fid);

json_col.insert(
json_col.end(), new_json_col.begin(), new_json_col.end());
seg->PreInsert(N);
seg->Insert(iter * N,
N,
raw_data.row_ids_.data(),
raw_data.timestamps_.data(),
raw_data.raw_);
}

auto seg_promote = dynamic_cast<SegmentGrowingImpl*>(seg.get());
ExecExprVisitor visitor(
*seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP);

std::vector<Testcase<bool>> bool_testcases{{{true}, {"bool"}},
{{false}, {"bool"}}};

for (auto testcase : bool_testcases) {
auto check = [&](const std::vector<bool>& values) {
return std::find(values.begin(), values.end(), testcase.term[0]) !=
values.end();
};
RetrievePlanNode plan;
auto pointer = milvus::Json::pointer(testcase.nested_path);
plan.predicate_ = std::make_unique<TermExprImpl<bool>>(
ColumnInfo(json_fid, DataType::JSON, testcase.nested_path),
testcase.term,
proto::plan::GenericValue::ValCase::kBoolVal,
true);
auto start = std::chrono::steady_clock::now();
auto final = visitor.call_child(*plan.predicate_.value());
std::cout << "cost"
<< std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count()
<< std::endl;
EXPECT_EQ(final.size(), N * num_iters);

for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
auto array = milvus::Json(simdjson::padded_string(json_col[i]))
.array_at(pointer);
std::vector<bool> res;
for (const auto& element : array) {
res.push_back(element.template get<bool>());
}
ASSERT_EQ(ans, check(res));
}
}

std::vector<Testcase<double>> double_testcases{
{{1.123}, {"double"}},
{{10.34}, {"double"}},
{{100.234}, {"double"}},
{{1000.4546}, {"double"}},
};

for (auto testcase : double_testcases) {
auto check = [&](const std::vector<double>& values) {
return std::find(values.begin(), values.end(), testcase.term[0]) !=
values.end();
};
RetrievePlanNode plan;
auto pointer = milvus::Json::pointer(testcase.nested_path);
plan.predicate_ = std::make_unique<TermExprImpl<double>>(
ColumnInfo(json_fid, DataType::JSON, testcase.nested_path),
testcase.term,
proto::plan::GenericValue::ValCase::kFloatVal,
true);
auto start = std::chrono::steady_clock::now();
auto final = visitor.call_child(*plan.predicate_.value());
std::cout << "cost"
<< std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count()
<< std::endl;
EXPECT_EQ(final.size(), N * num_iters);

for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
auto array = milvus::Json(simdjson::padded_string(json_col[i]))
.array_at(pointer);
std::vector<double> res;
for (const auto& element : array) {
res.push_back(element.template get<double>());
}
ASSERT_EQ(ans, check(res));
}
}

std::vector<Testcase<int64_t>> testcases{
{{1}, {"int"}},
{{10}, {"int"}},
{{100}, {"int"}},
{{1000}, {"int"}},
};

for (auto testcase : testcases) {
auto check = [&](const std::vector<int64_t>& values) {
return std::find(values.begin(), values.end(), testcase.term[0]) !=
values.end();
};
RetrievePlanNode plan;
auto pointer = milvus::Json::pointer(testcase.nested_path);
plan.predicate_ = std::make_unique<TermExprImpl<int64_t>>(
ColumnInfo(json_fid, DataType::JSON, testcase.nested_path),
testcase.term,
proto::plan::GenericValue::ValCase::kInt64Val,
true);
auto start = std::chrono::steady_clock::now();
auto final = visitor.call_child(*plan.predicate_.value());
std::cout << "cost"
<< std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count()
<< std::endl;
EXPECT_EQ(final.size(), N * num_iters);

for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
auto array = milvus::Json(simdjson::padded_string(json_col[i]))
.array_at(pointer);
std::vector<int64_t> res;
for (const auto& element : array) {
res.push_back(element.template get<int64_t>());
}
ASSERT_EQ(ans, check(res));
}
}

std::vector<Testcase<std::string>> testcases_string = {
{{"1sads"}, {"string"}},
{{"10dsf"}, {"string"}},
{{"100"}, {"string"}},
{{"100ddfdsssdfdsfsd0"}, {"string"}},
};

for (auto testcase : testcases_string) {
auto check = [&](const std::vector<std::string_view>& values) {
return std::find(values.begin(), values.end(), testcase.term[0]) !=
values.end();
};
RetrievePlanNode plan;
auto pointer = milvus::Json::pointer(testcase.nested_path);
plan.predicate_ = std::make_unique<TermExprImpl<std::string>>(
ColumnInfo(json_fid, DataType::JSON, testcase.nested_path),
testcase.term,
proto::plan::GenericValue::ValCase::kStringVal,
true);
auto start = std::chrono::steady_clock::now();
auto final = visitor.call_child(*plan.predicate_.value());
std::cout << "cost"
<< std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count()
<< std::endl;
EXPECT_EQ(final.size(), N * num_iters);

for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
auto array = milvus::Json(simdjson::padded_string(json_col[i]))
.array_at(pointer);
std::vector<std::string_view> res;
for (const auto& element : array) {
res.push_back(element.template get<std::string_view>());
}
ASSERT_EQ(ans, check(res));
}
}
}
Loading

0 comments on commit 3c80f8b

Please sign in to comment.