Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,18 @@
import java.util.zip.CRC32;
import org.opensearch.sql.calcite.udf.UserDefinedFunction;

/**
* Calculate a cyclic redundancy check value and returns a 32-bit unsigned value<br>
* The supported signature of crc32 function is<br>
* STRING -> LONG
*/
public class CRC32Function implements UserDefinedFunction {
@Override
public Object eval(Object... args) {
if (args.length != 1) {
throw new IllegalArgumentException(
String.format("CRC32 function requires exactly one argument, but got %d", args.length));
}
Object value = args[0];
CRC32 crc = new CRC32();
crc.update(value.toString().getBytes());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@

package org.opensearch.sql.calcite.udf.mathUDF;

import java.math.BigInteger;
import org.opensearch.sql.calcite.udf.UserDefinedFunction;

/**
* Convert number x from base a to base b<br>
* The supported signature of floor function is<br>
* (STRING, INTEGER, INTEGER) -> STRING<br>
* (INTEGER, INTEGER, INTEGER) -> STRING
*/
public class ConvFunction implements UserDefinedFunction {
@Override
public Object eval(Object... args) {
if (args.length != 3) {
throw new IllegalArgumentException(
String.format("CONV function requires exactly three arguments, but got %d", args.length));
}

Object number = args[0];
Object fromBase = args[1];
Object toBase = args[2];
Expand All @@ -32,39 +42,6 @@ public Object eval(Object... args) {
* null if bases are out of range.
*/
private static String conv(String numStr, int fromBase, int toBase) {
// Validate base ranges
if (fromBase < 2 || fromBase > 36 || toBase < 2 || toBase > 36) {
return null;
}

// Check for sign
boolean negative = false;
if (numStr.startsWith("-")) {
negative = true;
numStr = numStr.substring(1);
}

// Normalize input (e.g., remove extra whitespace, convert to uppercase)
numStr = numStr.trim().toUpperCase();

// Try parsing the input as a BigInteger of 'fromBase'
BigInteger value;
try {
value = new BigInteger(numStr, fromBase);
} catch (NumberFormatException e) {
// If numStr contains invalid characters for fromBase
return "0";
}

// Re-apply sign if needed
if (negative) {
value = value.negate();
}

// Convert to the target base; BigInteger's toString(...) yields lowercase above 9
String result = value.toString(toBase);

// Convert to uppercase to align with MySQL's behavior
return result.toUpperCase();
return Long.toString(Long.parseLong(numStr, fromBase), toBase);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@

import org.opensearch.sql.calcite.udf.UserDefinedFunction;

/** Get the Euler's number. () -> DOUBLE */
public class EulerFunction implements UserDefinedFunction {
@Override
public Object eval(Object... args) {
if (args.length != 0) {
throw new IllegalArgumentException(
String.format("Euler function takes no argument, but got %d", args.length));
}

return Math.E;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,54 @@

import org.opensearch.sql.calcite.udf.UserDefinedFunction;

/**
* Calculate the remainder of x divided by y<br>
* The supported signature of mod function is<br>
* (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE)<br>
* -> wider type between types of x and y
*/
public class ModFunction implements UserDefinedFunction {

@Override
public Object eval(Object... args) {
if (args.length < 2) {
throw new IllegalArgumentException("At least two arguments are required");
if (args.length != 2) {
throw new IllegalArgumentException(
String.format("MOD function requires exactly two arguments, but got %d", args.length));
}

// Get the two values
Object mod0 = args[0];
Object mod1 = args[1];
Object arg0 = args[0];
Object arg1 = args[1];
if (!(arg0 instanceof Number num0) || !(arg1 instanceof Number num1)) {
throw new IllegalArgumentException(
String.format(
"MOD function requires two numeric arguments, but got %s and %s",
arg0.getClass().getSimpleName(), arg1.getClass().getSimpleName()));
}

// Handle numbers dynamically
if (mod0 instanceof Integer && mod1 instanceof Integer) {
return (Integer) mod0 % (Integer) mod1;
} else if (mod0 instanceof Number && mod1 instanceof Number) {
double num0 = ((Number) mod0).doubleValue();
double num1 = ((Number) mod1).doubleValue();
// TODO: This precision check is arbitrary.
if (Math.abs(num1.doubleValue()) < 0.0000001) {
return null;
}

if (num1 == 0) {
throw new ArithmeticException("Modulo by zero is not allowed");
if (isIntegral(num0) && isIntegral(num1)) {
long l0 = num0.longValue();
long l1 = num1.longValue();
// Java returns negative values if the dividend is negative.
// We make it return positive values to align with V2's behavior
long result = (l0 % l1 + l1) % l1;
// Return the wider type between l0 and l1
if (num0 instanceof Integer && num1 instanceof Integer) {
return (int) result;
}

return num0 % num1; // Handles both float and double cases
} else {
throw new IllegalArgumentException("Invalid argument types: Expected numeric values");
return result;
}

double d0 = num0.doubleValue();
double d1 = num1.doubleValue();
return (d0 % d1 + d1) % d1;
}

private boolean isIntegral(Number n) {
return n instanceof Integer || n instanceof Long;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,28 @@

import org.opensearch.sql.calcite.udf.UserDefinedFunction;

/**
* Calculate the square root of a non-negative number x<br>
* The supported signature is<br>
* INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE It returns null if a negative parameter is provided
*/
public class SqrtFunction implements UserDefinedFunction {
@Override
public Object eval(Object... args) {
if (args.length < 1) {
throw new IllegalArgumentException("At least one argument is required");
}

// Get the input value
Object input = args[0];

// Handle numbers dynamically
if (input instanceof Number) {
double num = ((Number) input).doubleValue();

if (num < 0) {
throw new ArithmeticException("Cannot compute square root of a negative number");
return null;
}

return sqrt(num); // Computes sqrt using Math.sqrt()
return sqrt(num);
} else {
throw new IllegalArgumentException("Invalid argument type: Expected a numeric value");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.sql.calcite.utils;

import static org.opensearch.sql.calcite.utils.UserDefineFunctionUtils.TransferUserDefinedAggFunction;
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedAggFunction;

import com.google.common.collect.ImmutableList;
import java.util.List;
Expand Down Expand Up @@ -58,7 +58,7 @@ static RelBuilder.AggCall translate(
return TransferUserDefinedAggFunction(
TakeAggFunction.class,
"TAKE",
UserDefineFunctionUtils.getReturnTypeInferenceForArray(),
UserDefinedFunctionUtils.getReturnTypeInferenceForArray(),
List.of(field),
argList,
context.relBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
package org.opensearch.sql.calcite.utils;

import static java.lang.Math.E;
import static org.opensearch.sql.calcite.utils.UserDefineFunctionUtils.TransferUserDefinedFunction;
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedFunction;
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.getLeastRestrictiveReturnTypeAmongArgsAt;

import java.math.BigDecimal;
import java.util.ArrayList;
Expand All @@ -20,7 +21,11 @@
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.calcite.udf.mathUDF.*;
import org.opensearch.sql.calcite.udf.mathUDF.CRC32Function;
import org.opensearch.sql.calcite.udf.mathUDF.ConvFunction;
import org.opensearch.sql.calcite.udf.mathUDF.EulerFunction;
import org.opensearch.sql.calcite.udf.mathUDF.ModFunction;
import org.opensearch.sql.calcite.udf.mathUDF.SqrtFunction;

public interface BuiltinFunctionUtils {

Expand Down Expand Up @@ -72,7 +77,9 @@ static SqlOperator translate(String op) {
case "CEILING":
return SqlStdOperatorTable.CEIL;
case "CONV":
return TransferUserDefinedFunction(ConvFunction.class, "CONVERT", ReturnTypes.BIGINT);
// The CONV function in PPL converts between numerical bases,
// while SqlStdOperatorTable.CONVERT converts between charsets.
return TransferUserDefinedFunction(ConvFunction.class, "CONVERT", ReturnTypes.VARCHAR);
case "COS":
return SqlStdOperatorTable.COS;
case "COT":
Expand All @@ -96,7 +103,13 @@ static SqlOperator translate(String op) {
case "LOG10":
return SqlStdOperatorTable.LOG10;
case "MOD":
return TransferUserDefinedFunction(ModFunction.class, "MOD", ReturnTypes.DOUBLE);
// The MOD function in PPL supports floating-point parameters, e.g., MOD(5.5, 2) = 1.5,
// MOD(3.1, 2.1) = 1.1,
// whereas SqlStdOperatorTable.MOD supports only integer / long parameters.
return TransferUserDefinedFunction(
ModFunction.class,
"MOD",
getLeastRestrictiveReturnTypeAmongArgsAt(List.of(0, 1), true));
case "PI":
return SqlStdOperatorTable.PI;
case "POW", "POWER":
Expand All @@ -112,7 +125,10 @@ static SqlOperator translate(String op) {
case "SIN":
return SqlStdOperatorTable.SIN;
case "SQRT":
return TransferUserDefinedFunction(SqrtFunction.class, "SQRT", ReturnTypes.DOUBLE);
// SqlStdOperatorTable.SQRT is declared but not implemented, therefore we use a custom
// implementation.
return TransferUserDefinedFunction(
SqrtFunction.class, "SQRT", ReturnTypes.DOUBLE_FORCE_NULLABLE);
case "CBRT":
return SqlStdOperatorTable.CBRT;
// Built-in Date Functions
Expand All @@ -126,7 +142,6 @@ static SqlOperator translate(String op) {
return SqlLibraryOperators.DATE_ADD_SPARK;
case "DATE_ADD":
return SqlLibraryOperators.DATEADD;
// TODO Add more, ref RexImpTable
default:
throw new IllegalArgumentException("Unsupported operator: " + op);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.opensearch.sql.calcite.udf.UserDefinedAggFunction;
import org.opensearch.sql.calcite.udf.UserDefinedFunction;

public class UserDefineFunctionUtils {
public class UserDefinedFunctionUtils {
public static RelBuilder.AggCall TransferUserDefinedAggFunction(
Class<? extends UserDefinedAggFunction> UDAF,
String functionName,
Expand Down Expand Up @@ -79,4 +79,53 @@ public static SqlReturnTypeInference getReturnTypeInferenceForArray() {
return createArrayType(typeFactory, firstArgType, true);
};
}

/**
* Infer return argument type as the type of the argument at pos
*
* @param position The argument position
* @param nullable Whether the returned value is nullable
* @return SqlReturnTypeInference
*/
public static SqlReturnTypeInference getReturnTypeBasedOnArgAt(int position, boolean nullable) {
return opBinding -> {
if (position < 0 || position >= opBinding.getOperandCount()) {
throw new IllegalArgumentException("Invalid argument position: " + position);
}
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
RelDataType type = opBinding.getOperandType(position);
return typeFactory.createTypeWithNullability(type, nullable);
};
}

/**
* Infer return argument type as the widest return type among arguments as specified positions.
* E.g. (Integer, Long) -> Long; (Double, Float, SHORT) -> Double
*
* @param positions positions where the return type should be inferred from
* @param nullable whether the returned value is nullable
* @return The type inference
*/
public static SqlReturnTypeInference getLeastRestrictiveReturnTypeAmongArgsAt(
List<Integer> positions, boolean nullable) {
return opBinding -> {
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
List<RelDataType> types = new ArrayList<>();

for (int position : positions) {
if (position < 0 || position >= opBinding.getOperandCount()) {
throw new IllegalArgumentException("Invalid argument position: " + position);
}
types.add(opBinding.getOperandType(position));
}

RelDataType widerType = typeFactory.leastRestrictive(types);
if (widerType == null) {
throw new IllegalArgumentException(
"Cannot determine a common type for the given positions.");
}

return typeFactory.createTypeWithNullability(widerType, nullable);
};
}
}
Loading
Loading