Skip to content

Commit

Permalink
add agg function registration
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPresent-Han committed Nov 12, 2024
1 parent 3c22f65 commit f2cb481
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 1 deletion.
11 changes: 11 additions & 0 deletions internal/core/src/common/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,4 +317,15 @@ int comparePrimitiveAsc(const T& left, const T& right) {
}
return left < right ? -1 : left == right ? 0 : 1;
}

inline std::string lowerString(const std::string& str) {
std::string ret;
ret.reserve(str.size());
std::transform(str.begin(), str.end(), ret.begin(), [](unsigned char c) {
return std::tolower(c);
});
return ret;
}

} // namespace milvus

28 changes: 28 additions & 0 deletions internal/core/src/exec/operator/query-agg/Aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Created by hanchun on 24-10-22.
//
#include "Aggregate.h"
#include "AggregateUtil.h"

namespace milvus{
namespace exec{
Expand Down Expand Up @@ -37,6 +38,33 @@ bool isPartialOutput(milvus::plan::AggregationNode::Step step) {
step == milvus::plan::AggregationNode::Step::kIntermediate;
}

AggregateRegistrationResult registerAggregateFunction(const std::string& name,
const std::vector<std::shared_ptr<expr::AggregateFunctionSignature>>& signatures,
const AggregateFunctionFactory& factory,
bool registerCompanionFunctions,
bool overwrite){
auto realName = lowerString(name);
AggregateRegistrationResult registered;
if (overwrite) {
aggregateFunctions().withWLock([&](auto& aggFunctionMap){
aggFunctionMap[name] = {signatures, factory};
});
registered.mainFunction = true;
} else {
auto inserted = aggregateFunctions().withWLock([&](auto& aggFunctionMap){
auto [_, func_inserted] = aggFunctionMap.insert({name, {signatures, factory}});
return func_inserted;
});
registered.mainFunction = inserted;
}
return registered;
}

AggregateFunctionMap& aggregateFunctions() {
static AggregateFunctionMap aggFunctionMap;
return aggFunctionMap;
}

}
}

Expand Down
31 changes: 31 additions & 0 deletions internal/core/src/exec/operator/query-agg/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@

#include "common/Types.h"
#include "plan/PlanNode.h"
#include "AggregateUtil.h"
#include "expr/FunctionSignature.h"
#include "plan/PlanNode.h"
#include "exec/QueryContext.h"
#include <folly/Synchronized.h>


namespace milvus{
namespace exec{
Expand Down Expand Up @@ -124,6 +130,31 @@ class Aggregate {
virtual void initializeNewGroupsInternal(char** groups, folly::Range<const vector_size_t*> indices) = 0;
};

using AggregateFunctionFactory = std::function<std::unique_ptr<Aggregate>(plan::AggregationNode::Step step,
const std::vector<DataType>& argTypes,
DataType resultType,
const QueryConfig& config)>;

struct AggregateFunctionEntry {
std::vector<expr::AggregateFunctionSignaturePtr> signatures;
AggregateFunctionFactory factory;
};

using AggregateFunctionMap = folly::Synchronized<std::unordered_map<std::string, AggregateFunctionEntry>>;

AggregateFunctionMap& aggregateFunctions();

/// Register an aggregate function with the specified name and signatures. If
/// registerCompanionFunctions is true, also register companion aggregate and
/// scalar functions with it. When functions with `name` already exist, if
/// overwrite is true, existing registration will be replaced. Otherwise, return
/// false without overwriting the registry.
AggregateRegistrationResult registerAggregateFunction(const std::string& name,
const std::vector<std::shared_ptr<expr::AggregateFunctionSignature>>& signatures,
const AggregateFunctionFactory& factory,
bool registerCompanionFunctions,
bool overwrite);

bool isRawInput(milvus::plan::AggregationNode::Step step);

bool isPartialOutput(milvus::plan::AggregationNode::Step step);
Expand Down
29 changes: 29 additions & 0 deletions internal/core/src/exec/operator/query-agg/SumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,35 @@ AggregateRegistrationResult registerSum(
.intermediateType(DataType::INT64)
.intermediateType(DataType::INT64).build());
}
return exec::registerAggregateFunction(name,
signatures,
[name](plan::AggregationNode::Step step,
const std::vector<DataType>& argumentTypes,
DataType resultType,
const QueryConfig& config)->std::unique_ptr<Aggregate>{
AssertInfo(argumentTypes.size()==1, "function:{} only accept one argument", name);
auto inputType = argumentTypes[0];
switch (inputType) {
case DataType::INT8:
return std::make_unique<T<int8_t, int64_t, int64_t>>(DataType::INT64);
case DataType::INT16:
return std::make_unique<T<int16_t, int64_t, int64_t>>(DataType::INT64);
case DataType::INT32:
return std::make_unique<T<int32_t, int64_t, int64_t>>(DataType::INT64);
case DataType::INT64:
return std::make_unique<T<int64_t, int64_t, int64_t>>(DataType::INT64);
case DataType::DOUBLE:
return std::make_unique<T<double, double, double>>(DataType::DOUBLE);
case DataType::FLOAT:
return std::make_unique<T<float, double, double>>(DataType::DOUBLE);
default:
PanicInfo(DataTypeInvalid, "Unknown input type for {} aggregation {}",
name,
GetDataTypeName(inputType));
}
},
withCompanionFunctions,
overwrite);
};


Expand Down
6 changes: 5 additions & 1 deletion internal/core/src/exec/operator/query-agg/SumAggregateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ class SumAggregateBase: public SimpleNumericAggregate<TInput, TAccumulator, Resu
return 1;
}

void extractValues(char** groups, int32_t numGroups, VectorPtr* result) {
void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override{
BaseAggregate::template doExtractValues<TAccumulator>(
groups, numGroups, result, [&](char* group) {
return (ResultType)(*BaseAggregate::Aggregate::template value<TAccumulator>(group));
});
}

void initializeNewGroupsInternal(char** groups, folly::Range<const vector_size_t*> indices) override {

}
};
}
}
2 changes: 2 additions & 0 deletions internal/core/src/expr/FunctionSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class AggregateFunctionSignature : public FunctionSignature {
DataType intermediateType_;
};

using AggregateFunctionSignaturePtr = std::shared_ptr<AggregateFunctionSignature>;

class AggregateFunctionSignatureBuilder {
public:
AggregateFunctionSignatureBuilder& returnType(DataType returnType) {
Expand Down

0 comments on commit f2cb481

Please sign in to comment.