Skip to content

Commit

Permalink
Limit array size in reduce Presto function (facebookincubator#8698)
Browse files Browse the repository at this point in the history
Summary:
reduce lambda function is not very efficient and uses a lot of CPU and wall time
when applied to huge arrays.

Add a limit of 10,000 elements per array.

Pull Request resolved: facebookincubator#8698

Reviewed By: Yuhta

Differential Revision: D53518840

Pulled By: mbasmanova

fbshipit-source-id: e64f9ec317e46ac60cc356f2197b6c8f5a215285
  • Loading branch information
mbasmanova authored and FelixYBW committed Feb 10, 2024
1 parent 049eef3 commit f7a64c4
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 21 deletions.
21 changes: 14 additions & 7 deletions velox/docs/functions/presto/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Array Functions
SELECT array_sort(ARRAY [ARRAY [1, 2], ARRAY [1, null]]); -- failed: Ordering nulls is not supported

.. function:: array_sort(array(T), function(T,U)) -> array(T)
:noindex:

Returns the array sorted by values computed using specified lambda in ascending
order. U must be an orderable type. Null elements will be placed at the end of
Expand All @@ -185,6 +186,7 @@ Array Functions
SELECT array_sort(ARRAY [ARRAY [1, 2], ARRAY [1, null]]); -- failed: Ordering nulls is not supported

.. function:: array_sort_desc(array(T), function(T,U)) -> array(T)
:noindex:

Returns the array sorted by values computed using specified lambda in descending
order. U must be an orderable type. Null elements will be placed at the end of
Expand Down Expand Up @@ -221,9 +223,9 @@ Array Functions
When 'element' is of complex type, throws if 'x' or 'element' contains nested nulls
and these need to be compared to produce a result. ::

SELECT contains(ARRAY[ARRAY[1, 3]], ARRAY[2, null]); -- false.
SELECT contains(ARRAY[ARRAY[2, 3]], ARRAY[2, null]); -- failed: contains does not support arrays with elements that are null or contain null
SELECT contains(ARRAY[ARRAY[2, null]], ARRAY[2, 1]); -- failed: contains does not support arrays with elements that are null or contain null
SELECT contains(ARRAY[ARRAY[1, 3]], ARRAY[2, null]); -- false.
SELECT contains(ARRAY[ARRAY[2, 3]], ARRAY[2, null]); -- failed: contains does not support arrays with elements that are null or contain null
SELECT contains(ARRAY[ARRAY[2, null]], ARRAY[2, 1]); -- failed: contains does not support arrays with elements that are null or contain null

.. function:: element_at(array(E), index) -> E

Expand All @@ -247,6 +249,7 @@ Array Functions
for no-match and first-match-is-null cases.

.. function:: find_first(array(T), index, function(T,boolean)) -> E
:noindex:

Returns the first element of ``array`` that matches the predicate.
Returns ``NULL`` if no element matches the predicate.
Expand All @@ -268,6 +271,7 @@ Array Functions
Returns ``NULL`` if no such element exists.

.. function:: find_first_index(array(T), index, function(T,boolean)) -> BIGINT
:noindex:

Returns the 1-based index of the first element of ``array`` that matches the predicate.
Returns ``NULL`` if no such element exists.
Expand Down Expand Up @@ -304,7 +308,9 @@ Array Functions
the element, ``inputFunction`` takes the current state, initially
``initialState``, and returns the new state. ``outputFunction`` will be
invoked to turn the final state into the result value. It may be the
identity function (``i -> i``). ::
identity function (``i -> i``).

Throws if array has more than 10,000 elements. ::

SELECT reduce(ARRAY [], 0, (s, x) -> s + x, s -> s); -- 0
SELECT reduce(ARRAY [5, 20, 50], 0, (s, x) -> s + x, s -> s); -- 75
Expand All @@ -327,7 +333,7 @@ Array Functions

.. function:: shuffle(array(E)) -> array(E)

Generate a random permutation of the given ``array``::
Generate a random permutation of the given ``array`` ::

SELECT shuffle(ARRAY [1, 2, 3]); -- [3, 1, 2] or any other random permutation
SELECT shuffle(ARRAY [0, 0, 0]); -- [0, 0, 0]
Expand Down Expand Up @@ -375,7 +381,7 @@ Array Functions

.. function:: remove_nulls(x) -> array

Remove null values from an array ``array``::
Remove null values from an array ``array`` ::

SELECT remove_nulls(ARRAY[1, NULL, 3, NULL]); -- [1, 3]
SELECT remove_nulls(ARRAY[true, false, NULL]); -- [true, false]
Expand All @@ -392,7 +398,8 @@ Array Functions
.. function:: zip_with(array(T), array(U), function(T,U,R)) -> array(R)

Merges the two given arrays, element-wise, into a single array using ``function``.
If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying ``function``::
If one array is shorter, nulls are appended at the end to match the length of the
longer array, before applying ``function`` ::

SELECT zip_with(ARRAY[1, 3, 5], ARRAY['a', 'b', 'c'], (x, y) -> (y, x)); -- [ROW('a', 1), ROW('b', 3), ROW('c', 5)]
SELECT zip_with(ARRAY[1, 2], ARRAY[3, 4], (x, y) -> x + y); -- [4, 6]
Expand Down
73 changes: 59 additions & 14 deletions velox/functions/prestosql/Reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,34 @@
namespace facebook::velox::functions {
namespace {

// Throws if any array in any of 'rows' has more than 10K elements.
// Evaluating 'reduce' lambda function on very large arrays is too slow.
void checkArraySizes(
const SelectivityVector& rows,
DecodedVector& decodedArray,
exec::EvalCtx& context) {
const auto* indices = decodedArray.indices();
const auto* rawSizes = decodedArray.base()->as<ArrayVector>()->rawSizes();

static const vector_size_t kMaxArraySize = 10'000;

rows.applyToSelected([&](auto row) {
if (decodedArray.isNullAt(row)) {
return;
}
const auto size = rawSizes[indices[row]];
try {
VELOX_USER_CHECK_LT(
size,
kMaxArraySize,
"reduce lambda function doesn't support arrays with more than {} elements",
kMaxArraySize);
} catch (VeloxUserError&) {
context.setError(row, std::current_exception());
}
});
}

/// Populates indices of the n-th elements of the arrays.
/// Selects 'row' in 'arrayRows' if corresponding array has an n-th element.
/// Sets elementIndices[row] to the index of the n-th element in the 'elements'
Expand Down Expand Up @@ -75,6 +103,36 @@ class ReduceFunction : public exec::VectorFunction {
exec::LocalDecodedVector arrayDecoder(context, *args[0], rows);
auto& decodedArray = *arrayDecoder.get();

checkArraySizes(rows, decodedArray, context);

exec::LocalSelectivityVector remainingRows(context, rows);
context.deselectErrors(*remainingRows);

doApply(*remainingRows, args, decodedArray, outputType, context, result);
}

static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
// array(T), S, function(S, T, S), function(S, R) -> R
return {exec::FunctionSignatureBuilder()
.typeVariable("T")
.typeVariable("S")
.typeVariable("R")
.returnType("R")
.argumentType("array(T)")
.argumentType("S")
.argumentType("function(S,T,S)")
.argumentType("function(S,R)")
.build()};
}

private:
void doApply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
DecodedVector& decodedArray,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const {
auto flatArray = flattenArray(rows, args[0], decodedArray);
// Identify the rows need to be computed.
exec::LocalSelectivityVector nonNullRowsHolder(*context.execCtx());
Expand Down Expand Up @@ -157,6 +215,7 @@ class ReduceFunction : public exec::VectorFunction {
n++;
}
}

// Apply output function.
VectorPtr localResult;
auto outputFuncIt =
Expand All @@ -178,20 +237,6 @@ class ReduceFunction : public exec::VectorFunction {
}
context.moveOrCopyResult(localResult, rows, result);
}

static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
// array(T), S, function(S, T, S), function(S, R) -> R
return {exec::FunctionSignatureBuilder()
.typeVariable("T")
.typeVariable("S")
.typeVariable("R")
.returnType("R")
.argumentType("array(T)")
.argumentType("S")
.argumentType("function(S,T,S)")
.argumentType("function(S,R)")
.build()};
}
};
} // namespace

Expand Down
27 changes: 27 additions & 0 deletions velox/functions/prestosql/tests/ReduceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"

using namespace facebook::velox;
Expand Down Expand Up @@ -242,3 +243,29 @@ TEST_F(ReduceTest, nullArray) {
assertEqualVectors(
makeNullableFlatVector<int64_t>({std::nullopt, std::nullopt}), result);
}

// Verify limit on the number of array elements.
TEST_F(ReduceTest, limit) {
// Make array vector with huge arrays in rows 2 and 4.
auto data = makeRowVector({makeArrayVector(
{0, 1'000, 10'000, 100'000, 100'010}, makeConstant(123, 1'000'000))});

VELOX_ASSERT_THROW(
evaluate("reduce(c0, 0, (s, x) -> s + x, s -> s)", data),
"reduce lambda function doesn't support arrays with more than 10000 elements");

// Exclude huge arrays.
SelectivityVector rows(4);
rows.setValid(2, false);
rows.updateBounds();
auto result = evaluate("reduce(c0, 0, (s, x) -> s + x, s -> s)", data, rows);
auto expected =
makeFlatVector<int64_t>({123 * 1'000, 123 * 9'000, -1, 123 * 10});
assertEqualVectors(expected, result, rows);

// Mask errors with TRY.
result = evaluate("TRY(reduce(c0, 0, (s, x) -> s + x, s -> s))", data);
expected = makeNullableFlatVector<int64_t>(
{123 * 1'000, 123 * 9'000, std::nullopt, 123 * 10, std::nullopt});
assertEqualVectors(expected, result);
}

0 comments on commit f7a64c4

Please sign in to comment.