From fb9ff8ce8f0778a5010436f3469a7e03277b5472 Mon Sep 17 00:00:00 2001 From: Guian Gumpac Date: Mon, 13 Feb 2023 11:25:04 -0800 Subject: [PATCH] Fixed bug and added tests Signed-off-by: Guian Gumpac --- .../arthmetic/MathematicalFunction.java | 15 ++-- .../arthmetic/MathematicalFunctionTest.java | 70 +++++++++++++++++-- .../sql/sql/MathematicalFunctionIT.java | 61 ++++++++++++---- 3 files changed, 124 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java index 20b7928307..56d666fba3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -217,7 +217,8 @@ private static DefaultFunctionResolver ln() { return FunctionDSL.define(BuiltinFunctionName.LN.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.log(v.doubleValue()))), + v -> v.doubleValue() <= 0 ? ExprNullValue.of() : + new ExprDoubleValue(Math.log(v.doubleValue()))), type, DOUBLE)).collect(Collectors.toList())); } @@ -233,7 +234,8 @@ private static DefaultFunctionResolver log() { // build unary log(x), SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE for (ExprType type : ExprCoreType.numberTypes()) { builder.add(FunctionDSL.impl(FunctionDSL - .nullMissingHandling(v -> new ExprDoubleValue(Math.log(v.doubleValue()))), + .nullMissingHandling(v -> v.doubleValue() <= 0 ? ExprNullValue.of() : + new ExprDoubleValue(Math.log(v.doubleValue()))), DOUBLE, type)); } @@ -241,7 +243,8 @@ private static DefaultFunctionResolver log() { for (ExprType baseType : ExprCoreType.numberTypes()) { for (ExprType numberType : ExprCoreType.numberTypes()) { builder.add(FunctionDSL.impl(FunctionDSL - .nullMissingHandling((b, x) -> new ExprDoubleValue( + .nullMissingHandling((b, x) -> b.doubleValue() <= 0 || x.doubleValue() <= 0 + ? ExprNullValue.of() : new ExprDoubleValue( Math.log(x.doubleValue()) / Math.log(b.doubleValue()))), DOUBLE, baseType, numberType)); } @@ -258,7 +261,8 @@ private static DefaultFunctionResolver log10() { return FunctionDSL.define(BuiltinFunctionName.LOG10.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.log10(v.doubleValue()))), + v -> v.doubleValue() <= 0 ? ExprNullValue.of() : + new ExprDoubleValue(Math.log10(v.doubleValue()))), type, DOUBLE)).collect(Collectors.toList())); } @@ -270,7 +274,8 @@ private static DefaultFunctionResolver log2() { return FunctionDSL.define(BuiltinFunctionName.LOG2.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( - v -> new ExprDoubleValue(Math.log(v.doubleValue()) / Math.log(2))), DOUBLE, type)) + v -> v.doubleValue() <= 0 ? ExprNullValue.of() : + new ExprDoubleValue(Math.log(v.doubleValue()) / Math.log(2))), DOUBLE, type)) .collect(Collectors.toList())); } diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java index 3a03ba79ad..fcfa05a0bc 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java @@ -81,6 +81,13 @@ private static Stream testLogDoubleArguments() { return builder.add(Arguments.of(2D, 2D)).build(); } + private static Stream testLogInvalidDoubleArguments() { + Stream.Builder builder = Stream.builder(); + return builder.add(Arguments.of(0D, -2D)) + .add(Arguments.of(0D, 2D)) + .add(Arguments.of(2D, 0D)).build(); + } + private static Stream trigonometricArguments() { Stream.Builder builder = Stream.builder(); return builder @@ -641,7 +648,7 @@ public void floor_missing_value() { * Test ln with integer value. */ @ParameterizedTest(name = "ln({0})") - @ValueSource(ints = {2, -2}) + @ValueSource(ints = {2, 3}) public void ln_int_value(Integer value) { FunctionExpression ln = DSL.ln(DSL.literal(value)); assertThat( @@ -654,7 +661,7 @@ public void ln_int_value(Integer value) { * Test ln with long value. */ @ParameterizedTest(name = "ln({0})") - @ValueSource(longs = {2L, -2L}) + @ValueSource(longs = {2L, 3L}) public void ln_long_value(Long value) { FunctionExpression ln = DSL.ln(DSL.literal(value)); assertThat( @@ -667,7 +674,7 @@ public void ln_long_value(Long value) { * Test ln with float value. */ @ParameterizedTest(name = "ln({0})") - @ValueSource(floats = {2F, -2F}) + @ValueSource(floats = {2F, 3F}) public void ln_float_value(Float value) { FunctionExpression ln = DSL.ln(DSL.literal(value)); assertThat( @@ -680,7 +687,7 @@ public void ln_float_value(Float value) { * Test ln with double value. */ @ParameterizedTest(name = "ln({0})") - @ValueSource(doubles = {2D, -2D}) + @ValueSource(doubles = {2D, 3D}) public void ln_double_value(Double value) { FunctionExpression ln = DSL.ln(DSL.literal(value)); assertThat( @@ -689,6 +696,17 @@ public void ln_double_value(Double value) { assertEquals(String.format("ln(%s)", value.toString()), ln.toString()); } + /** + * Test ln with invalid value. + */ + @ParameterizedTest(name = "ln({0})") + @ValueSource(doubles = {0D, -3D}) + public void ln_invalid_value(Double value) { + FunctionExpression ln = DSL.ln(DSL.literal(value)); + assertEquals(DOUBLE, ln.type()); + assertTrue(ln.valueOf(valueEnv()).isNull()); + } + /** * Test ln with null value. */ @@ -769,6 +787,17 @@ public void log_double_value(Double v) { assertEquals(String.format("log(%s)", v.toString()), log.toString()); } + /** + * Test log with 1 invalid value. + */ + @ParameterizedTest(name = "log({0})") + @ValueSource(doubles = {0D, -3D}) + public void log_invalid_value(Double value) { + FunctionExpression log = DSL.log(DSL.literal(value)); + assertEquals(DOUBLE, log.type()); + assertTrue(log.valueOf(valueEnv()).isNull()); + } + /** * Test log with 1 null value argument. */ @@ -847,6 +876,17 @@ public void log_two_double_value(Double v1, Double v2) { assertEquals(String.format("log(%s, %s)", v1.toString(), v2.toString()), log.toString()); } + /** + * Test log with 2 invalid double arguments. + */ + @ParameterizedTest(name = "log({0}, {2})") + @MethodSource("testLogInvalidDoubleArguments") + public void log_two_invalid_double_value(Double v1, Double v2) { + FunctionExpression log = DSL.log(DSL.literal(v1), DSL.literal(v2)); + assertEquals(log.type(), DOUBLE); + assertTrue(log.valueOf(valueEnv()).isNull()); + } + /** * Test log with 2 null value arguments. */ @@ -967,6 +1007,17 @@ public void log10_double_value(Double v) { assertEquals(String.format("log10(%s)", v.toString()), log.toString()); } + /** + * Test log10 with 1 invalid double argument. + */ + @ParameterizedTest(name = "log10({0})") + @ValueSource(doubles = {0D, -3D}) + public void log10_two_invalid_value(Double v) { + FunctionExpression log = DSL.log10(DSL.literal(v)); + assertEquals(log.type(), DOUBLE); + assertTrue(log.valueOf(valueEnv()).isNull()); + } + /** * Test log10 with null value. */ @@ -1049,6 +1100,17 @@ public void log2_double_value(Double v) { assertEquals(String.format("log2(%s)", v.toString()), log.toString()); } + /** + * Test log2 with an invalid double value. + */ + @ParameterizedTest(name = "log2({0})") + @ValueSource(doubles = {0D, -2D}) + public void log2_invalid_double_value(Double v) { + FunctionExpression log = DSL.log2(DSL.literal(v)); + assertEquals(log.type(), DOUBLE); + assertTrue(log.valueOf(valueEnv()).isNull()); + } + /** * Test log2 with null value. */ diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java index b8767eb2f1..db79850214 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java @@ -190,19 +190,6 @@ public void testAtan() throws IOException { verifyDataRows(result, rows(Math.atan2(2, 3))); } - protected JSONObject executeQuery(String query) throws IOException { - Request request = new Request("POST", QUERY_API_ENDPOINT); - request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query)); - - RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder(); - restOptionsBuilder.addHeader("Content-Type", "application/json"); - request.setOptions(restOptionsBuilder); - - Response response = client().performRequest(request); - return new JSONObject(getResponseBody(response)); - } - - @Test public void testCbrt() throws IOException { JSONObject result = executeQuery("select cbrt(8)"); @@ -217,4 +204,52 @@ public void testCbrt() throws IOException { verifySchema(result, schema("cbrt(-27)", "double")); verifyDataRows(result, rows(-3.0)); } + + @Test + public void testLnReturnsNull() throws IOException { + JSONObject result = executeQuery("select ln(0), ln(-2)"); + verifySchema(result, + schema("ln(0)", "double"), + schema("ln(-2)", "double")); + verifyDataRows(result, rows(null, null)); + } + + @Test + public void testLogReturnsNull() throws IOException { + JSONObject result = executeQuery("select log(0), log(-2)"); + verifySchema(result, + schema("log(0)", "double"), + schema("log(-2)", "double")); + verifyDataRows(result, rows(null, null)); + } + + @Test + public void testLog10ReturnsNull() throws IOException { + JSONObject result = executeQuery("select log10(0), log10(-2)"); + verifySchema(result, + schema("log10(0)", "double"), + schema("log10(-2)", "double")); + verifyDataRows(result, rows(null, null)); + } + + @Test + public void testLog2ReturnsNull() throws IOException { + JSONObject result = executeQuery("select log2(0), log2(-2)"); + verifySchema(result, + schema("log2(0)", "double"), + schema("log2(-2)", "double")); + verifyDataRows(result, rows(null, null)); + } + + protected JSONObject executeQuery(String query) throws IOException { + Request request = new Request("POST", QUERY_API_ENDPOINT); + request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query)); + + RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder(); + restOptionsBuilder.addHeader("Content-Type", "application/json"); + request.setOptions(restOptionsBuilder); + + Response response = client().performRequest(request); + return new JSONObject(getResponseBody(response)); + } }