diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java index 3aab7d73fb1..2d6ff97d24d 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java @@ -74,4 +74,14 @@ public static Object caseWhen(Object... objs) { // with or without else statement. return objs.length % 2 == 0 ? null : objs[objs.length - 1]; } + + @Nullable + @ScalarFunction(nullableParameters = true) + public static Object nullIf(Object obj1, Object obj2) { + if (obj1 == null) { + return null; + } else { + return obj1.equals(obj2) ? null : obj1; + } + } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/ObjectFunctionsTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/ObjectFunctionsTest.java index 231b6078575..e71bd740f00 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/ObjectFunctionsTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/ObjectFunctionsTest.java @@ -97,7 +97,7 @@ public Object[][] objectFunctionsDataProvider() { oneValue.putValue("value2", null); inputs.add(new Object[]{ - "coalesce(null0,null1, null2, value1, value2)", Lists.newArrayList("null0", "null1", "null2", "value1", + "coalesce(null0, null1, null2, value1, value2)", Lists.newArrayList("null0", "null1", "null2", "value1", "value2"), oneValue, 1 }); @@ -156,6 +156,27 @@ public Object[][] objectFunctionsDataProvider() { "value1"), caseWhenCaseMultipleExpression2, "fifteen" }); + // NULLIF + GenericRow nullIf = new GenericRow(); + nullIf.putValue("value1", 1); + nullIf.putValue("value2", 1); + inputs.add(new Object[]{"NULLIF(value1, value2)", Lists.newArrayList("value1", "value2"), nullIf, null}); + + GenericRow nullIf2 = new GenericRow(); + nullIf2.putValue("value1", 1); + nullIf2.putValue("value2", 2); + inputs.add(new Object[]{"NULLIF(value1, value2)", Lists.newArrayList("value1", "value2"), nullIf2, 1}); + + GenericRow nullIf3 = new GenericRow(); + nullIf3.putValue("value1", null); + nullIf3.putValue("value2", 2); + inputs.add(new Object[]{"NULLIF(value1, value2)", Lists.newArrayList("value1", "value2"), nullIf3, null}); + + GenericRow nullIf4 = new GenericRow(); + nullIf4.putValue("value1", 1); + nullIf4.putValue("value2", null); + inputs.add(new Object[]{"NULLIF(value1, value2)", Lists.newArrayList("value1", "value2"), nullIf4, 1}); + return inputs.toArray(new Object[0][]); } } diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java index b58ae2dee2e..48cd23ea7d9 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java @@ -30,6 +30,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Pattern; import javax.annotation.Nullable; import org.apache.commons.io.FileUtils; @@ -1010,6 +1011,30 @@ public void testCaseWhenWithLargeNumberOfWhenThenClauses() } } + @Test + public void testNullIf() + throws Exception { + // Calls to the Calcite NULLIF operator are rewritten to the equivalent CASE WHEN expressions. This test verifies + // that the rewrite works correctly. + String sqlQuery = "SET " + CommonConstants.Broker.Request.QueryOptionKey.ENABLE_NULL_HANDLING + + "=true; SELECT NULLIF(ArrDelay, 0) FROM mytable"; + JsonNode result = postQuery(sqlQuery); + assertNoError(result); + + JsonNode rows = result.get("resultTable").get("rows"); + AtomicInteger nullRows = new AtomicInteger(); + rows.elements().forEachRemaining(row -> { + if (row.get(0).isNull()) { + nullRows.getAndIncrement(); + } + }); + + sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ArrDelay = 0"; + result = postQuery(sqlQuery); + assertNoError(result); + assertEquals(nullRows.get(), result.get("resultTable").get("rows").get(0).get(0).asInt()); + } + @Test public void testMVNumericCastInFilter() throws Exception { String sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ARRAY_TO_MV(CAST(DivAirportIDs AS BIGINT ARRAY)) > 0"; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java index 0c1a8d8a485..ecaab85affd 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java @@ -161,6 +161,7 @@ public static PinotOperatorTable instance() { SqlStdOperatorTable.LIKE, // SqlStdOperatorTable.CASE, SqlStdOperatorTable.OVER, + SqlStdOperatorTable.NULLIF, // FUNCTIONS // String functions