-
Couldn't load subscription status.
- Fork 176
Add udf interface #3374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add udf interface #3374
Changes from all commits
4cbf2ca
89207a1
28869cf
d1f8f11
2464c85
df6728a
958cac9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.calcite.udf; | ||
|
|
||
| public interface UserDefinedAggFunction<S extends UserDefinedAggFunction.Accumulator> { | ||
| /** | ||
| * @return {@link Accumulator} | ||
| */ | ||
| S init(); | ||
|
|
||
| /** | ||
| * @param {@link Accumulator} | ||
| * @return final result | ||
| */ | ||
| Object result(S accumulator); | ||
|
|
||
| /** | ||
| * Add values to the accumulator. Notice some init argument will also be here like the 50 in | ||
| * Percentile(field, 50). | ||
| * | ||
| * @param acc {@link Accumulator} | ||
| * @param values the value to add to accumulator | ||
| * @return {@link Accumulator} | ||
| */ | ||
| S add(S acc, Object... values); | ||
|
|
||
| /** Maintain the state when {@link UserDefinedAggFunction} add values */ | ||
| interface Accumulator { | ||
| /** | ||
| * @return the final aggregation value | ||
| */ | ||
| Object value(); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.calcite.udf; | ||
|
|
||
| public interface UserDefinedFunction { | ||
| Object eval(Object... args); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comments. |
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.calcite.udf.mathUDF; | ||
|
|
||
| import static java.lang.Math.sqrt; | ||
|
|
||
| import org.opensearch.sql.calcite.udf.UserDefinedFunction; | ||
|
|
||
| public class SqrtFunction implements UserDefinedFunction { | ||
| @Override | ||
| public Object eval(Object... args) { | ||
| if (args.length < 1) { | ||
| throw new IllegalArgumentException("At least one argument is required"); | ||
| } | ||
|
|
||
| // Get the input value | ||
| Object input = args[0]; | ||
|
|
||
| // Handle numbers dynamically | ||
| if (input instanceof Number) { | ||
| double num = ((Number) input).doubleValue(); | ||
|
|
||
| if (num < 0) { | ||
| throw new ArithmeticException("Cannot compute square root of a negative number"); | ||
| } | ||
|
|
||
| return sqrt(num); // Computes sqrt using Math.sqrt() | ||
| } else { | ||
| throw new IllegalArgumentException("Invalid argument type: Expected a numeric value"); | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.calcite.udf.udaf; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copyright header missing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
|
||
| import java.util.ArrayList; | ||
| import java.util.List; | ||
| import org.opensearch.sql.calcite.udf.UserDefinedAggFunction; | ||
|
|
||
| public class TakeAggFunction implements UserDefinedAggFunction<TakeAggFunction.TakeAccumulator> { | ||
|
|
||
| @Override | ||
| public TakeAccumulator init() { | ||
| return new TakeAccumulator(); | ||
| } | ||
|
|
||
| @Override | ||
| public Object result(TakeAccumulator accumulator) { | ||
| return accumulator.value(); | ||
| } | ||
|
|
||
| @Override | ||
| public TakeAccumulator add(TakeAccumulator acc, Object... values) { | ||
| Object candidateValue = values[0]; | ||
| int size = 0; | ||
| if (values.length > 1) { | ||
| size = (int) values[1]; | ||
| } else { | ||
| size = 10; | ||
| } | ||
| if (size > acc.size()) { | ||
| acc.add(candidateValue); | ||
| } | ||
| return acc; | ||
| } | ||
|
|
||
| public static class TakeAccumulator implements Accumulator { | ||
| private List<Object> hits; | ||
|
|
||
| public TakeAccumulator() { | ||
| hits = new ArrayList<>(); | ||
| } | ||
|
|
||
| @Override | ||
| public Object value() { | ||
| return hits; | ||
| } | ||
|
|
||
| public void add(Object value) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just compare the interface |
||
| hits.add(value); | ||
| } | ||
|
|
||
| public int size() { | ||
| return hits.size(); | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,10 +5,19 @@ | |
|
|
||
| package org.opensearch.sql.calcite.utils; | ||
|
|
||
| import static org.opensearch.sql.calcite.utils.UserDefineFunctionUtils.TransferUserDefinedFunction; | ||
|
|
||
| import java.math.BigDecimal; | ||
| import java.util.ArrayList; | ||
| import java.util.List; | ||
| import java.util.Locale; | ||
| import org.apache.calcite.rex.RexNode; | ||
| import org.apache.calcite.sql.SqlOperator; | ||
| import org.apache.calcite.sql.fun.SqlLibraryOperators; | ||
| import org.apache.calcite.sql.fun.SqlStdOperatorTable; | ||
| import org.apache.calcite.sql.type.ReturnTypes; | ||
| import org.opensearch.sql.calcite.CalcitePlanContext; | ||
| import org.opensearch.sql.calcite.udf.mathUDF.SqrtFunction; | ||
|
|
||
| public interface BuiltinFunctionUtils { | ||
|
|
||
|
|
@@ -51,6 +60,12 @@ static SqlOperator translate(String op) { | |
| // Built-in Math Functions | ||
| case "ABS": | ||
| return SqlStdOperatorTable.ABS; | ||
| case "SQRT": | ||
| return TransferUserDefinedFunction(SqrtFunction.class, "SQRT", ReturnTypes.DOUBLE); | ||
| case "ATAN", "ATAN2": | ||
| return SqlStdOperatorTable.ATAN2; | ||
| case "POW", "POWER": | ||
| return SqlStdOperatorTable.POWER; | ||
| // Built-in Date Functions | ||
| case "CURRENT_TIMESTAMP": | ||
| return SqlStdOperatorTable.CURRENT_TIMESTAMP; | ||
|
|
@@ -67,4 +82,32 @@ static SqlOperator translate(String op) { | |
| throw new IllegalArgumentException("Unsupported operator: " + op); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Translates function arguments to align with Calcite's expectations, ensuring compatibility with | ||
| * PPL (Piped Processing Language). This is necessary because Calcite's input argument order or | ||
| * default values may differ from PPL's function definitions. | ||
| * | ||
| * @param op The function name as a string. | ||
| * @param argList A list of {@link RexNode} representing the parsed arguments from the PPL | ||
| * statement. | ||
| * @param context The {@link CalcitePlanContext} providing necessary utilities such as {@code | ||
| * rexBuilder}. | ||
| * @return A modified list of {@link RexNode} that correctly maps to Calcite’s function | ||
| * expectations. | ||
| */ | ||
| static List<RexNode> translateArgument( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is a framework work, we will change this method frequently, could you add some comments to explain this method for developers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, already added. |
||
| String op, List<RexNode> argList, CalcitePlanContext context) { | ||
| switch (op.toUpperCase(Locale.ROOT)) { | ||
| case "ATAN": | ||
| List<RexNode> AtanArgs = new ArrayList<>(argList); | ||
| if (AtanArgs.size() == 1) { | ||
| BigDecimal divideNumber = BigDecimal.valueOf(1); | ||
| AtanArgs.add(context.rexBuilder.makeBigintLiteral(divideNumber)); | ||
| } | ||
| return AtanArgs; | ||
| default: | ||
| return argList; | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.calcite.utils; | ||
|
|
||
| import static org.apache.calcite.sql.type.SqlTypeUtil.createArrayType; | ||
|
|
||
| import java.util.ArrayList; | ||
| import java.util.Collections; | ||
| import java.util.List; | ||
| import org.apache.calcite.linq4j.tree.Types; | ||
| import org.apache.calcite.rel.type.RelDataType; | ||
| import org.apache.calcite.rel.type.RelDataTypeFactory; | ||
| import org.apache.calcite.rex.RexNode; | ||
| import org.apache.calcite.schema.ScalarFunction; | ||
| import org.apache.calcite.schema.impl.AggregateFunctionImpl; | ||
| import org.apache.calcite.schema.impl.ScalarFunctionImpl; | ||
| import org.apache.calcite.sql.SqlIdentifier; | ||
| import org.apache.calcite.sql.SqlKind; | ||
| import org.apache.calcite.sql.SqlOperator; | ||
| import org.apache.calcite.sql.parser.SqlParserPos; | ||
| import org.apache.calcite.sql.type.SqlReturnTypeInference; | ||
| import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction; | ||
| import org.apache.calcite.sql.validate.SqlUserDefinedFunction; | ||
| import org.apache.calcite.tools.RelBuilder; | ||
| import org.apache.calcite.util.Optionality; | ||
| import org.opensearch.sql.calcite.udf.UserDefinedAggFunction; | ||
| import org.opensearch.sql.calcite.udf.UserDefinedFunction; | ||
|
|
||
| public class UserDefineFunctionUtils { | ||
| public static RelBuilder.AggCall TransferUserDefinedAggFunction( | ||
| Class<? extends UserDefinedAggFunction> UDAF, | ||
| String functionName, | ||
| SqlReturnTypeInference returnType, | ||
| List<RexNode> fields, | ||
| List<RexNode> argList, | ||
| RelBuilder relBuilder) { | ||
| SqlUserDefinedAggFunction sqlUDAF = | ||
| new SqlUserDefinedAggFunction( | ||
| new SqlIdentifier(functionName, SqlParserPos.ZERO), | ||
| SqlKind.OTHER_FUNCTION, | ||
| returnType, | ||
| null, | ||
| null, | ||
| AggregateFunctionImpl.create(UDAF), | ||
| false, | ||
| false, | ||
| Optionality.FORBIDDEN); | ||
| List<RexNode> addArgList = new ArrayList<>(fields); | ||
| addArgList.addAll(argList); | ||
| return relBuilder.aggregateCall(sqlUDAF, addArgList); | ||
| } | ||
|
|
||
| public static SqlOperator TransferUserDefinedFunction( | ||
| Class<? extends UserDefinedFunction> UDF, | ||
| String functionName, | ||
| SqlReturnTypeInference returnType) { | ||
| final ScalarFunction udfFunction = | ||
| ScalarFunctionImpl.create(Types.lookupMethod(UDF, "eval", Object[].class)); | ||
| SqlIdentifier udfLtrimIdentifier = | ||
| new SqlIdentifier(Collections.singletonList(functionName), null, SqlParserPos.ZERO, null); | ||
| return new SqlUserDefinedFunction( | ||
| udfLtrimIdentifier, SqlKind.OTHER_FUNCTION, returnType, null, null, udfFunction); | ||
| } | ||
|
|
||
| public static SqlReturnTypeInference getReturnTypeInferenceForArray() { | ||
| return opBinding -> { | ||
| RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); | ||
|
|
||
| // Get argument types | ||
| List<RelDataType> argTypes = opBinding.collectOperandTypes(); | ||
|
|
||
| if (argTypes.isEmpty()) { | ||
| throw new IllegalArgumentException("Function requires at least one argument."); | ||
| } | ||
| RelDataType firstArgType = argTypes.getFirst(); | ||
| return createArrayType(typeFactory, firstArgType, true); | ||
| }; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in UDF, the method name is
eval, should we change toevaleither?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the function name of agg function, it's defined in calcite here https://github.com/apache/calcite/blob/1793ba79a328c61fb42842f443334cc1353c985f/core/src/main/java/org/apache/calcite/schema/impl/AggregateFunctionImpl.java#L91. We cannot modify them. I will left comment.