Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
08837b2
add math udfs
xinyual Mar 5, 2025
404f43d
add log argument
xinyual Mar 5, 2025
17976ef
Add math function unit tests
yuancu Mar 5, 2025
f04ca61
Add integration tests for Calcite math functions
yuancu Mar 5, 2025
ccbfefc
Rename CalcitePPLMathFunctionsIT to CalcitePPLBuiltinFunctionIT
yuancu Mar 6, 2025
66859e4
Merge pull request #2 from yuancu/addMathUDF
xinyual Mar 6, 2025
d059299
add license
xinyual Mar 6, 2025
e644401
apply spot
xinyual Mar 6, 2025
17b1462
Update the implementation of CONV function to align with v2's behavior
yuancu Mar 7, 2025
2e94770
Improve code style:
yuancu Mar 7, 2025
3a49d0f
Simplify Calcite PPL math function unit tests
yuancu Mar 7, 2025
1341b75
Alter MOD and SQRT UDF to conform to documented behaviors
yuancu Mar 10, 2025
dcd6a66
Complicate math integration tests
yuancu Mar 10, 2025
c7e32b4
Handle NULL return in ASIN, ACOS, SQRT and POW by convert returned Do…
yuancu Mar 10, 2025
8728534
Apply spotless on math UDFs and their tests
yuancu Mar 11, 2025
8f56ced
Remove unnecessary Double cast in SQRT UDF
yuancu Mar 11, 2025
0515377
Merge pull request #3 from yuancu/addMathUDF
xinyual Mar 11, 2025
0e7fe63
merge from origin
xinyual Mar 17, 2025
eb3e669
Convert returned Double.NaN and Float.NaN from math UDFs to LITERAL_NULL
yuancu Mar 17, 2025
7dae0f7
Correct math UDF integration tests
yuancu Mar 17, 2025
584ca60
Update MOD UDF
yuancu Mar 17, 2025
efc1ab8
Modify substring ITs
yuancu Mar 17, 2025
6eb74d1
Merge pull request #6 from yuancu/addMathUDF
xinyual Mar 17, 2025
bb77593
apply spot
xinyual Mar 18, 2025
557794b
Replace containsMessage with verifyErrorMessageContains in math ITs
yuancu Mar 18, 2025
58e42da
Correct MOD return types
yuancu Mar 19, 2025
2a16176
Merge pull request #9 from yuancu/addMathUDF
xinyual Mar 19, 2025
7b619e5
fix UT
xinyual Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.udf.mathUDF;

import java.util.zip.CRC32;
import org.opensearch.sql.calcite.udf.UserDefinedFunction;

/**
* Calculate a cyclic redundancy check value and returns a 32-bit unsigned value<br>
* The supported signature of crc32 function is<br>
* STRING -> LONG
*/
public class CRC32Function implements UserDefinedFunction {
@Override
public Object eval(Object... args) {
if (args.length != 1) {
throw new IllegalArgumentException(
String.format("CRC32 function requires exactly one argument, but got %d", args.length));
}
Object value = args[0];
CRC32 crc = new CRC32();
crc.update(value.toString().getBytes());
return crc.getValue();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.udf.mathUDF;

import org.opensearch.sql.calcite.udf.UserDefinedFunction;

/**
* Convert number x from base a to base b<br>
* The supported signature of floor function is<br>
* (STRING, INTEGER, INTEGER) -> STRING<br>
* (INTEGER, INTEGER, INTEGER) -> STRING
*/
public class ConvFunction implements UserDefinedFunction {
@Override
public Object eval(Object... args) {
if (args.length != 3) {
throw new IllegalArgumentException(
String.format("CONV function requires exactly three arguments, but got %d", args.length));
}

Object number = args[0];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not check args.length in ConvFunction? and check args.length in ModFunction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And need some UT or IT for edge case

Object fromBase = args[1];
Object toBase = args[2];

String numStr = number.toString();
int fromBaseInt = Integer.parseInt(fromBase.toString());
int toBaseInt = Integer.parseInt(toBase.toString());
return conv(numStr, fromBaseInt, toBaseInt);
}

/**
* Convert numStr from fromBase to toBase
*
* @param numStr the number to convert (case-insensitive for alphanumeric digits, may have a
* leading '-')
* @param fromBase base of the input number (2 to 36)
* @param toBase target base (2 to 36)
* @return the converted number in the target base (uppercase), "0" if the input is invalid, or
* null if bases are out of range.
*/
private static String conv(String numStr, int fromBase, int toBase) {
return Long.toString(Long.parseLong(numStr, fromBase), toBase);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.udf.mathUDF;

import org.opensearch.sql.calcite.udf.UserDefinedFunction;

/** Get the Euler's number. () -> DOUBLE */
public class EulerFunction implements UserDefinedFunction {
@Override
public Object eval(Object... args) {
if (args.length != 0) {
throw new IllegalArgumentException(
String.format("Euler function takes no argument, but got %d", args.length));
}

return Math.E;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.udf.mathUDF;

import java.math.BigDecimal;
import org.opensearch.sql.calcite.udf.UserDefinedFunction;

/**
* Calculate the remainder of x divided by y<br>
* The supported signature of mod function is<br>
* (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE)<br>
* -> wider type between types of x and y
*/
public class ModFunction implements UserDefinedFunction {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider add UT?


@Override
public Object eval(Object... args) {
if (args.length != 2) {
throw new IllegalArgumentException(
String.format("MOD function requires exactly two arguments, but got %d", args.length));
}

Object arg0 = args[0];
Object arg1 = args[1];
if (!(arg0 instanceof Number num0) || !(arg1 instanceof Number num1)) {
throw new IllegalArgumentException(
String.format(
"MOD function requires two numeric arguments, but got %s and %s",
arg0.getClass().getSimpleName(), arg1.getClass().getSimpleName()));
}

// TODO: This precision check is arbitrary.
if (Math.abs(num1.doubleValue()) < 0.0000001) {
return null;
}

if (isIntegral(num0) && isIntegral(num1)) {
long l0 = num0.longValue();
long l1 = num1.longValue();
// It returns negative values when l0 is negative
long result = l0 % l1;
// Return the wider type between l0 and l1
if (num0 instanceof Long || num1 instanceof Long) {
return result;
}
return (int) result;
}

BigDecimal b0 = new BigDecimal(num0.toString());
BigDecimal b1 = new BigDecimal(num1.toString());
BigDecimal result = b0.remainder(b1);
if (num0 instanceof Double || num1 instanceof Double) {
return result.doubleValue();
}
return result.floatValue();
}

private boolean isIntegral(Number n) {
return n instanceof Byte || n instanceof Short || n instanceof Integer || n instanceof Long;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.STDDEV_SAMP_NULLABLE;
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_POP_NULLABLE;
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_SAMP_NULLABLE;
import static org.opensearch.sql.calcite.utils.UserDefineFunctionUtils.TransferUserDefinedAggFunction;
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedAggFunction;

import com.google.common.collect.ImmutableList;
import java.util.List;
Expand Down Expand Up @@ -65,7 +65,7 @@ static RelBuilder.AggCall translate(
return TransferUserDefinedAggFunction(
TakeAggFunction.class,
"TAKE",
UserDefineFunctionUtils.getReturnTypeInferenceForArray(),
UserDefinedFunctionUtils.getReturnTypeInferenceForArray(),
List.of(field),
argList,
context.relBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,29 @@

package org.opensearch.sql.calcite.utils;

import static org.opensearch.sql.calcite.utils.UserDefineFunctionUtils.TransferUserDefinedFunction;
import static java.lang.Math.E;
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.*;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlLibraryOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlTrimFunction;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.calcite.udf.conditionUDF.IfFunction;
import org.opensearch.sql.calcite.udf.conditionUDF.IfNullFunction;
import org.opensearch.sql.calcite.udf.conditionUDF.NullIfFunction;
import org.opensearch.sql.calcite.udf.mathUDF.CRC32Function;
import org.opensearch.sql.calcite.udf.mathUDF.ConvFunction;
import org.opensearch.sql.calcite.udf.mathUDF.EulerFunction;
import org.opensearch.sql.calcite.udf.mathUDF.ModFunction;
import org.opensearch.sql.calcite.udf.mathUDF.SqrtFunction;

public interface BuiltinFunctionUtils {
Expand Down Expand Up @@ -82,13 +89,69 @@ static SqlOperator translate(String op) {
// Built-in Math Functions
case "ABS":
return SqlStdOperatorTable.ABS;
case "SQRT":
return TransferUserDefinedFunction(
SqrtFunction.class, "SQRT", ReturnTypes.DOUBLE_FORCE_NULLABLE);
case "ACOS":
return SqlStdOperatorTable.ACOS;
case "ASIN":
return SqlStdOperatorTable.ASIN;
case "ATAN", "ATAN2":
return SqlStdOperatorTable.ATAN2;
case "CEILING":
return SqlStdOperatorTable.CEIL;
case "CONV":
// The CONV function in PPL converts between numerical bases,
// while SqlStdOperatorTable.CONVERT converts between charsets.
return TransferUserDefinedFunction(ConvFunction.class, "CONVERT", ReturnTypes.VARCHAR);
case "COS":
return SqlStdOperatorTable.COS;
case "COT":
return SqlStdOperatorTable.COT;
case "CRC32":
return TransferUserDefinedFunction(CRC32Function.class, "CRC32", ReturnTypes.BIGINT);
case "DEGREES":
return SqlStdOperatorTable.DEGREES;
case "E":
return TransferUserDefinedFunction(EulerFunction.class, "E", ReturnTypes.DOUBLE);
case "EXP":
return SqlStdOperatorTable.EXP;
case "FLOOR":
return SqlStdOperatorTable.FLOOR;
case "LN":
return SqlStdOperatorTable.LN;
case "LOG":
return SqlLibraryOperators.LOG;
case "LOG2":
return SqlLibraryOperators.LOG2;
case "LOG10":
return SqlStdOperatorTable.LOG10;
case "MOD", "%":
// The MOD function in PPL supports floating-point parameters, e.g., MOD(5.5, 2) = 1.5,
// MOD(3.1, 2.1) = 1.1,
// whereas SqlStdOperatorTable.MOD supports only integer / long parameters.
return TransferUserDefinedFunction(
ModFunction.class,
"MOD",
getLeastRestrictiveReturnTypeAmongArgsAt(List.of(0, 1), true));
case "PI":
return SqlStdOperatorTable.PI;
case "POW", "POWER":
return SqlStdOperatorTable.POWER;
case "RADIANS":
return SqlStdOperatorTable.RADIANS;
case "RAND":
return SqlStdOperatorTable.RAND;
case "ROUND":
return SqlStdOperatorTable.ROUND;
case "SIGN":
return SqlStdOperatorTable.SIGN;
case "SIN":
return SqlStdOperatorTable.SIN;
case "SQRT":
// SqlStdOperatorTable.SQRT is declared but not implemented, therefore we use a custom
// implementation.
return TransferUserDefinedFunction(
SqrtFunction.class, "SQRT", ReturnTypes.DOUBLE_FORCE_NULLABLE);
case "CBRT":
return SqlStdOperatorTable.CBRT;
// Built-in Date Functions
case "CURRENT_TIMESTAMP":
return SqlStdOperatorTable.CURRENT_TIMESTAMP;
Expand All @@ -102,14 +165,13 @@ static SqlOperator translate(String op) {
return SqlLibraryOperators.DATEADD;
// Built-in condition functions
case "IF":
return TransferUserDefinedFunction(
IfFunction.class, "if", UserDefineFunctionUtils.getReturnTypeInference(1));
return TransferUserDefinedFunction(IfFunction.class, "if", getReturnTypeInference(1));
case "IFNULL":
return TransferUserDefinedFunction(
IfNullFunction.class, "ifnull", UserDefineFunctionUtils.getReturnTypeInference(1));
IfNullFunction.class, "ifnull", getReturnTypeInference(1));
case "NULLIF":
return TransferUserDefinedFunction(
NullIfFunction.class, "ifnull", UserDefineFunctionUtils.getReturnTypeInference(0));
NullIfFunction.class, "ifnull", getReturnTypeInference(0));
case "IS NOT NULL":
return SqlStdOperatorTable.IS_NOT_NULL;
case "IS NULL":
Expand Down Expand Up @@ -167,6 +229,21 @@ static List<RexNode> translateArgument(
AtanArgs.add(context.rexBuilder.makeBigintLiteral(divideNumber));
}
return AtanArgs;
case "LOG":
List<RexNode> LogArgs = new ArrayList<>();
RelDataTypeFactory typeFactory = context.rexBuilder.getTypeFactory();
if (argList.size() == 1) {
LogArgs.add(argList.getFirst());
LogArgs.add(
context.rexBuilder.makeExactLiteral(
BigDecimal.valueOf(E), typeFactory.createSqlType(SqlTypeName.DOUBLE)));
} else if (argList.size() == 2) {
LogArgs.add(argList.get(1));
LogArgs.add(argList.get(0));
} else {
throw new IllegalArgumentException("Log cannot accept argument list: " + argList);
}
return LogArgs;
default:
return argList;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.opensearch.sql.calcite.udf.UserDefinedAggFunction;
import org.opensearch.sql.calcite.udf.UserDefinedFunction;

public class UserDefineFunctionUtils {
public class UserDefinedFunctionUtils {
public static RelBuilder.AggCall TransferUserDefinedAggFunction(
Class<? extends UserDefinedAggFunction> UDAF,
String functionName,
Expand Down Expand Up @@ -80,6 +80,37 @@ public static SqlReturnTypeInference getReturnTypeInferenceForArray() {
};
}

/**
* Infer return argument type as the widest return type among arguments as specified positions.
* E.g. (Integer, Long) -> Long; (Double, Float, SHORT) -> Double
*
* @param positions positions where the return type should be inferred from
* @param nullable whether the returned value is nullable
* @return The type inference
*/
public static SqlReturnTypeInference getLeastRestrictiveReturnTypeAmongArgsAt(
List<Integer> positions, boolean nullable) {
return opBinding -> {
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
List<RelDataType> types = new ArrayList<>();

for (int position : positions) {
if (position < 0 || position >= opBinding.getOperandCount()) {
throw new IllegalArgumentException("Invalid argument position: " + position);
}
types.add(opBinding.getOperandType(position));
}

RelDataType widerType = typeFactory.leastRestrictive(types);
if (widerType == null) {
throw new IllegalArgumentException(
"Cannot determine a common type for the given positions.");
}

return typeFactory.createTypeWithNullability(widerType, nullable);
};
}

/**
* For some udf/udaf, when giving a list of arguments, we need to infer the return type from the
* arguments.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,18 @@ public static ExprValue fromObjectValue(Object o) {
return longValue(((Long) o));
} else if (o instanceof Boolean) {
return booleanValue((Boolean) o);
} else if (o instanceof Double) {
return doubleValue((Double) o);
} else if (o instanceof Double d) {
if (Double.isNaN(d)) {
return LITERAL_NULL;
}
return doubleValue(d);
} else if (o instanceof String) {
return stringValue((String) o);
} else if (o instanceof Float) {
return floatValue((Float) o);
} else if (o instanceof Float f) {
if (Float.isNaN(f)) {
return LITERAL_NULL;
}
return floatValue(f);
} else if (o instanceof Date) {
return dateValue(((Date) o).toLocalDate());
} else if (o instanceof LocalDate) {
Expand Down
Loading
Loading