From 3c80f8bc5ae02be31b5b039ef8287070be13b560 Mon Sep 17 00:00:00 2001 From: luzhang Date: Tue, 13 Jun 2023 12:53:29 -0400 Subject: [PATCH] Support element in json array in segcore part(#24677) Signed-off-by: luzhang --- internal/core/src/query/Expr.h | 8 +- internal/core/src/query/ExprImpl.h | 6 +- internal/core/src/query/Parser.cpp | 4 +- internal/core/src/query/PlanProto.cpp | 2 +- .../src/query/generated/ExecExprVisitor.h | 8 + .../src/query/visitors/ExecExprVisitor.cpp | 50 ++++- internal/core/unittest/test_c_api.cpp | 8 +- internal/core/unittest/test_expr.cpp | 198 +++++++++++++++++- internal/core/unittest/test_plan_proto.cpp | 4 +- internal/core/unittest/test_query.cpp | 5 +- internal/core/unittest/test_utils/DataGen.h | 91 ++++++++ 11 files changed, 369 insertions(+), 15 deletions(-) diff --git a/internal/core/src/query/Expr.h b/internal/core/src/query/Expr.h index 7456aceea664c..ba3547e3a2700 100644 --- a/internal/core/src/query/Expr.h +++ b/internal/core/src/query/Expr.h @@ -121,14 +121,18 @@ struct LogicalBinaryExpr : BinaryExprBase { struct TermExpr : Expr { const ColumnInfo column_; const proto::plan::GenericValue::ValCase val_case_; + const bool is_in_field_; protected: // prevent accidental instantiation TermExpr() = delete; TermExpr(ColumnInfo column, - const proto::plan::GenericValue::ValCase val_case) - : column_(std::move(column)), val_case_(val_case) { + const proto::plan::GenericValue::ValCase val_case, + const bool is_in_field) + : column_(std::move(column)), + val_case_(val_case), + is_in_field_(is_in_field) { } public: diff --git a/internal/core/src/query/ExprImpl.h b/internal/core/src/query/ExprImpl.h index 53cd2b39ddb32..3c523c87cf598 100644 --- a/internal/core/src/query/ExprImpl.h +++ b/internal/core/src/query/ExprImpl.h @@ -32,8 +32,10 @@ struct TermExprImpl : TermExpr { TermExprImpl(ColumnInfo column, const std::vector& terms, - const proto::plan::GenericValue::ValCase val_case) - : TermExpr(std::forward(column), val_case), terms_(terms) { + const proto::plan::GenericValue::ValCase val_case, + const bool is_in_field = false) + : TermExpr(std::forward(column), val_case, is_in_field), + terms_(terms) { } }; diff --git a/internal/core/src/query/Parser.cpp b/internal/core/src/query/Parser.cpp index 028ac65629cb9..43af189d81041 100644 --- a/internal/core/src/query/Parser.cpp +++ b/internal/core/src/query/Parser.cpp @@ -218,6 +218,7 @@ ExprPtr Parser::ParseTermNodeImpl(const FieldName& field_name, const Json& body) { Assert(body.is_object()); auto values = body["values"]; + auto is_in_field = body["is_in_field"]; std::vector terms(values.size()); auto val_case = proto::plan::GenericValue::ValCase::VAL_NOT_SET; @@ -242,7 +243,8 @@ Parser::ParseTermNodeImpl(const FieldName& field_name, const Json& body) { ColumnInfo(schema.get_field_id(field_name), schema[field_name].get_data_type()), terms, - val_case); + val_case, + is_in_field); } template diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index 247699b9a0c4b..a821a682278db 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -59,7 +59,7 @@ ExtractTermExprImpl(FieldId field_id, } std::sort(terms.begin(), terms.end()); return std::make_unique>( - expr_proto.column_info(), terms, val_case); + expr_proto.column_info(), terms, val_case, expr_proto.is_in_field()); } template diff --git a/internal/core/src/query/generated/ExecExprVisitor.h b/internal/core/src/query/generated/ExecExprVisitor.h index 996a5e0fed03a..3c608043e4bf4 100644 --- a/internal/core/src/query/generated/ExecExprVisitor.h +++ b/internal/core/src/query/generated/ExecExprVisitor.h @@ -117,6 +117,14 @@ class ExecExprVisitor : public ExprVisitor { auto ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType; + template + auto + ExecTermJsonVariableInField(TermExpr& expr_raw) -> BitsetType; + + template + auto + ExecTermJsonFieldInVariable(TermExpr& expr_raw) -> BitsetType; + template auto ExecTermVisitorImplTemplateJson(TermExpr& expr_raw) -> BitsetType; diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index 2379ab0f0eeae..b741fa63704c2 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -1740,8 +1740,7 @@ ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) template auto -ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw) - -> BitsetType { +ExecExprVisitor::ExecTermJsonFieldInVariable(TermExpr& expr_raw) -> BitsetType { using Index = index::ScalarIndex; auto& expr = static_cast&>(expr_raw); auto pointer = milvus::Json::pointer(expr.column_.nested_path); @@ -1783,6 +1782,53 @@ ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw) expr.column_.field_id, index_func, elem_func); } +template +auto +ExecExprVisitor::ExecTermJsonVariableInField(TermExpr& expr_raw) -> BitsetType { + using Index = index::ScalarIndex; + auto& expr = static_cast&>(expr_raw); + auto pointer = milvus::Json::pointer(expr.column_.nested_path); + auto index_func = [](Index* index) { return TargetBitmap{}; }; + + AssertInfo(expr.terms_.size() == 1, + "element length in json array must be one"); + ExprValueType target_val = expr.terms_[0]; + + auto elem_func = [&target_val, &pointer](const milvus::Json& json) { + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) + return false; + for (auto it = array.begin(); it != array.end(); ++it) { + auto val = (*it).template get(); + if (val.error()) { + return false; + } + if (val.value() == target_val) + return true; + } + return false; + }; + + return ExecRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); +} + +template +auto +ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw) + -> BitsetType { + if (expr_raw.is_in_field_) { + return ExecTermJsonVariableInField(expr_raw); + } else { + return ExecTermJsonFieldInVariable(expr_raw); + } +} + void ExecExprVisitor::visit(TermExpr& expr) { auto& field_meta = segment_.get_schema()[expr.column_.field_id]; diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 792c2efa8a379..539d84a24fff7 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -2470,7 +2470,8 @@ TEST(CApiTest, Indexing_With_float_Predicate_Term) { { "term": { "counter": { - "values": [4200, 4201, 4202, 4203, 4204] + "values": [4200, 4201, 4202, 4203, 4204], + "is_in_field": false } } }, @@ -3134,7 +3135,8 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) { { "term": { "counter": { - "values": [4200, 4201, 4202, 4203, 4204] + "values": [4200, 4201, 4202, 4203, 4204], + "is_in_field": false } } }, @@ -4457,4 +4459,4 @@ TEST(CApiTest, AssembeChunkTest) { for (size_t i = 0; i < 105; i++) { ASSERT_EQ(result[index++], chunk[i]) << i; } -} +} \ No newline at end of file diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index f916116c5289d..4734417461cba 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -719,7 +719,8 @@ TEST(Expr, TestTerm) { { "term": { "age": { - "values": @@@@ + "values": @@@@, + "is_in_field" : false } } }, @@ -810,6 +811,7 @@ TEST(Expr, TestSimpleDsl) { } query::Json s; s["term"]["age"]["values"] = terms; + s["term"]["age"]["is_in_field"] = false; return s; }; // std::cout << get_item(0).dump(-2); @@ -2899,3 +2901,197 @@ TEST(Expr, TestExistsWithJSON) { } } } + +template +struct Testcase { + std::vector term; + std::vector nested_path; +}; + +TEST(Expr, TestTermInFieldJson) { + using namespace milvus::query; + using namespace milvus::segcore; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 10000; + std::vector json_col; + int num_iters = 2; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGenForJsonArray(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + ExecExprVisitor visitor( + *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + + std::vector> bool_testcases{{{true}, {"bool"}}, + {{false}, {"bool"}}}; + + for (auto testcase : bool_testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + testcase.term, + proto::plan::GenericValue::ValCase::kBoolVal, + true); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> double_testcases{ + {{1.123}, {"double"}}, + {{10.34}, {"double"}}, + {{100.234}, {"double"}}, + {{1000.4546}, {"double"}}, + }; + + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + testcase.term, + proto::plan::GenericValue::ValCase::kFloatVal, + true); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> testcases{ + {{1}, {"int"}}, + {{10}, {"int"}}, + {{100}, {"int"}}, + {{1000}, {"int"}}, + }; + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + testcase.term, + proto::plan::GenericValue::ValCase::kInt64Val, + true); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> testcases_string = { + {{"1sads"}, {"string"}}, + {{"10dsf"}, {"string"}}, + {{"100"}, {"string"}}, + {{"100ddfdsssdfdsfsd0"}, {"string"}}, + }; + + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + testcase.term, + proto::plan::GenericValue::ValCase::kStringVal, + true); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } + } +} diff --git a/internal/core/unittest/test_plan_proto.cpp b/internal/core/unittest/test_plan_proto.cpp index 0deb5a44ae28e..71898a597ce0a 100644 --- a/internal/core/unittest/test_plan_proto.cpp +++ b/internal/core/unittest/test_plan_proto.cpp @@ -187,6 +187,7 @@ vector_anns: < values: < %4%: 3 > + is_in_field : false > > query_info: < @@ -218,7 +219,8 @@ vector_anns: < { "term": { "%1%": { - "values": [1,2,3] + "values": [1,2,3], + "is_in_field" : false } } }, diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index 0d404377bc680..8830f082766e8 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -381,7 +381,8 @@ TEST(Query, ExecTerm) { { "term": { "age": { - "values": [] + "values": [], + "is_in_field": false } } }, @@ -700,7 +701,7 @@ TEST(Query, FillSegment) { auto dataset = DataGen(schema, N); const auto std_vec = dataset.get_col(FieldId(101)); // ids field const auto std_vfloat_vec = - dataset.get_col(FieldId(100)); // vector field + dataset.get_col(FieldId(100)); // vector field const auto std_i32_vec = dataset.get_col(FieldId(102)); // scalar field diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index f53a7689882b8..c80c7129243d1 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -189,6 +189,13 @@ struct GeneratedData { uint64_t seed, uint64_t ts_offset, int repeat_count); + friend GeneratedData + DataGenForJsonArray(SchemaPtr schema, + int64_t N, + uint64_t seed, + uint64_t ts_offset, + int repeat_count, + int array_len); }; inline GeneratedData @@ -351,6 +358,90 @@ DataGen(SchemaPtr schema, return res; } +template +std::string +join(const std::vector& items, const std::string& delimiter) { + std::stringstream ss; + for (size_t i = 0; i < items.size(); ++i) { + if (i > 0) { + ss << delimiter; + } + ss << items[i]; + } + return ss.str(); +} + +inline GeneratedData +DataGenForJsonArray(SchemaPtr schema, + int64_t N, + uint64_t seed = 42, + uint64_t ts_offset = 0, + int repeat_count = 1, + int array_len = 1) { + using std::vector; + std::default_random_engine er(seed); + std::normal_distribution<> distr(0, 1); + + auto insert_data = std::make_unique(); + auto insert_cols = [&insert_data]( + auto& data, int64_t count, auto& field_meta) { + auto array = milvus::segcore::CreateDataArrayFrom( + data.data(), count, field_meta); + insert_data->mutable_fields_data()->AddAllocated(array.release()); + }; + for (auto field_id : schema->get_field_ids()) { + auto field_meta = schema->operator[](field_id); + switch (field_meta.get_data_type()) { + case DataType::INT64: { + vector data(N); + for (int i = 0; i < N; i++) { + data[i] = i / repeat_count; + } + insert_cols(data, N, field_meta); + break; + } + case DataType::JSON: { + vector data(N); + for (int i = 0; i < N / repeat_count; i++) { + std::vector intVec; + std::vector doubleVec; + std::vector stringVec; + std::vector boolVec; + for (int i = 0; i < array_len; ++i) { + intVec.push_back(std::to_string(er())); + doubleVec.push_back( + std::to_string(static_cast(er()))); + stringVec.push_back("\"" + std::to_string(er()) + "\""); + boolVec.push_back(i % 2 == 0 ? "true" : "false"); + } + auto str = R"({"int":[)" + join(intVec, ",") + + R"(],"double":[)" + join(doubleVec, ",") + + R"(],"string":[)" + join(stringVec, ",") + + R"(],"bool": [)" + join(boolVec, ",") + "]}"; + //std::cout << str << std::endl; + data[i] = str; + } + insert_cols(data, N, field_meta); + break; + } + default: { + throw std::runtime_error("unimplemented"); + } + } + } + + milvus::segcore::GeneratedData res; + res.schema_ = schema; + res.raw_ = insert_data.release(); + res.raw_->set_num_rows(N); + for (int i = 0; i < N; ++i) { + res.row_ids_.push_back(i); + res.timestamps_.push_back(i + ts_offset); + } + + return res; +} + inline auto CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) { namespace ser = milvus::proto::common;