Skip to content

Commit

Permalink
Implement array_sum function in Java from SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavmuk04 authored and kaikalur committed Jun 14, 2024
1 parent fd8d572 commit 1e80972
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 58 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.benchmark;

import com.facebook.presto.testing.LocalQueryRunner;

import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner;

public class SqlArraySumBenchmark
extends AbstractSqlBenchmark
{
public SqlArraySumBenchmark(LocalQueryRunner localQueryRunner, String query, String name)
{
super(localQueryRunner, name, 10, 10, query);
}

public static void main(String[] args)
{
new SqlArraySumBenchmark(createLocalQueryRunner(), "SELECT ARRAY_SUM(x) FROM part cross join (SELECT transform(sequence(1,10000),y->random()) AS x)", "sql_array_sum").runBenchmark(new SimpleLineBenchmarkResultWriter(System.out));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@
import com.facebook.presto.operator.scalar.ArraySliceFunction;
import com.facebook.presto.operator.scalar.ArraySortComparatorFunction;
import com.facebook.presto.operator.scalar.ArraySortFunction;
import com.facebook.presto.operator.scalar.ArraySumBigIntFunction;
import com.facebook.presto.operator.scalar.ArraySumDoubleFunction;
import com.facebook.presto.operator.scalar.ArrayTrimFunction;
import com.facebook.presto.operator.scalar.ArrayUnionFunction;
import com.facebook.presto.operator.scalar.ArraysOverlapFunction;
Expand Down Expand Up @@ -855,6 +857,8 @@ private List<? extends SqlFunction> getBuiltInFunctions(FeaturesConfig featuresC
.scalar(ArrayFilterFunction.class)
.scalar(ArrayPositionFunction.class)
.scalar(ArrayPositionWithIndexFunction.class)
.scalar(ArraySumBigIntFunction.class)
.scalar(ArraySumDoubleFunction.class)
.scalars(CombineHashFunction.class)
.scalars(JsonOperators.class)
.scalar(JsonOperators.JsonDistinctFromOperator.class)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.OperatorDependency;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

import java.lang.invoke.MethodHandle;

import static com.facebook.presto.common.function.OperatorType.ADD;
import static com.facebook.presto.util.Failures.internalError;

@Description("Returns the sum of all array elements, or 0 if the array is empty. Ignores null elements.")
@ScalarFunction(value = "array_sum")
public final class ArraySumBigIntFunction
{
private ArraySumBigIntFunction() {}

@SqlType("bigint")
public static long arraySumBigInt(
@OperatorDependency(operator = ADD, argumentTypes = {"bigint", "bigint"}) MethodHandle addFunction,
@TypeParameter("bigint") Type elementType,
@SqlType("array(bigint)") Block arrayBlock)
{
int positionCount = arrayBlock.getPositionCount();
if (positionCount == 0) {
return 0;
}

long sum = 0;
for (int i = 0; i < positionCount; i++) {
if (!arrayBlock.isNull(i)) {
try {
sum = (long) addFunction.invoke(sum, arrayBlock.getLong(i));
}
catch (Throwable throwable) {
throw internalError(throwable);
}
}
}
return sum;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.OperatorDependency;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

import java.lang.invoke.MethodHandle;

import static com.facebook.presto.common.function.OperatorType.ADD;
import static com.facebook.presto.common.type.TypeUtils.readNativeValue;
import static com.facebook.presto.util.Failures.internalError;

@Description("Returns the sum of all array elements, or 0 if the array is empty. Ignores null elements.")
@ScalarFunction(value = "array_sum")
public final class ArraySumDoubleFunction
{
private ArraySumDoubleFunction() {}

@SqlType("double")
public static double arraySumDouble(
@OperatorDependency(operator = ADD, argumentTypes = {"double", "double"}) MethodHandle addFunction,
@TypeParameter("double") Type elementType,
@SqlType("array(double)") Block arrayBlock)
{
int positionCount = arrayBlock.getPositionCount();
if (positionCount == 0) {
return 0.0;
}

double sum = 0.0;
for (int i = 0; i < positionCount; i++) {
if (!arrayBlock.isNull(i)) {
try {
Object newValue = readNativeValue(elementType, arrayBlock, i);
sum = (double) addFunction.invoke(sum, (double) newValue);
}
catch (Throwable throwable) {
throw internalError(throwable);
}
}
}
return sum;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,6 @@ public class ArraySqlFunctions
{
private ArraySqlFunctions() {}

@SqlInvokedScalarFunction(value = "array_sum", deterministic = true, calledOnNullInput = false)
@Description("Returns the sum of all array elements, or 0 if the array is empty. Ignores null elements.")
@SqlParameter(name = "input", type = "array<bigint>")
@SqlType("bigint")
public static String arraySumBigint()
{
return "RETURN reduce(input, BIGINT '0', (s, x) -> s + coalesce(x, BIGINT '0'), s -> s)";
}

@SqlInvokedScalarFunction(value = "array_sum", deterministic = true, calledOnNullInput = false)
@Description("Returns the sum of all array elements, or 0 if the array is empty. Ignores null elements.")
@SqlParameter(name = "input", type = "array<double>")
@SqlType("double")
public static String arraySumDouble()
{
return "RETURN reduce(input, DOUBLE '0', (s, x) -> s + coalesce(x, DOUBLE '0'), s -> s)";
}

@SqlInvokedScalarFunction(value = "array_average", deterministic = true, calledOnNullInput = false)
@Description("Returns the average of all array elements, or null if the array is empty. Ignores null elements.")
@SqlParameter(name = "input", type = "array<double>")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;

import org.testng.annotations.Test;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;

public class TestArraySumFunction
extends AbstractTestFunctions
{
@Test
public void testBigIntType()
{
assertFunction("array_sum(array[BIGINT '1', BIGINT '2'])", BIGINT, 3L);
assertFunction("array_sum(array[INTEGER '1', INTEGER '2'])", BIGINT, 3L);
assertFunction("array_sum(array[SMALLINT '1', SMALLINT '2'])", BIGINT, 3L);
assertFunction("array_sum(array[TINYINT '1', TINYINT '2'])", BIGINT, 3L);

assertFunction("array_sum(array[BIGINT '1', INTEGER '2'])", BIGINT, 3L);
assertFunction("array_sum(array[INTEGER '1', SMALLINT '2'])", BIGINT, 3L);
assertFunction("array_sum(array[SMALLINT '1', TINYINT '2'])", BIGINT, 3L);
}

@Test
public void testDoubleType()
{
assertFunctionWithError("array_sum(array[DOUBLE '-2.0', DOUBLE '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[DOUBLE '-2.0', REAL '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[DOUBLE '-2.0', DECIMAL '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[REAL '-2.0', DECIMAL '5.3'])", DOUBLE, 3.3);

assertFunctionWithError("array_sum(array[BIGINT '-2', DOUBLE '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[INTEGER '-2', REAL '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[SMALLINT '-2', DECIMAL '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[TINYINT '-2', DOUBLE '5.3'])", DOUBLE, 3.3);
}

@Test
public void testEdgeCases()
{
assertFunction("array_sum(null)", BIGINT, null);
assertFunction("array_sum(array[])", BIGINT, 0L);
assertFunction("array_sum(array[NULL])", BIGINT, 0L);
assertFunction("array_sum(array[NULL, NULL, NULL])", BIGINT, 0L);
assertFunction("array_sum(array[3, NULL, 5])", BIGINT, 8L);
assertFunctionWithError("array_sum(array[NULL, double '1.2', double '2.3', NULL, -3])", DOUBLE, 0.5);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,36 +38,6 @@
public class TestArraySqlFunctions
extends AbstractTestFunctions
{
@Test
public void testArraySum()
{
assertFunction("array_sum(array[BIGINT '1', BIGINT '2'])", BIGINT, 3L);
assertFunction("array_sum(array[INTEGER '1', INTEGER '2'])", BIGINT, 3L);
assertFunction("array_sum(array[SMALLINT '1', SMALLINT '2'])", BIGINT, 3L);
assertFunction("array_sum(array[TINYINT '1', TINYINT '2'])", BIGINT, 3L);

assertFunction("array_sum(array[BIGINT '1', INTEGER '2'])", BIGINT, 3L);
assertFunction("array_sum(array[INTEGER '1', SMALLINT '2'])", BIGINT, 3L);
assertFunction("array_sum(array[SMALLINT '1', TINYINT '2'])", BIGINT, 3L);

assertFunctionWithError("array_sum(array[DOUBLE '-2.0', DOUBLE '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[DOUBLE '-2.0', REAL '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[DOUBLE '-2.0', DECIMAL '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[REAL '-2.0', DECIMAL '5.3'])", DOUBLE, 3.3);

assertFunctionWithError("array_sum(array[BIGINT '-2', DOUBLE '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[INTEGER '-2', REAL '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[SMALLINT '-2', DECIMAL '5.3'])", DOUBLE, 3.3);
assertFunctionWithError("array_sum(array[TINYINT '-2', DOUBLE '5.3'])", DOUBLE, 3.3);

assertFunction("array_sum(null)", BIGINT, null);
assertFunction("array_sum(array[])", BIGINT, 0L);
assertFunction("array_sum(array[NULL])", BIGINT, 0L);
assertFunction("array_sum(array[NULL, NULL, NULL])", BIGINT, 0L);
assertFunction("array_sum(array[3, NULL, 5])", BIGINT, 8L);
assertFunctionWithError("array_sum(array[NULL, double '1.2', double '2.3', NULL, -3])", DOUBLE, 0.5);
}

@Test
public void testArrayAverage()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,6 @@ public void testInlineSqlFunctionCoercesConstantWithCast()
new ArrayType(BigintType.BIGINT));
}

@Test
public void testInlineBuiltinSqlFunction()
{
assertInlined(
"array_sum(x)",
"reduce(x, BIGINT '0', (\"s$lambda\", \"x$lambda\") -> \"s$lambda\" + COALESCE(\"x$lambda\", BIGINT '0'), \"s$lambda_0\" -> \"s$lambda_0\")",
"x",
new ArrayType(IntegerType.INTEGER));
}

@Test
public void testNoInlineThriftFunction()
{
Expand Down

0 comments on commit 1e80972

Please sign in to comment.