From 1e80972b373b7133ebd5c29e4987e3265638ae69 Mon Sep 17 00:00:00 2001 From: abhinavmuk04 <69162586+abhinavmuk04@users.noreply.github.com> Date: Thu, 13 Jun 2024 13:03:24 -0700 Subject: [PATCH] Implement array_sum function in Java from SQL --- .../benchmark/SqlArraySumBenchmark.java | 32 ++++++++++ ...uiltInTypeAndFunctionNamespaceManager.java | 4 ++ .../scalar/ArraySumBigIntFunction.java | 59 ++++++++++++++++++ .../scalar/ArraySumDoubleFunction.java | 61 +++++++++++++++++++ .../scalar/sql/ArraySqlFunctions.java | 18 ------ .../operator/scalar/TestArraySumFunction.java | 61 +++++++++++++++++++ .../scalar/sql/TestArraySqlFunctions.java | 30 --------- .../rule/TestInlineSqlFunctions.java | 10 --- 8 files changed, 217 insertions(+), 58 deletions(-) create mode 100644 presto-benchmark/src/main/java/com/facebook/presto/benchmark/SqlArraySumBenchmark.java create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySumBigIntFunction.java create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySumDoubleFunction.java create mode 100644 presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArraySumFunction.java diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/SqlArraySumBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/SqlArraySumBenchmark.java new file mode 100644 index 0000000000000..b681b31643cd0 --- /dev/null +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/SqlArraySumBenchmark.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.benchmark; + +import com.facebook.presto.testing.LocalQueryRunner; + +import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; + +public class SqlArraySumBenchmark + extends AbstractSqlBenchmark +{ + public SqlArraySumBenchmark(LocalQueryRunner localQueryRunner, String query, String name) + { + super(localQueryRunner, name, 10, 10, query); + } + + public static void main(String[] args) + { + new SqlArraySumBenchmark(createLocalQueryRunner(), "SELECT ARRAY_SUM(x) FROM part cross join (SELECT transform(sequence(1,10000),y->random()) AS x)", "sql_array_sum").runBenchmark(new SimpleLineBenchmarkResultWriter(System.out)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 384d0e4f026d0..9a6c01095e1fc 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -148,6 +148,8 @@ import com.facebook.presto.operator.scalar.ArraySliceFunction; import com.facebook.presto.operator.scalar.ArraySortComparatorFunction; import com.facebook.presto.operator.scalar.ArraySortFunction; +import com.facebook.presto.operator.scalar.ArraySumBigIntFunction; +import com.facebook.presto.operator.scalar.ArraySumDoubleFunction; import com.facebook.presto.operator.scalar.ArrayTrimFunction; import com.facebook.presto.operator.scalar.ArrayUnionFunction; import com.facebook.presto.operator.scalar.ArraysOverlapFunction; @@ -855,6 +857,8 @@ private List getBuiltInFunctions(FeaturesConfig featuresC .scalar(ArrayFilterFunction.class) .scalar(ArrayPositionFunction.class) .scalar(ArrayPositionWithIndexFunction.class) + .scalar(ArraySumBigIntFunction.class) + .scalar(ArraySumDoubleFunction.class) .scalars(CombineHashFunction.class) .scalars(JsonOperators.class) .scalar(JsonOperators.JsonDistinctFromOperator.class) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySumBigIntFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySumBigIntFunction.java new file mode 100644 index 0000000000000..0fd1153d3cb20 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySumBigIntFunction.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.OperatorDependency; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.ADD; +import static com.facebook.presto.util.Failures.internalError; + +@Description("Returns the sum of all array elements, or 0 if the array is empty. Ignores null elements.") +@ScalarFunction(value = "array_sum") +public final class ArraySumBigIntFunction +{ + private ArraySumBigIntFunction() {} + + @SqlType("bigint") + public static long arraySumBigInt( + @OperatorDependency(operator = ADD, argumentTypes = {"bigint", "bigint"}) MethodHandle addFunction, + @TypeParameter("bigint") Type elementType, + @SqlType("array(bigint)") Block arrayBlock) + { + int positionCount = arrayBlock.getPositionCount(); + if (positionCount == 0) { + return 0; + } + + long sum = 0; + for (int i = 0; i < positionCount; i++) { + if (!arrayBlock.isNull(i)) { + try { + sum = (long) addFunction.invoke(sum, arrayBlock.getLong(i)); + } + catch (Throwable throwable) { + throw internalError(throwable); + } + } + } + return sum; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySumDoubleFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySumDoubleFunction.java new file mode 100644 index 0000000000000..ef819cce58d8e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArraySumDoubleFunction.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.OperatorDependency; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.ADD; +import static com.facebook.presto.common.type.TypeUtils.readNativeValue; +import static com.facebook.presto.util.Failures.internalError; + +@Description("Returns the sum of all array elements, or 0 if the array is empty. Ignores null elements.") +@ScalarFunction(value = "array_sum") +public final class ArraySumDoubleFunction +{ + private ArraySumDoubleFunction() {} + + @SqlType("double") + public static double arraySumDouble( + @OperatorDependency(operator = ADD, argumentTypes = {"double", "double"}) MethodHandle addFunction, + @TypeParameter("double") Type elementType, + @SqlType("array(double)") Block arrayBlock) + { + int positionCount = arrayBlock.getPositionCount(); + if (positionCount == 0) { + return 0.0; + } + + double sum = 0.0; + for (int i = 0; i < positionCount; i++) { + if (!arrayBlock.isNull(i)) { + try { + Object newValue = readNativeValue(elementType, arrayBlock, i); + sum = (double) addFunction.invoke(sum, (double) newValue); + } + catch (Throwable throwable) { + throw internalError(throwable); + } + } + } + return sum; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java index 8c8c73c37dcc1..fa275eca5362f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java @@ -24,24 +24,6 @@ public class ArraySqlFunctions { private ArraySqlFunctions() {} - @SqlInvokedScalarFunction(value = "array_sum", deterministic = true, calledOnNullInput = false) - @Description("Returns the sum of all array elements, or 0 if the array is empty. Ignores null elements.") - @SqlParameter(name = "input", type = "array") - @SqlType("bigint") - public static String arraySumBigint() - { - return "RETURN reduce(input, BIGINT '0', (s, x) -> s + coalesce(x, BIGINT '0'), s -> s)"; - } - - @SqlInvokedScalarFunction(value = "array_sum", deterministic = true, calledOnNullInput = false) - @Description("Returns the sum of all array elements, or 0 if the array is empty. Ignores null elements.") - @SqlParameter(name = "input", type = "array") - @SqlType("double") - public static String arraySumDouble() - { - return "RETURN reduce(input, DOUBLE '0', (s, x) -> s + coalesce(x, DOUBLE '0'), s -> s)"; - } - @SqlInvokedScalarFunction(value = "array_average", deterministic = true, calledOnNullInput = false) @Description("Returns the average of all array elements, or null if the array is empty. Ignores null elements.") @SqlParameter(name = "input", type = "array") diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArraySumFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArraySumFunction.java new file mode 100644 index 0000000000000..6ba315c035ca2 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArraySumFunction.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; + +public class TestArraySumFunction + extends AbstractTestFunctions +{ + @Test + public void testBigIntType() + { + assertFunction("array_sum(array[BIGINT '1', BIGINT '2'])", BIGINT, 3L); + assertFunction("array_sum(array[INTEGER '1', INTEGER '2'])", BIGINT, 3L); + assertFunction("array_sum(array[SMALLINT '1', SMALLINT '2'])", BIGINT, 3L); + assertFunction("array_sum(array[TINYINT '1', TINYINT '2'])", BIGINT, 3L); + + assertFunction("array_sum(array[BIGINT '1', INTEGER '2'])", BIGINT, 3L); + assertFunction("array_sum(array[INTEGER '1', SMALLINT '2'])", BIGINT, 3L); + assertFunction("array_sum(array[SMALLINT '1', TINYINT '2'])", BIGINT, 3L); + } + + @Test + public void testDoubleType() + { + assertFunctionWithError("array_sum(array[DOUBLE '-2.0', DOUBLE '5.3'])", DOUBLE, 3.3); + assertFunctionWithError("array_sum(array[DOUBLE '-2.0', REAL '5.3'])", DOUBLE, 3.3); + assertFunctionWithError("array_sum(array[DOUBLE '-2.0', DECIMAL '5.3'])", DOUBLE, 3.3); + assertFunctionWithError("array_sum(array[REAL '-2.0', DECIMAL '5.3'])", DOUBLE, 3.3); + + assertFunctionWithError("array_sum(array[BIGINT '-2', DOUBLE '5.3'])", DOUBLE, 3.3); + assertFunctionWithError("array_sum(array[INTEGER '-2', REAL '5.3'])", DOUBLE, 3.3); + assertFunctionWithError("array_sum(array[SMALLINT '-2', DECIMAL '5.3'])", DOUBLE, 3.3); + assertFunctionWithError("array_sum(array[TINYINT '-2', DOUBLE '5.3'])", DOUBLE, 3.3); + } + + @Test + public void testEdgeCases() + { + assertFunction("array_sum(null)", BIGINT, null); + assertFunction("array_sum(array[])", BIGINT, 0L); + assertFunction("array_sum(array[NULL])", BIGINT, 0L); + assertFunction("array_sum(array[NULL, NULL, NULL])", BIGINT, 0L); + assertFunction("array_sum(array[3, NULL, 5])", BIGINT, 8L); + assertFunctionWithError("array_sum(array[NULL, double '1.2', double '2.3', NULL, -3])", DOUBLE, 0.5); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java index 8344f854f0395..a0da487dae0a4 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java @@ -38,36 +38,6 @@ public class TestArraySqlFunctions extends AbstractTestFunctions { - @Test - public void testArraySum() - { - assertFunction("array_sum(array[BIGINT '1', BIGINT '2'])", BIGINT, 3L); - assertFunction("array_sum(array[INTEGER '1', INTEGER '2'])", BIGINT, 3L); - assertFunction("array_sum(array[SMALLINT '1', SMALLINT '2'])", BIGINT, 3L); - assertFunction("array_sum(array[TINYINT '1', TINYINT '2'])", BIGINT, 3L); - - assertFunction("array_sum(array[BIGINT '1', INTEGER '2'])", BIGINT, 3L); - assertFunction("array_sum(array[INTEGER '1', SMALLINT '2'])", BIGINT, 3L); - assertFunction("array_sum(array[SMALLINT '1', TINYINT '2'])", BIGINT, 3L); - - assertFunctionWithError("array_sum(array[DOUBLE '-2.0', DOUBLE '5.3'])", DOUBLE, 3.3); - assertFunctionWithError("array_sum(array[DOUBLE '-2.0', REAL '5.3'])", DOUBLE, 3.3); - assertFunctionWithError("array_sum(array[DOUBLE '-2.0', DECIMAL '5.3'])", DOUBLE, 3.3); - assertFunctionWithError("array_sum(array[REAL '-2.0', DECIMAL '5.3'])", DOUBLE, 3.3); - - assertFunctionWithError("array_sum(array[BIGINT '-2', DOUBLE '5.3'])", DOUBLE, 3.3); - assertFunctionWithError("array_sum(array[INTEGER '-2', REAL '5.3'])", DOUBLE, 3.3); - assertFunctionWithError("array_sum(array[SMALLINT '-2', DECIMAL '5.3'])", DOUBLE, 3.3); - assertFunctionWithError("array_sum(array[TINYINT '-2', DOUBLE '5.3'])", DOUBLE, 3.3); - - assertFunction("array_sum(null)", BIGINT, null); - assertFunction("array_sum(array[])", BIGINT, 0L); - assertFunction("array_sum(array[NULL])", BIGINT, 0L); - assertFunction("array_sum(array[NULL, NULL, NULL])", BIGINT, 0L); - assertFunction("array_sum(array[3, NULL, 5])", BIGINT, 8L); - assertFunctionWithError("array_sum(array[NULL, double '1.2', double '2.3', NULL, -3])", DOUBLE, 0.5); - } - @Test public void testArrayAverage() { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.java index dc9ca779761db..b5ab33b609904 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.java @@ -158,16 +158,6 @@ public void testInlineSqlFunctionCoercesConstantWithCast() new ArrayType(BigintType.BIGINT)); } - @Test - public void testInlineBuiltinSqlFunction() - { - assertInlined( - "array_sum(x)", - "reduce(x, BIGINT '0', (\"s$lambda\", \"x$lambda\") -> \"s$lambda\" + COALESCE(\"x$lambda\", BIGINT '0'), \"s$lambda_0\" -> \"s$lambda_0\")", - "x", - new ArrayType(IntegerType.INTEGER)); - } - @Test public void testNoInlineThriftFunction() {