diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 4d928ef20f..7b18fc0d4b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -174,6 +174,10 @@ public static FunctionExpression exp(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.EXP, expressions); } + public static FunctionExpression expm1(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.EXPM1, expressions); + } + public static FunctionExpression floor(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.FLOOR, expressions); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index f9d38a0da3..a901868698 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -28,6 +28,7 @@ public enum BuiltinFunctionName { CRC32(FunctionName.of("crc32")), E(FunctionName.of("e")), EXP(FunctionName.of("exp")), + EXPM1(FunctionName.of("expm1")), FLOOR(FunctionName.of("floor")), LN(FunctionName.of("ln")), LOG(FunctionName.of("log")), diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java index 20b7928307..d555acb32d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -32,6 +32,7 @@ import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprShortValue; import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.function.BuiltinFunctionName; @@ -59,6 +60,7 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(crc32()); repository.register(euler()); repository.register(exp()); + repository.register(expm1()); repository.register(floor()); repository.register(ln()); repository.register(log()); @@ -85,6 +87,23 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(tan()); } + /** + * Base function for math functions with similar formats that return DOUBLE. + * + * @param functionName BuiltinFunctionName of math function. + * @param formula lambda function of math formula. + * @param returnType data type return type of the calling function + * @return DefaultFunctionResolver for math functions. + */ + private static DefaultFunctionResolver baseMathFunction( + FunctionName functionName, SerializableFunction formula, ExprCoreType returnType) { + return FunctionDSL.define(functionName, + ExprCoreType.numberTypes().stream().map(type -> FunctionDSL.impl( + FunctionDSL.nullMissingHandling(formula), + returnType, type)).collect(Collectors.toList())); + } + /** * Definition of abs() function. The supported signature of abs() function are INT -> INT LONG -> * LONG FLOAT -> FLOAT DOUBLE -> DOUBLE @@ -186,15 +205,21 @@ private static DefaultFunctionResolver euler() { } /** - * Definition of exp(x) function. Calculate exponent function e to the x The supported signature - * of exp function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + * Definition of exp(x) function. Calculate exponent function e to the x + * The supported signature of exp function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ private static DefaultFunctionResolver exp() { - return FunctionDSL.define(BuiltinFunctionName.EXP.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.exp(v.doubleValue()))), - type, DOUBLE)).collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.EXP.getName(), + v -> new ExprDoubleValue(Math.exp(v.doubleValue())), DOUBLE); + } + + /** + * Definition of expm1(x) function. Calculate exponent function e to the x, minus 1 + * The supported signature of exp function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver expm1() { + return baseMathFunction(BuiltinFunctionName.EXPM1.getName(), + v -> new ExprDoubleValue(Math.expm1(v.doubleValue())), DOUBLE); } /** @@ -214,11 +239,8 @@ private static DefaultFunctionResolver floor() { * ln function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ private static DefaultFunctionResolver ln() { - return FunctionDSL.define(BuiltinFunctionName.LN.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.log(v.doubleValue()))), - type, DOUBLE)).collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.LN.getName(), + v -> new ExprDoubleValue(Math.log(v.doubleValue())), DOUBLE); } /** @@ -255,11 +277,8 @@ private static DefaultFunctionResolver log() { * log function is SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ private static DefaultFunctionResolver log10() { - return FunctionDSL.define(BuiltinFunctionName.LOG10.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.log10(v.doubleValue()))), - type, DOUBLE)).collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.LOG10.getName(), + v -> new ExprDoubleValue(Math.log10(v.doubleValue())), DOUBLE); } /** @@ -267,11 +286,8 @@ private static DefaultFunctionResolver log10() { * function is SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ private static DefaultFunctionResolver log2() { - return FunctionDSL.define(BuiltinFunctionName.LOG2.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.log(v.doubleValue()) / Math.log(2))), DOUBLE, type)) - .collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.LOG2.getName(), + v -> new ExprDoubleValue(Math.log(v.doubleValue()) / Math.log(2)), DOUBLE); } /** @@ -450,11 +466,8 @@ private static DefaultFunctionResolver round() { * SHORT/INTEGER/LONG/FLOAT/DOUBLE -> INTEGER */ private static DefaultFunctionResolver sign() { - return FunctionDSL.define(BuiltinFunctionName.SIGN.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprIntegerValue(Math.signum(v.doubleValue()))), - INTEGER, type)).collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.SIGN.getName(), + v -> new ExprIntegerValue(Math.signum(v.doubleValue())), INTEGER); } /** @@ -464,12 +477,9 @@ private static DefaultFunctionResolver sign() { * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ private static DefaultFunctionResolver sqrt() { - return FunctionDSL.define(BuiltinFunctionName.SQRT.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> v.doubleValue() < 0 ? ExprNullValue.of() : - new ExprDoubleValue(Math.sqrt(v.doubleValue()))), - DOUBLE, type)).collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.SQRT.getName(), + v -> v.doubleValue() < 0 ? ExprNullValue.of() : + new ExprDoubleValue(Math.sqrt(v.doubleValue())), DOUBLE); } /** @@ -479,11 +489,8 @@ private static DefaultFunctionResolver sqrt() { * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ private static DefaultFunctionResolver cbrt() { - return FunctionDSL.define(BuiltinFunctionName.CBRT.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.cbrt(v.doubleValue()))), - DOUBLE, type)).collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.CBRT.getName(), + v -> new ExprDoubleValue(Math.cbrt(v.doubleValue())), DOUBLE); } /** @@ -606,11 +613,8 @@ private static DefaultFunctionResolver atan2() { * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ private static DefaultFunctionResolver cos() { - return FunctionDSL.define(BuiltinFunctionName.COS.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.cos(v.doubleValue()))), - DOUBLE, type)).collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.COS.getName(), + v -> new ExprDoubleValue(Math.cos(v.doubleValue())), DOUBLE); } /** @@ -641,11 +645,8 @@ private static DefaultFunctionResolver cot() { * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ private static DefaultFunctionResolver degrees() { - return FunctionDSL.define(BuiltinFunctionName.DEGREES.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.toDegrees(v.doubleValue()))), - type, DOUBLE)).collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.DEGREES.getName(), + v -> new ExprDoubleValue(Math.toDegrees(v.doubleValue())), DOUBLE); } /** @@ -655,11 +656,8 @@ private static DefaultFunctionResolver degrees() { * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ private static DefaultFunctionResolver radians() { - return FunctionDSL.define(BuiltinFunctionName.RADIANS.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.toRadians(v.doubleValue()))), - DOUBLE, type)).collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.RADIANS.getName(), + v -> new ExprDoubleValue(Math.toRadians(v.doubleValue())), DOUBLE); } /** @@ -669,11 +667,8 @@ private static DefaultFunctionResolver radians() { * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ private static DefaultFunctionResolver sin() { - return FunctionDSL.define(BuiltinFunctionName.SIN.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.sin(v.doubleValue()))), - DOUBLE, type)).collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.SIN.getName(), + v -> new ExprDoubleValue(Math.sin(v.doubleValue())), DOUBLE); } /** @@ -683,10 +678,7 @@ private static DefaultFunctionResolver sin() { * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ private static DefaultFunctionResolver tan() { - return FunctionDSL.define(BuiltinFunctionName.TAN.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.tan(v.doubleValue()))), - DOUBLE, type)).collect(Collectors.toList())); + return baseMathFunction(BuiltinFunctionName.TAN.getName(), + v -> new ExprDoubleValue(Math.tan(v.doubleValue())), DOUBLE); } } diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java index 3a03ba79ad..4e42286141 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java @@ -565,6 +565,90 @@ public void exp_missing_value() { assertTrue(exp.valueOf(valueEnv()).isMissing()); } + /** + * Test expm1 with integer value. + */ + @ParameterizedTest(name = "expm1({0})") + @ValueSource(ints = { + -1, 0, 1, Integer.MAX_VALUE, Integer.MIN_VALUE}) + public void expm1_int_value(Integer value) { + FunctionExpression expm1 = DSL.expm1(DSL.literal(value)); + assertThat( + expm1.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.expm1(value)))); + assertEquals(String.format("expm1(%s)", value), expm1.toString()); + } + + /** + * Test expm1 with long value. + */ + @ParameterizedTest(name = "expm1({0})") + @ValueSource(longs = { + -1L, 0L, 1L, Long.MAX_VALUE, Long.MIN_VALUE}) + public void expm1_long_value(Long value) { + FunctionExpression expm1 = DSL.expm1(DSL.literal(value)); + assertThat( + expm1.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.expm1(value)))); + assertEquals(String.format("expm1(%s)", value), expm1.toString()); + } + + /** + * Test expm1 with float value. + */ + @ParameterizedTest(name = "expm1({0})") + @ValueSource(floats = { + -1.5F, -1F, 0F, 1F, 1.5F, Float.MAX_VALUE, Float.MIN_VALUE}) + public void expm1_float_value(Float value) { + FunctionExpression expm1 = DSL.expm1(DSL.literal(value)); + assertThat( + expm1.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.expm1(value)))); + assertEquals(String.format("expm1(%s)", value), expm1.toString()); + } + + /** + * Test expm1 with double value. + */ + @ParameterizedTest(name = "expm1({0})") + @ValueSource(doubles = { + -1.5D, -1D, 0D, 1D, 1.5D, Double.MAX_VALUE, Double.MIN_VALUE}) + public void expm1_double_value(Double value) { + FunctionExpression expm1 = DSL.expm1(DSL.literal(value)); + assertThat( + expm1.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.expm1(value)))); + assertEquals(String.format("expm1(%s)", value), expm1.toString()); + } + + /** + * Test expm1 with short value. + */ + @ParameterizedTest(name = "expm1({0})") + @ValueSource(shorts = { + -1, 0, 1, Short.MAX_VALUE, Short.MIN_VALUE}) + public void expm1_short_value(Short value) { + FunctionExpression expm1 = DSL.expm1(DSL.literal(value)); + assertThat( + expm1.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.expm1(value)))); + assertEquals(String.format("expm1(%s)", value), expm1.toString()); + } + + /** + * Test expm1 with short value. + */ + @ParameterizedTest(name = "expm1({0})") + @ValueSource(bytes = { + -1, 0, 1, Byte.MAX_VALUE, Byte.MIN_VALUE}) + public void expm1_byte_value(Byte value) { + FunctionExpression expm1 = DSL.expm1(DSL.literal(value)); + assertThat( + expm1.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.expm1(value)))); + assertEquals(String.format("expm1(%s)", value), expm1.toString()); + } + /** * Test floor with integer value. */ @@ -575,7 +659,7 @@ public void floor_int_value(Integer value) { assertThat( floor.valueOf(valueEnv()), allOf(hasType(LONG), hasValue((long) Math.floor(value)))); - assertEquals(String.format("floor(%s)", value.toString()), floor.toString()); + assertEquals(String.format("floor(%s)", value), floor.toString()); } /** diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 749017078b..4ab66fb0e1 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -443,10 +443,21 @@ EXPM1 Description >>>>>>>>>>> -Specifications: +Usage: EXPM1(NUMBER T) returns the exponential of T, minus 1. + +Argument type: INTEGER/LONG/FLOAT/DOUBLE -1. EXPM1(NUMBER T) -> T +Return type: DOUBLE +Example:: + + os> SELECT EXPM1(-1), EXPM1(0), EXPM1(1), EXPM1(1.5) + fetched rows / total rows = 1/1 + +---------------------+------------+-------------------+-------------------+ + | EXPM1(-1) | EXPM1(0) | EXPM1(1) | EXPM1(1.5) | + |---------------------+------------+-------------------+-------------------| + | -0.6321205588285577 | 0.0 | 1.718281828459045 | 3.481689070338065 | + +---------------------+------------+-------------------+-------------------+ FLOOR ----- diff --git a/docs/user/ppl/functions/math.rst b/docs/user/ppl/functions/math.rst index 20bd1d6a70..9d0b088e80 100644 --- a/docs/user/ppl/functions/math.rst +++ b/docs/user/ppl/functions/math.rst @@ -187,6 +187,7 @@ Example:: | c | 44 | 1100 | 15 | +----------------------+----------------------+-------------------+---------------------+ + COS --- @@ -278,6 +279,7 @@ Example:: | 89.95437383553924 | +-------------------+ + E - @@ -309,7 +311,7 @@ Usage: exp(x) return e raised to the power of x. Argument type: INTEGER/LONG/FLOAT/DOUBLE -Return type: INTEGER +Return type: DOUBLE Example:: diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java index b8767eb2f1..5df76134f9 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java @@ -80,6 +80,13 @@ public void testE() throws IOException { verifyDataRows(result, rows(Math.E)); } + @Test + public void testExpm1() throws IOException { + JSONObject result = executeQuery("select expm1(account_number) FROM " + TEST_INDEX_BANK + " LIMIT 2"); + verifySchema(result, schema("expm1(account_number)", null, "double")); + verifyDataRows(result, rows(Math.expm1(1)), rows(Math.expm1(6))); + } + @Test public void testMod() throws IOException { JSONObject result = executeQuery("select mod(3, 2)"); diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 3fa223c584..f1a6ec1104 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -406,7 +406,7 @@ aggregationFunctionName ; mathematicalFunctionName - : ABS | CBRT | CEIL | CEILING | CONV | CRC32 | E | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER + : ABS | CBRT | CEIL | CEILING | CONV | CRC32 | E | EXP | EXPM1 | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER | RAND | ROUND | SIGN | SQRT | TRUNCATE | trigonometricFunctionName ;