From f2cb481b07cad803bf85ae81943a4006274303ed Mon Sep 17 00:00:00 2001 From: MrPresent-Han Date: Mon, 11 Nov 2024 22:23:18 -0500 Subject: [PATCH] add agg function registration --- internal/core/src/common/Utils.h | 11 +++++++ .../src/exec/operator/query-agg/Aggregate.cpp | 28 +++++++++++++++++ .../src/exec/operator/query-agg/Aggregate.h | 31 +++++++++++++++++++ .../exec/operator/query-agg/SumAggregate.cpp | 29 +++++++++++++++++ .../operator/query-agg/SumAggregateBase.h | 6 +++- internal/core/src/expr/FunctionSignature.h | 2 ++ 6 files changed, 106 insertions(+), 1 deletion(-) diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index cdaf7e3885304..430f70f8670bc 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -317,4 +317,15 @@ int comparePrimitiveAsc(const T& left, const T& right) { } return left < right ? -1 : left == right ? 0 : 1; } + +inline std::string lowerString(const std::string& str) { + std::string ret; + ret.reserve(str.size()); + std::transform(str.begin(), str.end(), ret.begin(), [](unsigned char c) { + return std::tolower(c); + }); + return ret; +} + } // namespace milvus + diff --git a/internal/core/src/exec/operator/query-agg/Aggregate.cpp b/internal/core/src/exec/operator/query-agg/Aggregate.cpp index fc0e7146e3c1f..6ff8eddd4b71d 100644 --- a/internal/core/src/exec/operator/query-agg/Aggregate.cpp +++ b/internal/core/src/exec/operator/query-agg/Aggregate.cpp @@ -2,6 +2,7 @@ // Created by hanchun on 24-10-22. // #include "Aggregate.h" +#include "AggregateUtil.h" namespace milvus{ namespace exec{ @@ -37,6 +38,33 @@ bool isPartialOutput(milvus::plan::AggregationNode::Step step) { step == milvus::plan::AggregationNode::Step::kIntermediate; } +AggregateRegistrationResult registerAggregateFunction(const std::string& name, + const std::vector>& signatures, + const AggregateFunctionFactory& factory, + bool registerCompanionFunctions, + bool overwrite){ + auto realName = lowerString(name); + AggregateRegistrationResult registered; + if (overwrite) { + aggregateFunctions().withWLock([&](auto& aggFunctionMap){ + aggFunctionMap[name] = {signatures, factory}; + }); + registered.mainFunction = true; + } else { + auto inserted = aggregateFunctions().withWLock([&](auto& aggFunctionMap){ + auto [_, func_inserted] = aggFunctionMap.insert({name, {signatures, factory}}); + return func_inserted; + }); + registered.mainFunction = inserted; + } + return registered; +} + +AggregateFunctionMap& aggregateFunctions() { + static AggregateFunctionMap aggFunctionMap; + return aggFunctionMap; +} + } } diff --git a/internal/core/src/exec/operator/query-agg/Aggregate.h b/internal/core/src/exec/operator/query-agg/Aggregate.h index b26afd9b005b2..aca4b75488c55 100644 --- a/internal/core/src/exec/operator/query-agg/Aggregate.h +++ b/internal/core/src/exec/operator/query-agg/Aggregate.h @@ -12,6 +12,12 @@ #include "common/Types.h" #include "plan/PlanNode.h" +#include "AggregateUtil.h" +#include "expr/FunctionSignature.h" +#include "plan/PlanNode.h" +#include "exec/QueryContext.h" +#include + namespace milvus{ namespace exec{ @@ -124,6 +130,31 @@ class Aggregate { virtual void initializeNewGroupsInternal(char** groups, folly::Range indices) = 0; }; +using AggregateFunctionFactory = std::function(plan::AggregationNode::Step step, + const std::vector& argTypes, + DataType resultType, + const QueryConfig& config)>; + +struct AggregateFunctionEntry { + std::vector signatures; + AggregateFunctionFactory factory; +}; + +using AggregateFunctionMap = folly::Synchronized>; + +AggregateFunctionMap& aggregateFunctions(); + +/// Register an aggregate function with the specified name and signatures. If +/// registerCompanionFunctions is true, also register companion aggregate and +/// scalar functions with it. When functions with `name` already exist, if +/// overwrite is true, existing registration will be replaced. Otherwise, return +/// false without overwriting the registry. +AggregateRegistrationResult registerAggregateFunction(const std::string& name, + const std::vector>& signatures, + const AggregateFunctionFactory& factory, + bool registerCompanionFunctions, + bool overwrite); + bool isRawInput(milvus::plan::AggregationNode::Step step); bool isPartialOutput(milvus::plan::AggregationNode::Step step); diff --git a/internal/core/src/exec/operator/query-agg/SumAggregate.cpp b/internal/core/src/exec/operator/query-agg/SumAggregate.cpp index bce2a5d585fd7..3fa3b0a56f26c 100644 --- a/internal/core/src/exec/operator/query-agg/SumAggregate.cpp +++ b/internal/core/src/exec/operator/query-agg/SumAggregate.cpp @@ -43,6 +43,35 @@ AggregateRegistrationResult registerSum( .intermediateType(DataType::INT64) .intermediateType(DataType::INT64).build()); } + return exec::registerAggregateFunction(name, + signatures, + [name](plan::AggregationNode::Step step, + const std::vector& argumentTypes, + DataType resultType, + const QueryConfig& config)->std::unique_ptr{ + AssertInfo(argumentTypes.size()==1, "function:{} only accept one argument", name); + auto inputType = argumentTypes[0]; + switch (inputType) { + case DataType::INT8: + return std::make_unique>(DataType::INT64); + case DataType::INT16: + return std::make_unique>(DataType::INT64); + case DataType::INT32: + return std::make_unique>(DataType::INT64); + case DataType::INT64: + return std::make_unique>(DataType::INT64); + case DataType::DOUBLE: + return std::make_unique>(DataType::DOUBLE); + case DataType::FLOAT: + return std::make_unique>(DataType::DOUBLE); + default: + PanicInfo(DataTypeInvalid, "Unknown input type for {} aggregation {}", + name, + GetDataTypeName(inputType)); + } + }, + withCompanionFunctions, + overwrite); }; diff --git a/internal/core/src/exec/operator/query-agg/SumAggregateBase.h b/internal/core/src/exec/operator/query-agg/SumAggregateBase.h index 7640f45b0cdbe..6d1e82a9cc727 100644 --- a/internal/core/src/exec/operator/query-agg/SumAggregateBase.h +++ b/internal/core/src/exec/operator/query-agg/SumAggregateBase.h @@ -31,12 +31,16 @@ class SumAggregateBase: public SimpleNumericAggregate( groups, numGroups, result, [&](char* group) { return (ResultType)(*BaseAggregate::Aggregate::template value(group)); }); } + + void initializeNewGroupsInternal(char** groups, folly::Range indices) override { + + } }; } } diff --git a/internal/core/src/expr/FunctionSignature.h b/internal/core/src/expr/FunctionSignature.h index acdb8a65fbf1f..c320524e88346 100644 --- a/internal/core/src/expr/FunctionSignature.h +++ b/internal/core/src/expr/FunctionSignature.h @@ -51,6 +51,8 @@ class AggregateFunctionSignature : public FunctionSignature { DataType intermediateType_; }; +using AggregateFunctionSignaturePtr = std::shared_ptr; + class AggregateFunctionSignatureBuilder { public: AggregateFunctionSignatureBuilder& returnType(DataType returnType) {