Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.calcite;

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder.AggCall;
import org.opensearch.sql.ast.AbstractNodeVisitor;
Expand Down Expand Up @@ -33,6 +35,10 @@ public AggCall visitAlias(Alias node, CalcitePlanContext context) {
@Override
public AggCall visitAggregateFunction(AggregateFunction node, CalcitePlanContext context) {
RexNode field = rexNodeVisitor.analyze(node.getField(), context);
return AggregateUtils.translate(node, field, context);
List<RexNode> argList = new ArrayList<>();
for (UnresolvedExpression arg : node.getArgList()) {
argList.add(rexNodeVisitor.analyze(arg, context));
}
return AggregateUtils.translate(node, field, context, argList);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.sql.ast.expression.SpanUnit.NONE;
import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN;
import static org.opensearch.sql.calcite.utils.BuiltinFunctionUtils.translateArgument;

import java.math.BigDecimal;
import java.util.List;
Expand Down Expand Up @@ -254,7 +255,8 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
List<RexNode> arguments =
node.getFuncArgs().stream().map(arg -> analyze(arg, context)).collect(Collectors.toList());
return context.rexBuilder.makeCall(
BuiltinFunctionUtils.translate(node.getFuncName()), arguments);
BuiltinFunctionUtils.translate(node.getFuncName()),
translateArgument(node.getFuncName(), arguments, context));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.udf;

public interface UserDefinedAggFunction<S extends UserDefinedAggFunction.Accumulator> {
/**
* @return {@link Accumulator}
*/
S init();

/**
* @param {@link Accumulator}
* @return final result
*/
Object result(S accumulator);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in UDF, the method name is eval, should we change to eval either?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the function name of agg function, it's defined in calcite here https://github.com/apache/calcite/blob/1793ba79a328c61fb42842f443334cc1353c985f/core/src/main/java/org/apache/calcite/schema/impl/AggregateFunctionImpl.java#L91. We cannot modify them. I will left comment.


/**
* Add values to the accumulator. Notice some init argument will also be here like the 50 in
* Percentile(field, 50).
*
* @param acc {@link Accumulator}
* @param values the value to add to accumulator
* @return {@link Accumulator}
*/
S add(S acc, Object... values);

/** Maintain the state when {@link UserDefinedAggFunction} add values */
interface Accumulator {
/**
* @return the final aggregation value
*/
Object value();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.udf;

public interface UserDefinedFunction {
Object eval(Object... args);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments.

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.udf.mathUDF;

import static java.lang.Math.sqrt;

import org.opensearch.sql.calcite.udf.UserDefinedFunction;

public class SqrtFunction implements UserDefinedFunction {
@Override
public Object eval(Object... args) {
if (args.length < 1) {
throw new IllegalArgumentException("At least one argument is required");
}

// Get the input value
Object input = args[0];

// Handle numbers dynamically
if (input instanceof Number) {
double num = ((Number) input).doubleValue();

if (num < 0) {
throw new ArithmeticException("Cannot compute square root of a negative number");
}

return sqrt(num); // Computes sqrt using Math.sqrt()
} else {
throw new IllegalArgumentException("Invalid argument type: Expected a numeric value");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.udf.udaf;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copyright header missing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


import java.util.ArrayList;
import java.util.List;
import org.opensearch.sql.calcite.udf.UserDefinedAggFunction;

public class TakeAggFunction implements UserDefinedAggFunction<TakeAggFunction.TakeAccumulator> {

@Override
public TakeAccumulator init() {
return new TakeAccumulator();
}

@Override
public Object result(TakeAccumulator accumulator) {
return accumulator.value();
}

@Override
public TakeAccumulator add(TakeAccumulator acc, Object... values) {
Object candidateValue = values[0];
int size = 0;
if (values.length > 1) {
size = (int) values[1];
} else {
size = 10;
}
if (size > acc.size()) {
acc.add(candidateValue);
}
return acc;
}

public static class TakeAccumulator implements Accumulator {
private List<Object> hits;

public TakeAccumulator() {
hits = new ArrayList<>();
}

@Override
public Object value() {
return hits;
}

public void add(Object value) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add be a part of interface method signature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just compare the interface AggregationState. I think we could just make sure we have add in UserDefinedAggFunction. For Accumulator, they can implement their own functions.

hits.add(value);
}

public int size() {
return hits.size();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

package org.opensearch.sql.calcite.utils;

import static org.opensearch.sql.calcite.utils.UserDefineFunctionUtils.TransferUserDefinedAggFunction;

import com.google.common.collect.ImmutableList;
import java.util.List;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rex.RexInputRef;
Expand All @@ -15,12 +18,13 @@
import org.apache.calcite.tools.RelBuilder;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.calcite.udf.udaf.TakeAggFunction;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

public interface AggregateUtils {

static RelBuilder.AggCall translate(
AggregateFunction agg, RexNode field, CalcitePlanContext context) {
AggregateFunction agg, RexNode field, CalcitePlanContext context, List<RexNode> argList) {
if (BuiltinFunctionName.ofAggregation(agg.getFuncName()).isEmpty())
throw new IllegalStateException("Unexpected value: " + agg.getFuncName());

Expand Down Expand Up @@ -50,6 +54,14 @@ static RelBuilder.AggCall translate(
// case PERCENTILE_APPROX:
// return
// context.relBuilder.aggregateCall(SqlStdOperatorTable.PERCENTILE_CONT, field);
case TAKE:
return TransferUserDefinedAggFunction(
TakeAggFunction.class,
"TAKE",
UserDefineFunctionUtils.getReturnTypeInferenceForArray(),
List.of(field),
argList,
context.relBuilder);
case PERCENTILE_APPROX:
throw new UnsupportedOperationException("PERCENTILE_APPROX is not supported in PPL");
// case APPROX_COUNT_DISTINCT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@

package org.opensearch.sql.calcite.utils;

import static org.opensearch.sql.calcite.utils.UserDefineFunctionUtils.TransferUserDefinedFunction;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlLibraryOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.ReturnTypes;
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.calcite.udf.mathUDF.SqrtFunction;

public interface BuiltinFunctionUtils {

Expand Down Expand Up @@ -51,6 +60,12 @@ static SqlOperator translate(String op) {
// Built-in Math Functions
case "ABS":
return SqlStdOperatorTable.ABS;
case "SQRT":
return TransferUserDefinedFunction(SqrtFunction.class, "SQRT", ReturnTypes.DOUBLE);
case "ATAN", "ATAN2":
return SqlStdOperatorTable.ATAN2;
case "POW", "POWER":
return SqlStdOperatorTable.POWER;
// Built-in Date Functions
case "CURRENT_TIMESTAMP":
return SqlStdOperatorTable.CURRENT_TIMESTAMP;
Expand All @@ -67,4 +82,32 @@ static SqlOperator translate(String op) {
throw new IllegalArgumentException("Unsupported operator: " + op);
}
}

/**
* Translates function arguments to align with Calcite's expectations, ensuring compatibility with
* PPL (Piped Processing Language). This is necessary because Calcite's input argument order or
* default values may differ from PPL's function definitions.
*
* @param op The function name as a string.
* @param argList A list of {@link RexNode} representing the parsed arguments from the PPL
* statement.
* @param context The {@link CalcitePlanContext} providing necessary utilities such as {@code
* rexBuilder}.
* @return A modified list of {@link RexNode} that correctly maps to Calcite’s function
* expectations.
*/
static List<RexNode> translateArgument(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a framework work, we will change this method frequently, could you add some comments to explain this method for developers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, already added.

String op, List<RexNode> argList, CalcitePlanContext context) {
switch (op.toUpperCase(Locale.ROOT)) {
case "ATAN":
List<RexNode> AtanArgs = new ArrayList<>(argList);
if (AtanArgs.size() == 1) {
BigDecimal divideNumber = BigDecimal.valueOf(1);
AtanArgs.add(context.rexBuilder.makeBigintLiteral(divideNumber));
}
return AtanArgs;
default:
return argList;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.utils;

import static org.apache.calcite.sql.type.SqlTypeUtil.createArrayType;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.ScalarFunction;
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
import org.apache.calcite.schema.impl.ScalarFunctionImpl;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Optionality;
import org.opensearch.sql.calcite.udf.UserDefinedAggFunction;
import org.opensearch.sql.calcite.udf.UserDefinedFunction;

public class UserDefineFunctionUtils {
public static RelBuilder.AggCall TransferUserDefinedAggFunction(
Class<? extends UserDefinedAggFunction> UDAF,
String functionName,
SqlReturnTypeInference returnType,
List<RexNode> fields,
List<RexNode> argList,
RelBuilder relBuilder) {
SqlUserDefinedAggFunction sqlUDAF =
new SqlUserDefinedAggFunction(
new SqlIdentifier(functionName, SqlParserPos.ZERO),
SqlKind.OTHER_FUNCTION,
returnType,
null,
null,
AggregateFunctionImpl.create(UDAF),
false,
false,
Optionality.FORBIDDEN);
List<RexNode> addArgList = new ArrayList<>(fields);
addArgList.addAll(argList);
return relBuilder.aggregateCall(sqlUDAF, addArgList);
}

public static SqlOperator TransferUserDefinedFunction(
Class<? extends UserDefinedFunction> UDF,
String functionName,
SqlReturnTypeInference returnType) {
final ScalarFunction udfFunction =
ScalarFunctionImpl.create(Types.lookupMethod(UDF, "eval", Object[].class));
SqlIdentifier udfLtrimIdentifier =
new SqlIdentifier(Collections.singletonList(functionName), null, SqlParserPos.ZERO, null);
return new SqlUserDefinedFunction(
udfLtrimIdentifier, SqlKind.OTHER_FUNCTION, returnType, null, null, udfFunction);
}

public static SqlReturnTypeInference getReturnTypeInferenceForArray() {
return opBinding -> {
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();

// Get argument types
List<RelDataType> argTypes = opBinding.collectOperandTypes();

if (argTypes.isEmpty()) {
throw new IllegalArgumentException("Function requires at least one argument.");
}
RelDataType firstArgType = argTypes.getFirst();
return createArrayType(typeFactory, firstArgType, true);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.opensearch.sql.util.MatcherUtils.verifySchema;

import java.io.IOException;
import java.util.List;
import org.json.JSONObject;
import org.junit.Ignore;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -221,4 +222,13 @@ public void testSimpleTwoLevelStats() {
verifySchema(actual, schema("avg_avg", "double"));
verifyDataRows(actual, rows(28432.625));
}

@Test
public void testTake() {
JSONObject actual =
executeQuery(
String.format("source=%s | stats take(firstname, 2) as take", TEST_INDEX_BANK));
verifySchema(actual, schema("take", "array"));
verifyDataRows(actual, rows(List.of("Amber JOHnny", "Hattie")));
}
}
Loading
Loading