diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java index d555acb32d..bb562ac07b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -240,7 +240,8 @@ private static DefaultFunctionResolver floor() { */ private static DefaultFunctionResolver ln() { return baseMathFunction(BuiltinFunctionName.LN.getName(), - v -> new ExprDoubleValue(Math.log(v.doubleValue())), DOUBLE); + v -> v.doubleValue() <= 0 ? ExprNullValue.of() : + new ExprDoubleValue(Math.log(v.doubleValue())), DOUBLE); } /** @@ -255,7 +256,8 @@ private static DefaultFunctionResolver log() { // build unary log(x), SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE for (ExprType type : ExprCoreType.numberTypes()) { builder.add(FunctionDSL.impl(FunctionDSL - .nullMissingHandling(v -> new ExprDoubleValue(Math.log(v.doubleValue()))), + .nullMissingHandling(v -> v.doubleValue() <= 0 ? ExprNullValue.of() : + new ExprDoubleValue(Math.log(v.doubleValue()))), DOUBLE, type)); } @@ -263,7 +265,8 @@ private static DefaultFunctionResolver log() { for (ExprType baseType : ExprCoreType.numberTypes()) { for (ExprType numberType : ExprCoreType.numberTypes()) { builder.add(FunctionDSL.impl(FunctionDSL - .nullMissingHandling((b, x) -> new ExprDoubleValue( + .nullMissingHandling((b, x) -> b.doubleValue() <= 0 || x.doubleValue() <= 0 + ? ExprNullValue.of() : new ExprDoubleValue( Math.log(x.doubleValue()) / Math.log(b.doubleValue()))), DOUBLE, baseType, numberType)); } @@ -278,7 +281,8 @@ private static DefaultFunctionResolver log() { */ private static DefaultFunctionResolver log10() { return baseMathFunction(BuiltinFunctionName.LOG10.getName(), - v -> new ExprDoubleValue(Math.log10(v.doubleValue())), DOUBLE); + v -> v.doubleValue() <= 0 ? ExprNullValue.of() : + new ExprDoubleValue(Math.log10(v.doubleValue())), DOUBLE); } /** @@ -287,7 +291,8 @@ private static DefaultFunctionResolver log10() { */ private static DefaultFunctionResolver log2() { return baseMathFunction(BuiltinFunctionName.LOG2.getName(), - v -> new ExprDoubleValue(Math.log(v.doubleValue()) / Math.log(2)), DOUBLE); + v -> v.doubleValue() <= 0 ? ExprNullValue.of() : + new ExprDoubleValue(Math.log(v.doubleValue()) / Math.log(2)), DOUBLE); } /** diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java index 4e42286141..47ea5057c9 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java @@ -81,6 +81,12 @@ private static Stream testLogDoubleArguments() { return builder.add(Arguments.of(2D, 2D)).build(); } + private static Stream testLogInvalidDoubleArguments() { + return Stream.of(Arguments.of(0D, -2D), + Arguments.of(0D, 2D), + Arguments.of(2D, 0D)); + } + private static Stream trigonometricArguments() { Stream.Builder builder = Stream.builder(); return builder @@ -725,7 +731,7 @@ public void floor_missing_value() { * Test ln with integer value. */ @ParameterizedTest(name = "ln({0})") - @ValueSource(ints = {2, -2}) + @ValueSource(ints = {2, 3}) public void ln_int_value(Integer value) { FunctionExpression ln = DSL.ln(DSL.literal(value)); assertThat( @@ -738,7 +744,7 @@ public void ln_int_value(Integer value) { * Test ln with long value. */ @ParameterizedTest(name = "ln({0})") - @ValueSource(longs = {2L, -2L}) + @ValueSource(longs = {2L, 3L}) public void ln_long_value(Long value) { FunctionExpression ln = DSL.ln(DSL.literal(value)); assertThat( @@ -751,7 +757,7 @@ public void ln_long_value(Long value) { * Test ln with float value. */ @ParameterizedTest(name = "ln({0})") - @ValueSource(floats = {2F, -2F}) + @ValueSource(floats = {2F, 3F}) public void ln_float_value(Float value) { FunctionExpression ln = DSL.ln(DSL.literal(value)); assertThat( @@ -764,7 +770,7 @@ public void ln_float_value(Float value) { * Test ln with double value. */ @ParameterizedTest(name = "ln({0})") - @ValueSource(doubles = {2D, -2D}) + @ValueSource(doubles = {2D, 3D}) public void ln_double_value(Double value) { FunctionExpression ln = DSL.ln(DSL.literal(value)); assertThat( @@ -773,6 +779,17 @@ public void ln_double_value(Double value) { assertEquals(String.format("ln(%s)", value.toString()), ln.toString()); } + /** + * Test ln with invalid value. + */ + @ParameterizedTest(name = "ln({0})") + @ValueSource(doubles = {0D, -3D}) + public void ln_invalid_value(Double value) { + FunctionExpression ln = DSL.ln(DSL.literal(value)); + assertEquals(DOUBLE, ln.type()); + assertTrue(ln.valueOf(valueEnv()).isNull()); + } + /** * Test ln with null value. */ @@ -853,6 +870,17 @@ public void log_double_value(Double v) { assertEquals(String.format("log(%s)", v.toString()), log.toString()); } + /** + * Test log with 1 invalid value. + */ + @ParameterizedTest(name = "log({0})") + @ValueSource(doubles = {0D, -3D}) + public void log_invalid_value(Double value) { + FunctionExpression log = DSL.log(DSL.literal(value)); + assertEquals(DOUBLE, log.type()); + assertTrue(log.valueOf(valueEnv()).isNull()); + } + /** * Test log with 1 null value argument. */ @@ -931,6 +959,17 @@ public void log_two_double_value(Double v1, Double v2) { assertEquals(String.format("log(%s, %s)", v1.toString(), v2.toString()), log.toString()); } + /** + * Test log with 2 invalid double arguments. + */ + @ParameterizedTest(name = "log({0}, {2})") + @MethodSource("testLogInvalidDoubleArguments") + public void log_two_invalid_double_value(Double v1, Double v2) { + FunctionExpression log = DSL.log(DSL.literal(v1), DSL.literal(v2)); + assertEquals(log.type(), DOUBLE); + assertTrue(log.valueOf(valueEnv()).isNull()); + } + /** * Test log with 2 null value arguments. */ @@ -1051,6 +1090,17 @@ public void log10_double_value(Double v) { assertEquals(String.format("log10(%s)", v.toString()), log.toString()); } + /** + * Test log10 with 1 invalid double argument. + */ + @ParameterizedTest(name = "log10({0})") + @ValueSource(doubles = {0D, -3D}) + public void log10_two_invalid_value(Double v) { + FunctionExpression log = DSL.log10(DSL.literal(v)); + assertEquals(log.type(), DOUBLE); + assertTrue(log.valueOf(valueEnv()).isNull()); + } + /** * Test log10 with null value. */ @@ -1133,6 +1183,17 @@ public void log2_double_value(Double v) { assertEquals(String.format("log2(%s)", v.toString()), log.toString()); } + /** + * Test log2 with an invalid double value. + */ + @ParameterizedTest(name = "log2({0})") + @ValueSource(doubles = {0D, -2D}) + public void log2_invalid_double_value(Double v) { + FunctionExpression log = DSL.log2(DSL.literal(v)); + assertEquals(log.type(), DOUBLE); + assertTrue(log.valueOf(valueEnv()).isNull()); + } + /** * Test log2 with null value. */ diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java index 5df76134f9..1bc9bd09b6 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java @@ -197,19 +197,6 @@ public void testAtan() throws IOException { verifyDataRows(result, rows(Math.atan2(2, 3))); } - protected JSONObject executeQuery(String query) throws IOException { - Request request = new Request("POST", QUERY_API_ENDPOINT); - request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query)); - - RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder(); - restOptionsBuilder.addHeader("Content-Type", "application/json"); - request.setOptions(restOptionsBuilder); - - Response response = client().performRequest(request); - return new JSONObject(getResponseBody(response)); - } - - @Test public void testCbrt() throws IOException { JSONObject result = executeQuery("select cbrt(8)"); @@ -224,4 +211,52 @@ public void testCbrt() throws IOException { verifySchema(result, schema("cbrt(-27)", "double")); verifyDataRows(result, rows(-3.0)); } + + @Test + public void testLnReturnsNull() throws IOException { + JSONObject result = executeQuery("select ln(0), ln(-2)"); + verifySchema(result, + schema("ln(0)", "double"), + schema("ln(-2)", "double")); + verifyDataRows(result, rows(null, null)); + } + + @Test + public void testLogReturnsNull() throws IOException { + JSONObject result = executeQuery("select log(0), log(-2)"); + verifySchema(result, + schema("log(0)", "double"), + schema("log(-2)", "double")); + verifyDataRows(result, rows(null, null)); + } + + @Test + public void testLog10ReturnsNull() throws IOException { + JSONObject result = executeQuery("select log10(0), log10(-2)"); + verifySchema(result, + schema("log10(0)", "double"), + schema("log10(-2)", "double")); + verifyDataRows(result, rows(null, null)); + } + + @Test + public void testLog2ReturnsNull() throws IOException { + JSONObject result = executeQuery("select log2(0), log2(-2)"); + verifySchema(result, + schema("log2(0)", "double"), + schema("log2(-2)", "double")); + verifyDataRows(result, rows(null, null)); + } + + protected JSONObject executeQuery(String query) throws IOException { + Request request = new Request("POST", QUERY_API_ENDPOINT); + request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query)); + + RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder(); + restOptionsBuilder.addHeader("Content-Type", "application/json"); + request.setOptions(restOptionsBuilder); + + Response response = client().performRequest(request); + return new JSONObject(getResponseBody(response)); + } }