diff --git a/internal/agg/aggregate.go b/internal/agg/aggregate.go index 0f21c15df8844..c91689f6f20b9 100644 --- a/internal/agg/aggregate.go +++ b/internal/agg/aggregate.go @@ -177,3 +177,26 @@ func (max *MaxAggregate) Reduce() { func (max *MaxAggregate) ToPB() *planpb.Aggregate { return &planpb.Aggregate{Op: planpb.AggregateOp_max, FieldId: max.aggFieldID()} } + +func OrganizeAggregates(userAggregates []AggregateBase) map[AggID]AggregateBase { + realAggregates := make(map[AggID]AggregateBase, 0) + for _, userAgg := range userAggregates { + subAggs := userAgg.Decompose() + for _, subAgg := range subAggs { + if _, exist := realAggregates[subAgg.ID()]; !exist { + realAggregates[subAgg.ID()] = subAgg + } + } + } + return realAggregates +} + +func AggregatesToPB(aggregates map[AggID]AggregateBase) []*planpb.Aggregate { + ret := make([]*planpb.Aggregate, len(aggregates)) + if aggregates != nil { + for idx, agg := range aggregates { + ret[idx] = agg.ToPB() + } + } + return ret +} diff --git a/internal/core/src/common/Schema.h b/internal/core/src/common/Schema.h index 6a610fc4691d7..57bef10867f7d 100644 --- a/internal/core/src/common/Schema.h +++ b/internal/core/src/common/Schema.h @@ -183,7 +183,7 @@ class Schema { FieldId get_field_id(const FieldName& field_name) const { - AssertInfo(name_ids_.count(field_name), "Cannot find field_name"); + AssertInfo(name_ids_.count(field_name), "Cannot find field_name:{}", field_name.get()); return name_ids_.at(field_name); } @@ -232,6 +232,21 @@ class Schema { field_ids_.emplace_back(field_id); } + DataType + FieldType(const FieldId& field_id) const { + AssertInfo(fields_.count(field_id), "field_id:{} is not existed in the schema", field_id.get()); + auto& meta = fields_.at(field_id); + return meta.get_data_type(); + } + + const std::string& + FieldName(const FieldId& field_id) const { + AssertInfo(fields_.count(field_id), "field_id:{} is not existed in the schema", field_id.get()); + auto& meta = fields_.at(field_id); + return meta.get_name().get(); + } + + private: int64_t debug_id = START_USER_FIELDID; std::vector field_ids_; diff --git a/internal/core/src/expr/ITypeExpr.h b/internal/core/src/expr/ITypeExpr.h index f41b76d1a2001..5285b87a61a41 100644 --- a/internal/core/src/expr/ITypeExpr.h +++ b/internal/core/src/expr/ITypeExpr.h @@ -262,14 +262,18 @@ using CallTypeExprPtr = std::shared_ptr; class FieldAccessTypeExpr : public ITypeExpr { public: - FieldAccessTypeExpr(DataType type, const std::string& name) - : ITypeExpr{type}, name_(name), is_input_column_(true) { + FieldAccessTypeExpr(DataType type, const std::string& name, FieldId fieldId) + : ITypeExpr{type}, name_(name), field_id_(fieldId), is_input_column_(true) { } + FieldAccessTypeExpr(DataType type, FieldId fieldId) + : ITypeExpr{type}, name_(""), field_id_(fieldId), is_input_column_(true){} + FieldAccessTypeExpr(DataType type, const TypedExprPtr& input, - const std::string& name) - : ITypeExpr{type, {std::move(input)}}, name_(name) { + const std::string& name, + FieldId fieldId) + : ITypeExpr{type, {std::move(input)}}, name_(name), field_id_(fieldId) { is_input_column_ = dynamic_cast(inputs_[0].get()) != nullptr; } @@ -285,11 +289,12 @@ class FieldAccessTypeExpr : public ITypeExpr { return fmt::format("{}", name_); } - return fmt::format("{}[{}]", inputs_[0]->ToString(), name_); + return fmt::format("{}[{}{}]", inputs_[0]->ToString(), name_, field_id_.get()); } private: std::string name_; + const FieldId field_id_; bool is_input_column_; }; diff --git a/internal/core/src/plan/PlanNode.h b/internal/core/src/plan/PlanNode.h index 2c0f877b0a3eb..edda348a14c92 100644 --- a/internal/core/src/plan/PlanNode.h +++ b/internal/core/src/plan/PlanNode.h @@ -26,6 +26,7 @@ #include "common/EasyAssert.h" #include "segcore/SegmentInterface.h" #include "plan/PlanNodeIdGenerator.h" +#include "pb/plan.pb.h" namespace milvus { namespace plan { @@ -414,18 +415,24 @@ class AggregationNode: public PlanNode { struct Aggregate { /// Function name and input column names. - expr::CallTypeExprPtr call; - - /// Raw input types used to properly identify aggregate function. These - /// might be different from the input types specified in 'call' when - /// aggregation step is kIntermediate or kFinal. - std::vector rawInputTypes; + expr::CallTypeExprPtr call_; + public: + Aggregate(expr::CallTypeExprPtr call):call_(call){} }; std::vector sources() const override { return sources_; } + AggregationNode(const PlanNodeId& id, + std::vector&& groupingKeys, + std::vector&& aggNames, + std::vector&& aggregates, + std::vector&& sources, + RowType&& output_type) + : PlanNode(id), groupingKeys_(std::move(groupingKeys)), aggregateNames_(std::move(aggNames)), aggregates_(std::move(aggregates)), + sources_(std::move(sources)), output_type_(std::move(output_type)), ignoreNullKeys_(true){} + private: const std::vector groupingKeys_; const std::vector aggregateNames_; diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index d199f0be77de7..9cffdaf4cf82a 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -22,10 +22,22 @@ #include "query/Utils.h" #include "knowhere/comp/materialized_view.h" #include "plan/PlanNode.h" +#include "expr/ITypeExpr.h" namespace milvus::query { namespace planpb = milvus::proto::plan; +std::string getAggregateOpName(planpb::AggregateOp op) { + switch (op) { + case planpb::sum: return "sum"; + case planpb::count: return "count"; + case planpb::avg: return "avg"; + case planpb::min: return "min"; + case planpb::max: return "max"; + default: return "unknown"; + } +} + std::unique_ptr ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { // TODO: add more buffs @@ -156,7 +168,10 @@ ProtoParser::RetrievePlanNodeFromProto( milvus::plan::GetNextPlanNodeId(), sources); node->plannodes_ = std::move(plannode); } else { + // mvccNode--->FilterBitsNode or + // aggNode--->mvccNode--->FilterBitsNode auto& query = plan_node_proto.query(); + // 1. FilterBitsNode if (query.has_predicates()) { auto& predicate_proto = query.predicates(); auto expr_parser = [&]() -> plan::PlanNodePtr { @@ -168,16 +183,46 @@ ProtoParser::RetrievePlanNodeFromProto( sources = std::vector{plannode}; } + // 2. mvccNode plannode = std::make_shared( milvus::plan::GetNextPlanNodeId(), sources); sources = std::vector{plannode}; - node->is_count_ = query.is_count(); + // 3. aggNode + /*node->is_count_ = query.is_count(); node->limit_ = query.limit(); if (node->is_count_) { plannode = std::make_shared( milvus::plan::GetNextPlanNodeId(), sources); sources = std::vector{plannode}; + }*/ + std::vector groupingKeys; + if(query.group_by_field_ids_size() > 0) { + groupingKeys.reserve(query.group_by_field_ids_size()); + for(int i = 0; i < query.group_by_field_ids_size(); i++) { + auto input_field_id = query.group_by_field_ids(i); + AssertInfo(input_field_id > 0, "input field_id to group by must be positive, but is:{}", input_field_id); + auto field_id = FieldId(input_field_id); + auto field_type = schema.FieldType(field_id); + groupingKeys.emplace_back(std::make_shared(field_type, field_id)); + } + } + std::vector aggregates; + if (query.aggregates_size() > 0) { + aggregates.reserve(query.aggregates_size()); + for(int i = 0; i < query.aggregates_size(); i++) { + auto aggregate = query.aggregates(i); + auto input_agg_field_id = aggregate.field_id(); + AssertInfo(input_agg_field_id > 0, "input field_id to aggregate must be positive, but is:{}", input_agg_field_id); + auto field_id = FieldId(input_agg_field_id); + auto field_type = schema.FieldType(field_id); + auto field_name = schema.FieldName(field_id); + auto agg_name = getAggregateOpName(aggregate.op()); + auto agg_input = std::make_shared(field_type, field_name, field_id); + auto call = std::make_shared(field_type, std::vector{agg_input}, agg_name); + aggregates.emplace_back(plan::AggregationNode::Aggregate{call}); + //check type conversion here + } } node->plannodes_ = plannode; } diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index 154191d2db51c..84e8a6becef78 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -4,6 +4,7 @@ option go_package = "github.com/milvus-io/milvus/internal/proto/internalpb"; import "common.proto"; import "schema.proto"; +import "plan.proto"; message GetTimeTickChannelRequest { } @@ -188,6 +189,9 @@ message RetrieveRequest { string username = 15; bool reduce_stop_for_best = 16; //deprecated int32 reduce_type = 17; + // for query agg + repeated int64 group_by_field_ids = 18; + repeated plan.Aggregate aggregates = 19; } @@ -201,7 +205,7 @@ message RetrieveResults { repeated string channelIDs_retrieved = 7; repeated int64 global_sealed_segmentIDs = 8; - // query request cost + // query request cost CostAggregation costAggregation = 13; int64 all_retrieve_count = 14; bool has_more_result = 15; diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 850bbd25fdfd5..4bbb148ac2213 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -50,21 +50,22 @@ const ( ) const ( - IgnoreGrowingKey = "ignore_growing" - ReduceStopForBestKey = "reduce_stop_for_best" - IteratorField = "iterator" - GroupByFieldKey = "group_by_field" - GroupSizeKey = "group_size" - GroupStrictSize = "group_strict_size" - RankGroupScorer = "rank_group_scorer" - AnnsFieldKey = "anns_field" - TopKKey = "topk" - NQKey = "nq" - MetricTypeKey = common.MetricTypeKey - SearchParamsKey = "params" - RoundDecimalKey = "round_decimal" - OffsetKey = "offset" - LimitKey = "limit" + IgnoreGrowingKey = "ignore_growing" + ReduceStopForBestKey = "reduce_stop_for_best" + IteratorField = "iterator" + GroupByFieldKey = "group_by_field" + GroupSizeKey = "group_size" + GroupStrictSize = "group_strict_size" + RankGroupScorer = "rank_group_scorer" + AnnsFieldKey = "anns_field" + TopKKey = "topk" + NQKey = "nq" + MetricTypeKey = common.MetricTypeKey + SearchParamsKey = "params" + RoundDecimalKey = "round_decimal" + OffsetKey = "offset" + LimitKey = "limit" + QueryGroupByFieldsKey = "group_by_fields" InsertTaskName = "InsertTask" CreateCollectionTaskName = "CreateCollectionTask" diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 63e5a9a67a607..077de51605297 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -79,15 +79,16 @@ type queryTask struct { } type queryParams struct { - limit int64 - offset int64 - reduceType reduce.IReduceType + limit int64 + offset int64 + reduceType reduce.IReduceType + groupByFields []string } -// translateToOutputFieldIDs translates output fields name to output fields id. -func translateToOutputFieldIDs(outputFields []string, schema *schemapb.CollectionSchema) ([]UniqueID, error) { - outputFieldIDs := make([]UniqueID, 0, len(outputFields)+1) - if len(outputFields) == 0 { +// translateToFieldIDs translates output fields name to output fields id. +func translateToFieldIDs(fieldNames []string, schema *schemapb.CollectionSchema, forQuery bool) ([]UniqueID, error) { + outputFieldIDs := make([]UniqueID, 0, len(fieldNames)+1) + if len(fieldNames) == 0 && forQuery { for _, field := range schema.Fields { if field.FieldID >= common.StartOfUserFieldID && !typeutil.IsVectorType(field.DataType) { outputFieldIDs = append(outputFieldIDs, field.FieldID) @@ -95,15 +96,21 @@ func translateToOutputFieldIDs(outputFields []string, schema *schemapb.Collectio } } else { var pkFieldID UniqueID - for _, field := range schema.Fields { - if field.IsPrimaryKey { - pkFieldID = field.FieldID + if forQuery { + for _, field := range schema.Fields { + if field.IsPrimaryKey { + pkFieldID = field.FieldID + } } } - for _, reqField := range outputFields { + var pkAdded bool + for _, reqField := range fieldNames { var fieldFound bool for _, field := range schema.Fields { if reqField == field.Name { + if field.IsPrimaryKey { + pkAdded = true + } outputFieldIDs = append(outputFieldIDs, field.FieldID) fieldFound = true break @@ -115,15 +122,7 @@ func translateToOutputFieldIDs(outputFields []string, schema *schemapb.Collectio } // pk field needs to be in output field list - var pkFound bool - for _, outputField := range outputFieldIDs { - if outputField == pkFieldID { - pkFound = true - break - } - } - - if !pkFound { + if !pkAdded && forQuery { outputFieldIDs = append(outputFieldIDs, pkFieldID) } } @@ -188,9 +187,9 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr) } + // parse offset offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, queryParamsPair) - // if offset is provided - if err == nil { + if err == nil { // if offset is provided offset, err = strconv.ParseInt(offsetStr, 0, 64) if err != nil { return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr) @@ -202,10 +201,21 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e return nil, fmt.Errorf("invalid max query result window, %w", err) } + // parse group by fields + groupByFieldsStr, err := funcutil.GetAttrByKeyFromRepeatedKV(QueryGroupByFieldsKey, queryParamsPair) + groupByFields := make([]string, 0) + if err == nil { + fields := strings.Split(groupByFieldsStr, ",") + for _, field := range fields { + groupByFields = append(groupByFields, field) + } + } + return &queryParams{ - limit: limit, - offset: offset, - reduceType: reduceType, + limit: limit, + offset: offset, + reduceType: reduceType, + groupByFields: groupByFields, }, nil } @@ -235,19 +245,6 @@ func createCntPlan(expr string, schemaHelper *typeutil.SchemaHelper) (*planpb.Pl return plan, nil } -func (t *queryTask) organizeAggregates(userAggregates []agg.AggregateBase) map[agg.AggID]agg.AggregateBase { - realAggregates := make(map[agg.AggID]agg.AggregateBase, 0) - for _, userAgg := range userAggregates { - subAggs := userAgg.Decompose() - for _, subAgg := range subAggs { - if _, exist := realAggregates[subAgg.ID()]; !exist { - realAggregates[subAgg.ID()] = subAgg - } - } - } - return realAggregates -} - func (t *queryTask) createPlan(ctx context.Context) error { schema := t.schema t.plan = &planpb.PlanNode{ @@ -258,42 +255,20 @@ func (t *queryTask) createPlan(ctx context.Context) error { if t.request.GetExpr() != "" { expr, err := planparserv2.ParseExpr(schema.schemaHelper, t.request.Expr) if err != nil { - return err + return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err)) } t.plan.GetQuery().Predicates = expr } var err error - var userAggregates []agg.AggregateBase t.request.OutputFields, t.userOutputFields, t.userDynamicFields, t.userAggregates, err = translateOutputFields(t.request.OutputFields, t.schema, true) if err != nil { return err } - outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema.CollectionSchema) - t.internalAggregates = t.organizeAggregates(t.userAggregates) + t.internalAggregates = agg.OrganizeAggregates(t.userAggregates) + t.plan.GetQuery().Aggregates = agg.AggregatesToPB(t.internalAggregates) + //t.RetrieveRequest. - /* schema := t.schema - cntMatch := matchCountRule(t.request.GetOutputFields()) - if cntMatch { - var err error - t.plan, err = createCntPlan(t.request.GetExpr(), schema.schemaHelper) - t.userOutputFields = []string{"count(*)"} - return err - } - - var err error - if t.plan == nil { - t.plan, err = planparserv2.CreateRetrievePlan(schema.schemaHelper, t.request.Expr) - if err != nil { - return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err)) - } - } - - t.request.OutputFields, t.userOutputFields, t.userDynamicFields, err = translateOutputFields(t.request.OutputFields, t.schema, true) - if err != nil { - return err - } - - outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema.CollectionSchema) + outputFieldIDs, err := translateToFieldIDs(t.request.GetOutputFields(), schema.CollectionSchema, true) if err != nil { return err } @@ -303,8 +278,7 @@ func (t *queryTask) createPlan(ctx context.Context) error { t.plan.DynamicFields = t.userDynamicFields log.Ctx(ctx).Debug("translate output fields to field ids", zap.Int64s("OutputFieldsID", t.OutputFieldsId), - zap.String("requestType", "query"))*/ - + zap.String("requestType", "query")) return nil } @@ -429,7 +403,13 @@ func (t *queryTask) PreExecute(ctx context.Context) error { if err := t.createPlan(ctx); err != nil { return err } - t.plan.Node.(*planpb.PlanNode_Query).Query.Limit = t.RetrieveRequest.Limit + t.plan.GetQuery().Limit = t.RetrieveRequest.Limit + groupByFieldsIDs, err := translateToFieldIDs(t.queryParams.groupByFields, t.schema.CollectionSchema, false) + if err != nil { + return err + } + t.plan.GetQuery().GroupByFieldIds = groupByFieldsIDs + t.RetrieveRequest.GroupByFieldIds = groupByFieldsIDs if planparserv2.IsAlwaysTruePlan(t.plan) && t.RetrieveRequest.Limit == typeutil.Unlimited { return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("empty expression should be used with limit")) diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 9f2ec742ef9be..e927252ff1d40 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -382,7 +382,7 @@ func Test_translateToOutputFieldIDs(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ids, err := translateToOutputFieldIDs(tc.outputFields, tc.schema) + ids, err := translateToFieldIDs(tc.outputFields, tc.schema, true) if tc.expectedError { assert.Error(t, err) } else { diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 9bf90be5271e9..5e2a9b23b8d1d 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -156,7 +156,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { } } - t.request.OutputFields, t.userOutputFields, t.userDynamicFields, err = translateOutputFields(t.request.OutputFields, t.schema, false) + t.request.OutputFields, t.userOutputFields, t.userDynamicFields, _, err = translateOutputFields(t.request.OutputFields, t.schema, false) if err != nil { log.Warn("translate output fields failed", zap.Error(err)) return err diff --git a/internal/querynodev2/segments/plan.go b/internal/querynodev2/segments/plan.go index ff94ac63c9d93..4e055791b2b85 100644 --- a/internal/querynodev2/segments/plan.go +++ b/internal/querynodev2/segments/plan.go @@ -175,7 +175,7 @@ type RetrievePlan struct { ignoreNonPk bool } -func NewRetrievePlan(ctx context.Context, col *Collection, expr []byte, timestamp Timestamp, msgID UniqueID) (*RetrievePlan, error) { +func NewRetrievePlan(ctx context.Context, col *Collection, planBytes []byte, timestamp Timestamp, msgID UniqueID) (*RetrievePlan, error) { col.mu.RLock() defer col.mu.RUnlock() @@ -184,7 +184,7 @@ func NewRetrievePlan(ctx context.Context, col *Collection, expr []byte, timestam } var cPlan C.CRetrievePlan - status := C.CreateRetrievePlanByExpr(col.collectionPtr, unsafe.Pointer(&expr[0]), (C.int64_t)(len(expr)), &cPlan) + status := C.CreateRetrievePlanByExpr(col.collectionPtr, unsafe.Pointer(&planBytes[0]), (C.int64_t)(len(planBytes)), &cPlan) err := HandleCStatus(ctx, &status, "Create retrieve plan by expr failed") if err != nil {