Skip to content

Commit

Permalink
coding in the secore
Browse files Browse the repository at this point in the history
Signed-off-by: MrPresent-Han <chun.han@gmail.com>
  • Loading branch information
MrPresent-Han committed Oct 15, 2024
1 parent ae6843e commit 66c555f
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 101 deletions.
23 changes: 23 additions & 0 deletions internal/agg/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,26 @@ func (max *MaxAggregate) Reduce() {
func (max *MaxAggregate) ToPB() *planpb.Aggregate {
return &planpb.Aggregate{Op: planpb.AggregateOp_max, FieldId: max.aggFieldID()}
}

func OrganizeAggregates(userAggregates []AggregateBase) map[AggID]AggregateBase {
realAggregates := make(map[AggID]AggregateBase, 0)
for _, userAgg := range userAggregates {
subAggs := userAgg.Decompose()
for _, subAgg := range subAggs {
if _, exist := realAggregates[subAgg.ID()]; !exist {
realAggregates[subAgg.ID()] = subAgg
}
}
}
return realAggregates
}

func AggregatesToPB(aggregates map[AggID]AggregateBase) []*planpb.Aggregate {
ret := make([]*planpb.Aggregate, len(aggregates))
if aggregates != nil {
for idx, agg := range aggregates {
ret[idx] = agg.ToPB()
}
}
return ret
}
17 changes: 16 additions & 1 deletion internal/core/src/common/Schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class Schema {

FieldId
get_field_id(const FieldName& field_name) const {
AssertInfo(name_ids_.count(field_name), "Cannot find field_name");
AssertInfo(name_ids_.count(field_name), "Cannot find field_name:{}", field_name.get());
return name_ids_.at(field_name);
}

Expand Down Expand Up @@ -232,6 +232,21 @@ class Schema {
field_ids_.emplace_back(field_id);
}

DataType
FieldType(const FieldId& field_id) const {
AssertInfo(fields_.count(field_id), "field_id:{} is not existed in the schema", field_id.get());
auto& meta = fields_.at(field_id);
return meta.get_data_type();
}

const std::string&
FieldName(const FieldId& field_id) const {
AssertInfo(fields_.count(field_id), "field_id:{} is not existed in the schema", field_id.get());
auto& meta = fields_.at(field_id);
return meta.get_name().get();
}


private:
int64_t debug_id = START_USER_FIELDID;
std::vector<FieldId> field_ids_;
Expand Down
15 changes: 10 additions & 5 deletions internal/core/src/expr/ITypeExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,18 @@ using CallTypeExprPtr = std::shared_ptr<const CallTypeExpr>;

class FieldAccessTypeExpr : public ITypeExpr {
public:
FieldAccessTypeExpr(DataType type, const std::string& name)
: ITypeExpr{type}, name_(name), is_input_column_(true) {
FieldAccessTypeExpr(DataType type, const std::string& name, FieldId fieldId)
: ITypeExpr{type}, name_(name), field_id_(fieldId), is_input_column_(true) {
}

FieldAccessTypeExpr(DataType type, FieldId fieldId)
: ITypeExpr{type}, name_(""), field_id_(fieldId), is_input_column_(true){}

FieldAccessTypeExpr(DataType type,
const TypedExprPtr& input,
const std::string& name)
: ITypeExpr{type, {std::move(input)}}, name_(name) {
const std::string& name,
FieldId fieldId)
: ITypeExpr{type, {std::move(input)}}, name_(name), field_id_(fieldId) {
is_input_column_ =
dynamic_cast<const InputTypeExpr*>(inputs_[0].get()) != nullptr;
}
Expand All @@ -285,11 +289,12 @@ class FieldAccessTypeExpr : public ITypeExpr {
return fmt::format("{}", name_);
}

return fmt::format("{}[{}]", inputs_[0]->ToString(), name_);
return fmt::format("{}[{}{}]", inputs_[0]->ToString(), name_, field_id_.get());
}

private:
std::string name_;
const FieldId field_id_;
bool is_input_column_;
};

Expand Down
19 changes: 13 additions & 6 deletions internal/core/src/plan/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "common/EasyAssert.h"
#include "segcore/SegmentInterface.h"
#include "plan/PlanNodeIdGenerator.h"
#include "pb/plan.pb.h"

namespace milvus {
namespace plan {
Expand Down Expand Up @@ -414,18 +415,24 @@ class AggregationNode: public PlanNode {

struct Aggregate {
/// Function name and input column names.
expr::CallTypeExprPtr call;

/// Raw input types used to properly identify aggregate function. These
/// might be different from the input types specified in 'call' when
/// aggregation step is kIntermediate or kFinal.
std::vector<DataType> rawInputTypes;
expr::CallTypeExprPtr call_;
public:
Aggregate(expr::CallTypeExprPtr call):call_(call){}
};

std::vector<PlanNodePtr> sources() const override {
return sources_;
}

AggregationNode(const PlanNodeId& id,
std::vector<expr::FieldAccessTypeExprPtr>&& groupingKeys,
std::vector<std::string>&& aggNames,
std::vector<Aggregate>&& aggregates,
std::vector<PlanNodePtr>&& sources,
RowType&& output_type)
: PlanNode(id), groupingKeys_(std::move(groupingKeys)), aggregateNames_(std::move(aggNames)), aggregates_(std::move(aggregates)),
sources_(std::move(sources)), output_type_(std::move(output_type)), ignoreNullKeys_(true){}

private:
const std::vector<expr::FieldAccessTypeExprPtr> groupingKeys_;
const std::vector<std::string> aggregateNames_;
Expand Down
47 changes: 46 additions & 1 deletion internal/core/src/query/PlanProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,22 @@
#include "query/Utils.h"
#include "knowhere/comp/materialized_view.h"
#include "plan/PlanNode.h"
#include "expr/ITypeExpr.h"

namespace milvus::query {
namespace planpb = milvus::proto::plan;

std::string getAggregateOpName(planpb::AggregateOp op) {
switch (op) {
case planpb::sum: return "sum";
case planpb::count: return "count";
case planpb::avg: return "avg";
case planpb::min: return "min";
case planpb::max: return "max";
default: return "unknown";
}
}

std::unique_ptr<VectorPlanNode>
ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
// TODO: add more buffs
Expand Down Expand Up @@ -156,7 +168,10 @@ ProtoParser::RetrievePlanNodeFromProto(
milvus::plan::GetNextPlanNodeId(), sources);
node->plannodes_ = std::move(plannode);
} else {
// mvccNode--->FilterBitsNode or
// aggNode--->mvccNode--->FilterBitsNode
auto& query = plan_node_proto.query();
// 1. FilterBitsNode
if (query.has_predicates()) {
auto& predicate_proto = query.predicates();
auto expr_parser = [&]() -> plan::PlanNodePtr {
Expand All @@ -168,16 +183,46 @@ ProtoParser::RetrievePlanNodeFromProto(
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}

// 2. mvccNode
plannode = std::make_shared<milvus::plan::MvccNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};

node->is_count_ = query.is_count();
// 3. aggNode
/*node->is_count_ = query.is_count();
node->limit_ = query.limit();
if (node->is_count_) {
plannode = std::make_shared<milvus::plan::CountNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}*/
std::vector<expr::FieldAccessTypeExprPtr> groupingKeys;
if(query.group_by_field_ids_size() > 0) {
groupingKeys.reserve(query.group_by_field_ids_size());
for(int i = 0; i < query.group_by_field_ids_size(); i++) {
auto input_field_id = query.group_by_field_ids(i);
AssertInfo(input_field_id > 0, "input field_id to group by must be positive, but is:{}", input_field_id);
auto field_id = FieldId(input_field_id);
auto field_type = schema.FieldType(field_id);
groupingKeys.emplace_back(std::make_shared<const expr::FieldAccessTypeExpr>(field_type, field_id));
}
}
std::vector<plan::AggregationNode::Aggregate> aggregates;
if (query.aggregates_size() > 0) {
aggregates.reserve(query.aggregates_size());
for(int i = 0; i < query.aggregates_size(); i++) {
auto aggregate = query.aggregates(i);
auto input_agg_field_id = aggregate.field_id();
AssertInfo(input_agg_field_id > 0, "input field_id to aggregate must be positive, but is:{}", input_agg_field_id);
auto field_id = FieldId(input_agg_field_id);
auto field_type = schema.FieldType(field_id);
auto field_name = schema.FieldName(field_id);
auto agg_name = getAggregateOpName(aggregate.op());
auto agg_input = std::make_shared<expr::FieldAccessTypeExpr>(field_type, field_name, field_id);
auto call = std::make_shared<const expr::CallTypeExpr>(field_type, std::vector<expr::TypedExprPtr>{agg_input}, agg_name);
aggregates.emplace_back(plan::AggregationNode::Aggregate{call});
//check type conversion here
}
}
node->plannodes_ = plannode;
}
Expand Down
6 changes: 5 additions & 1 deletion internal/proto/internal.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ option go_package = "github.com/milvus-io/milvus/internal/proto/internalpb";

import "common.proto";
import "schema.proto";
import "plan.proto";

message GetTimeTickChannelRequest {
}
Expand Down Expand Up @@ -188,6 +189,9 @@ message RetrieveRequest {
string username = 15;
bool reduce_stop_for_best = 16; //deprecated
int32 reduce_type = 17;
// for query agg
repeated int64 group_by_field_ids = 18;
repeated plan.Aggregate aggregates = 19;
}


Expand All @@ -201,7 +205,7 @@ message RetrieveResults {
repeated string channelIDs_retrieved = 7;
repeated int64 global_sealed_segmentIDs = 8;

// query request cost
// query request cost
CostAggregation costAggregation = 13;
int64 all_retrieve_count = 14;
bool has_more_result = 15;
Expand Down
31 changes: 16 additions & 15 deletions internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,22 @@ const (
)

const (
IgnoreGrowingKey = "ignore_growing"
ReduceStopForBestKey = "reduce_stop_for_best"
IteratorField = "iterator"
GroupByFieldKey = "group_by_field"
GroupSizeKey = "group_size"
GroupStrictSize = "group_strict_size"
RankGroupScorer = "rank_group_scorer"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
NQKey = "nq"
MetricTypeKey = common.MetricTypeKey
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"
OffsetKey = "offset"
LimitKey = "limit"
IgnoreGrowingKey = "ignore_growing"
ReduceStopForBestKey = "reduce_stop_for_best"
IteratorField = "iterator"
GroupByFieldKey = "group_by_field"
GroupSizeKey = "group_size"
GroupStrictSize = "group_strict_size"
RankGroupScorer = "rank_group_scorer"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
NQKey = "nq"
MetricTypeKey = common.MetricTypeKey
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"
OffsetKey = "offset"
LimitKey = "limit"
QueryGroupByFieldsKey = "group_by_fields"

InsertTaskName = "InsertTask"
CreateCollectionTaskName = "CreateCollectionTask"
Expand Down
Loading

0 comments on commit 66c555f

Please sign in to comment.