diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index cc1b7d88f58787..c24cd70ebead65 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -106,7 +106,7 @@ class IAggregateFunction { * row_num is number of row which should be added. * Additional parameter arena should be used instead of standard memory allocator if the addition requires memory allocation. */ - virtual void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + virtual void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const = 0; virtual void add_many(AggregateDataPtr __restrict place, const IColumn** columns, diff --git a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h index 03e1cc3df1394b..d0f5bce81a02be 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h +++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h @@ -95,7 +95,7 @@ class AggregateFunctionApproxCountDistinct final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { if constexpr (IsFixLenColumnType::value) { auto column = assert_cast(columns[0]); diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h b/be/src/vec/aggregate_functions/aggregate_function_avg.h index 61eb04bb13bb44..ca155f9d72c2fe 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h @@ -140,7 +140,7 @@ class AggregateFunctionAvg final } } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { #ifdef __clang__ #pragma clang fp reassociate(on) diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h index fe6f50481ba69c..498ee20ccb8371 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h +++ b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h @@ -106,7 +106,7 @@ class AggregateFunctionAvgWeight final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { const auto& column = assert_cast(*columns[0]); const auto& weight = assert_cast&>(*columns[1]); diff --git a/be/src/vec/aggregate_functions/aggregate_function_binary.h b/be/src/vec/aggregate_functions/aggregate_function_binary.h index 422919c52af9ad..ca06cc1bb81a8f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_binary.h +++ b/be/src/vec/aggregate_functions/aggregate_function_binary.h @@ -69,7 +69,7 @@ struct AggregateFunctionBinary bool allocates_memory_in_arena() const override { return false; } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { this->data(place).add( static_cast( diff --git a/be/src/vec/aggregate_functions/aggregate_function_bit.h b/be/src/vec/aggregate_functions/aggregate_function_bit.h index 6d2e67b14e7219..c0b2df85ba25d2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bit.h +++ b/be/src/vec/aggregate_functions/aggregate_function_bit.h @@ -112,7 +112,7 @@ class AggregateFunctionBitwise final DataTypePtr get_return_type() const override { return std::make_shared>(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { const auto& column = assert_cast&>(*columns[0]); this->data(place).add(column.get_data()[row_num]); diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.h b/be/src/vec/aggregate_functions/aggregate_function_bitmap.h index aa167b6571c16d..e997337769799e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.h +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.h @@ -301,7 +301,7 @@ class AggregateFunctionBitmapOp final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { const auto& column = assert_cast(*columns[0]); this->data(place).add(column.get_data()[row_num]); @@ -361,7 +361,7 @@ class AggregateFunctionBitmapCount final String get_name() const override { return "count"; } DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { if constexpr (arg_is_nullable) { auto& nullable_column = assert_cast(*columns[0]); diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.h b/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.h index a4c08aefe2ad43..000a6dab36bf0e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.h @@ -70,7 +70,7 @@ class AggregateFunctionBitmapAgg final std::string get_name() const override { return "bitmap_agg"; } DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { DCHECK_LT(row_num, columns[0]->size()); if constexpr (arg_nullable) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.h b/be/src/vec/aggregate_functions/aggregate_function_collect.h index 2188fe9b242ad9..7e3c7207a7d27c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_collect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_collect.h @@ -469,7 +469,7 @@ class AggregateFunctionCollect bool allocates_memory_in_arena() const override { return ENABLE_ARENA; } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { auto& data = this->data(place); if constexpr (HasLimit::value) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.h b/be/src/vec/aggregate_functions/aggregate_function_count.h index 92d0f644d33534..bf44b944bda021 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count.h +++ b/be/src/vec/aggregate_functions/aggregate_function_count.h @@ -65,7 +65,7 @@ class AggregateFunctionCount final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn**, size_t, Arena*) const override { + void add(AggregateDataPtr __restrict place, const IColumn**, ssize_t, Arena*) const override { ++data(place).count; } @@ -194,7 +194,7 @@ class AggregateFunctionCountNotNullUnary final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { data(place).count += !assert_cast(*columns[0]).is_null_at(row_num); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.h b/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.h index 273fa2a1e4c1ed..93a5103ef593c0 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.h +++ b/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.h @@ -159,7 +159,7 @@ class AggregateFunctionCountByEnum final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { for (int i = 0; i < arg_count; i++) { const auto* nullable_column = check_and_get_column(columns[i]); diff --git a/be/src/vec/aggregate_functions/aggregate_function_covar.h b/be/src/vec/aggregate_functions/aggregate_function_covar.h index 0c5dfd3f0377bb..31f0d7d2830e86 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_covar.h +++ b/be/src/vec/aggregate_functions/aggregate_function_covar.h @@ -273,7 +273,7 @@ class AggregateFunctionSampCovariance } } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { if constexpr (is_pop) { this->data(place).add(columns[0], columns[1], row_num); diff --git a/be/src/vec/aggregate_functions/aggregate_function_distinct.h b/be/src/vec/aggregate_functions/aggregate_function_distinct.h index 3b4968050aeb7a..c0c7a5b66dd58f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_distinct.h +++ b/be/src/vec/aggregate_functions/aggregate_function_distinct.h @@ -201,7 +201,7 @@ class AggregateFunctionDistinct prefix_size = (sizeof(Data) + nested_size - 1) / nested_size * nested_size; } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { this->data(place).add(columns, arguments_num, row_num, arena); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp b/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp new file mode 100644 index 00000000000000..e64e5900d01c2f --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Combinators/AggregateFunctionForEach.cpp +// and modified by Doris + +#include "vec/aggregate_functions/aggregate_function_foreach.h" + +#include +#include + +#include "common/logging.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/common/typeid_cast.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" + +namespace doris::vectorized { + +void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory& factory) { + AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, + const bool result_is_nullable) -> AggregateFunctionPtr { + const std::string& suffix = AggregateFunctionForEach::AGG_FOREACH_SUFFIX; + DataTypes transform_arguments; + for (const auto& t : types) { + auto item_type = + assert_cast(remove_nullable(t).get())->get_nested_type(); + transform_arguments.push_back((item_type)); + } + auto nested_function_name = name.substr(0, name.size() - suffix.size()); + auto nested_function = + factory.get(nested_function_name, transform_arguments, result_is_nullable); + if (!nested_function) { + throw Exception( + ErrorCode::INTERNAL_ERROR, + "The combiner did not find a foreach combiner function. nested function " + "name {} , args {}", + nested_function_name, types_name(types)); + } + return creator_without_type::create(transform_arguments, true, + nested_function); + }; + factory.register_foreach_function_combinator( + creator, AggregateFunctionForEach::AGG_FOREACH_SUFFIX, true); + factory.register_foreach_function_combinator( + creator, AggregateFunctionForEach::AGG_FOREACH_SUFFIX, false); +} +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_foreach.h b/be/src/vec/aggregate_functions/aggregate_function_foreach.h new file mode 100644 index 00000000000000..039c2d507b852b --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_foreach.h @@ -0,0 +1,264 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Combinators/AggregateFunctionForEach.h +// and modified by Doris + +#pragma once + +#include "common/logging.h" +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column_nullable.h" +#include "vec/common/assert_cast.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/functions/array/function_array_utils.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +struct AggregateFunctionForEachData { + size_t dynamic_array_size = 0; + char* array_of_aggregate_datas = nullptr; +}; + +/** Adaptor for aggregate functions. + * Adding -ForEach suffix to aggregate function + * will convert that aggregate function to a function, accepting arrays, + * and applies aggregation for each corresponding elements of arrays independently, + * returning arrays of aggregated values on corresponding positions. + * + * Example: sumForEach of: + * [1, 2], + * [3, 4, 5], + * [6, 7] + * will return: + * [10, 13, 5] + * + * TODO Allow variable number of arguments. + */ +class AggregateFunctionForEach : public IAggregateFunctionDataHelper { +protected: + using Base = + IAggregateFunctionDataHelper; + + AggregateFunctionPtr nested_function; + const size_t nested_size_of_data; + const size_t num_arguments; + + AggregateFunctionForEachData& ensure_aggregate_data(AggregateDataPtr __restrict place, + size_t new_size, Arena& arena) const { + AggregateFunctionForEachData& state = data(place); + + /// Ensure we have aggregate states for new_size elements, allocate + /// from arena if needed. When reallocating, we can't copy the + /// states to new buffer with memcpy, because they may contain pointers + /// to themselves. In particular, this happens when a state contains + /// a PODArrayWithStackMemory, which stores small number of elements + /// inline. This is why we create new empty states in the new buffer, + /// and merge the old states to them. + size_t old_size = state.dynamic_array_size; + if (old_size < new_size) { + static constexpr size_t MAX_ARRAY_SIZE = 100 * 1000000000ULL; + if (new_size > MAX_ARRAY_SIZE) { + throw Exception(ErrorCode::INTERNAL_ERROR, + "Suspiciously large array size ({}) in -ForEach aggregate function", + new_size); + } + + size_t allocation_size = 0; + if (common::mul_overflow(new_size, nested_size_of_data, allocation_size)) { + throw Exception(ErrorCode::INTERNAL_ERROR, + "Allocation size ({} * {}) overflows in -ForEach aggregate " + "function, but it should've been prevented by previous checks", + new_size, nested_size_of_data); + } + + char* old_state = state.array_of_aggregate_datas; + + char* new_state = + arena.aligned_alloc(allocation_size, nested_function->align_of_data()); + + size_t i; + try { + for (i = 0; i < new_size; ++i) { + nested_function->create(&new_state[i * nested_size_of_data]); + } + } catch (...) { + size_t cleanup_size = i; + + for (i = 0; i < cleanup_size; ++i) { + nested_function->destroy(&new_state[i * nested_size_of_data]); + } + + throw; + } + + for (i = 0; i < old_size; ++i) { + nested_function->merge(&new_state[i * nested_size_of_data], + &old_state[i * nested_size_of_data], &arena); + nested_function->destroy(&old_state[i * nested_size_of_data]); + } + + state.array_of_aggregate_datas = new_state; + state.dynamic_array_size = new_size; + } + + return state; + } + +public: + constexpr static auto AGG_FOREACH_SUFFIX = "_foreach"; + AggregateFunctionForEach(AggregateFunctionPtr nested_function_, const DataTypes& arguments) + : Base(arguments), + nested_function {std::move(nested_function_)}, + nested_size_of_data(nested_function->size_of_data()), + num_arguments(arguments.size()) { + if (arguments.empty()) { + throw Exception(ErrorCode::INTERNAL_ERROR, + "Aggregate function {} require at least one argument", get_name()); + } + } + void set_version(const int version_) override { + Base::set_version(version_); + nested_function->set_version(version_); + } + + String get_name() const override { return nested_function->get_name() + AGG_FOREACH_SUFFIX; } + + DataTypePtr get_return_type() const override { + return std::make_shared(nested_function->get_return_type()); + } + + void destroy(AggregateDataPtr __restrict place) const noexcept override { + AggregateFunctionForEachData& state = data(place); + + char* nested_state = state.array_of_aggregate_datas; + for (size_t i = 0; i < state.dynamic_array_size; ++i) { + nested_function->destroy(nested_state); + nested_state += nested_size_of_data; + } + } + + bool has_trivial_destructor() const override { + return std::is_trivially_destructible_v && nested_function->has_trivial_destructor(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena* arena) const override { + const AggregateFunctionForEachData& rhs_state = data(rhs); + AggregateFunctionForEachData& state = + ensure_aggregate_data(place, rhs_state.dynamic_array_size, *arena); + + const char* rhs_nested_state = rhs_state.array_of_aggregate_datas; + char* nested_state = state.array_of_aggregate_datas; + + for (size_t i = 0; i < state.dynamic_array_size && i < rhs_state.dynamic_array_size; ++i) { + nested_function->merge(nested_state, rhs_nested_state, arena); + + rhs_nested_state += nested_size_of_data; + nested_state += nested_size_of_data; + } + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + const AggregateFunctionForEachData& state = data(place); + write_binary(state.dynamic_array_size, buf); + const char* nested_state = state.array_of_aggregate_datas; + for (size_t i = 0; i < state.dynamic_array_size; ++i) { + nested_function->serialize(nested_state, buf); + nested_state += nested_size_of_data; + } + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena* arena) const override { + AggregateFunctionForEachData& state = data(place); + + size_t new_size = 0; + read_binary(new_size, buf); + + ensure_aggregate_data(place, new_size, *arena); + + char* nested_state = state.array_of_aggregate_datas; + for (size_t i = 0; i < new_size; ++i) { + nested_function->deserialize(nested_state, buf, arena); + nested_state += nested_size_of_data; + } + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + const AggregateFunctionForEachData& state = data(place); + + auto& arr_to = assert_cast(to); + auto& offsets_to = arr_to.get_offsets(); + IColumn& elems_to = arr_to.get_data(); + + char* nested_state = state.array_of_aggregate_datas; + for (size_t i = 0; i < state.dynamic_array_size; ++i) { + nested_function->insert_result_into(nested_state, elems_to); + nested_state += nested_size_of_data; + } + + offsets_to.push_back(offsets_to.back() + state.dynamic_array_size); + } + + bool allocates_memory_in_arena() const override { + return nested_function->allocates_memory_in_arena(); + } + + bool is_state() const override { return nested_function->is_state(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena* arena) const override { + const IColumn* nested[num_arguments]; + + for (size_t i = 0; i < num_arguments; ++i) { + nested[i] = &assert_cast(*columns[i]).get_data(); + } + + const auto& first_array_column = assert_cast(*columns[0]); + const auto& offsets = first_array_column.get_offsets(); + + size_t begin = offsets[row_num - 1]; + size_t end = offsets[row_num]; + + /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance. + for (size_t i = 1; i < num_arguments; ++i) { + const auto& ith_column = assert_cast(*columns[i]); + const auto& ith_offsets = ith_column.get_offsets(); + + if (ith_offsets[row_num] != end || + (row_num != 0 && ith_offsets[row_num - 1] != begin)) { + throw Exception(ErrorCode::INTERNAL_ERROR, + "Arrays passed to {} aggregate function have different sizes", + get_name()); + } + } + + AggregateFunctionForEachData& state = ensure_aggregate_data(place, end - begin, *arena); + + char* nested_state = state.array_of_aggregate_datas; + for (size_t i = begin; i < end; ++i) { + nested_function->add(nested_state, nested, i, arena); + nested_state += nested_size_of_data; + } + } +}; +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_concat.h b/be/src/vec/aggregate_functions/aggregate_function_group_concat.h index 6438e65a20b451..87ed907377ea36 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_concat.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_concat.h @@ -124,7 +124,7 @@ class AggregateFunctionGroupConcat final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { Impl::add(this->data(place), columns, row_num); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.h b/be/src/vec/aggregate_functions/aggregate_function_histogram.h index 295a063bc3044d..cae2a88daf0f75 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_histogram.h +++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.h @@ -184,7 +184,7 @@ class AggregateFunctionHistogram final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { if (columns[0]->is_null_at(row_num)) { return; diff --git a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h index bb2ab75d6c503e..f976e959f8558d 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h @@ -121,7 +121,7 @@ class AggregateFunctionHLLUnion this->data(place).insert_result_into(to); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { this->data(place).add(columns[0], row_num); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h index d79154a004c17e..4ef64aae558c3b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h +++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h @@ -295,7 +295,7 @@ class AggregateJavaUdaf final DataTypePtr get_return_type() const override { return _return_type; } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { int64_t places_address = reinterpret_cast(place); Status st = this->data(_exec_place) diff --git a/be/src/vec/aggregate_functions/aggregate_function_map.h b/be/src/vec/aggregate_functions/aggregate_function_map.h index a06173058307a9..e0a19a34207d80 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_map.h +++ b/be/src/vec/aggregate_functions/aggregate_function_map.h @@ -172,7 +172,7 @@ class AggregateFunctionMapAgg final make_nullable(argument_types[1])); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { if (columns[0]->is_nullable()) { auto& nullable_col = assert_cast(*columns[0]); diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.h b/be/src/vec/aggregate_functions/aggregate_function_min_max.h index 56714c9ee80157..dfc0cbae7f42fb 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.h +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.h @@ -526,7 +526,7 @@ class AggregateFunctionsSingleValue final DataTypePtr get_return_type() const override { return type; } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { this->data(place).change_if_better(*columns[0], row_num, arena); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h index b7a2f5c159d366..634dc171f5960c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h @@ -167,7 +167,7 @@ class AggregateFunctionsMinMaxBy final DataTypePtr get_return_type() const override { return value_type; } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { this->data(place).change_if_better(*columns[0], *columns[1], row_num, arena); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h b/be/src/vec/aggregate_functions/aggregate_function_null.h index 939396073825c4..a91a172fc0567b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_null.h +++ b/be/src/vec/aggregate_functions/aggregate_function_null.h @@ -200,7 +200,7 @@ class AggregateFunctionNullUnaryInline final AggregateFunctionNullUnaryInline>( nested_function_, arguments) {} - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { const ColumnNullable* column = assert_cast(columns[0]); if (!column->is_null_at(row_num)) { @@ -301,7 +301,7 @@ class AggregateFunctionNullVariadicInline final } } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { /// This container stores the columns we really pass to the nested function. const IColumn* nested_columns[number_of_arguments]; diff --git a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h index 0204c08e020601..f0fd67f4a851da 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h +++ b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h @@ -344,7 +344,7 @@ class AggFunctionOrthBitmapFunc final DataTypePtr get_return_type() const override { return Impl::get_return_type(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { this->data(place).init_add_key(columns, row_num, _argument_size); this->data(place).add(columns, row_num); diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h index 1a1285a9dc88f8..2eb7cc33098c32 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h @@ -200,7 +200,7 @@ class AggregateFunctionPercentileApproxMerge : public AggregateFunctionPercentil public: AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_) : AggregateFunctionPercentileApprox(argument_types_) {} - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { LOG(FATAL) << "AggregateFunctionPercentileApproxMerge do not support add()"; } @@ -211,7 +211,7 @@ class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPerce public: AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_) : AggregateFunctionPercentileApprox(argument_types_) {} - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { if constexpr (is_nullable) { double column_data[2] = {0, 0}; @@ -251,7 +251,7 @@ class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPer public: AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_) : AggregateFunctionPercentileApprox(argument_types_) {} - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { if constexpr (is_nullable) { double column_data[3] = {0, 0, 0}; @@ -386,7 +386,7 @@ class AggregateFunctionPercentile final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { const auto& sources = assert_cast&>(*columns[0]); const auto& quantile = assert_cast&>(*columns[1]); @@ -431,7 +431,7 @@ class AggregateFunctionPercentileArray final return std::make_shared(make_nullable(std::make_shared())); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { const auto& sources = assert_cast&>(*columns[0]); const auto& quantile_array = assert_cast(*columns[1]); diff --git a/be/src/vec/aggregate_functions/aggregate_function_product.h b/be/src/vec/aggregate_functions/aggregate_function_product.h index 4b0365a1b6d12b..22a217263b2274 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_product.h +++ b/be/src/vec/aggregate_functions/aggregate_function_product.h @@ -131,7 +131,7 @@ class AggregateFunctionProduct final } } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { const auto& column = assert_cast(*columns[0]); this->data(place).add(TResult(column.get_data()[row_num]), multiplier); diff --git a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h index 7c6a9cb9da0926..14250087d2bd74 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h +++ b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h @@ -111,7 +111,7 @@ class AggregateFunctionQuantileStateOp final return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { if constexpr (arg_is_nullable) { auto& nullable_column = assert_cast(*columns[0]); diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h index 9077a009a7a572..bbf62b09222857 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h +++ b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h @@ -215,7 +215,7 @@ class ReaderFunctionData final this->data(place).insert_result_into(to); } - void add(AggregateDataPtr place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { this->data(place).add(row_num, columns); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_retention.h b/be/src/vec/aggregate_functions/aggregate_function_retention.h index f595a1ad72648c..f38f1cf45a00d1 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_retention.h +++ b/be/src/vec/aggregate_functions/aggregate_function_retention.h @@ -124,7 +124,7 @@ class AggregateFunctionRetention } void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, const size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, const ssize_t row_num, Arena*) const override { for (int i = 0; i < get_argument_types().size(); i++) { auto event = assert_cast*>(columns[i])->get_data()[row_num]; diff --git a/be/src/vec/aggregate_functions/aggregate_function_rpc.h b/be/src/vec/aggregate_functions/aggregate_function_rpc.h index 21e5aa290d06d0..c92e96aaf9d935 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_rpc.h +++ b/be/src/vec/aggregate_functions/aggregate_function_rpc.h @@ -357,7 +357,7 @@ class AggregateRpcUdaf final DataTypePtr get_return_type() const override { return _return_type; } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { static_cast( this->data(place).buffer_add(columns, row_num, row_num + 1, argument_types)); diff --git a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h index 064c1e9979a231..101c2c16fd00c8 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h +++ b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h @@ -599,7 +599,7 @@ class AggregateFunctionSequenceBase void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, const size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, const ssize_t row_num, Arena*) const override { std::string pattern = assert_cast(columns[0])->get_data_at(0).to_string(); diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index abb844919895cc..c33b8b50609635 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -28,6 +28,7 @@ namespace doris::vectorized { void register_aggregate_function_combinator_sort(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory); @@ -107,6 +108,8 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_functions_corr(instance); register_aggregate_function_covar_pop(instance); register_aggregate_function_covar_samp(instance); + + register_aggregate_function_combinator_foreach(instance); }); return instance; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h index dccbd9a4d575fa..635709f3594e48 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h @@ -37,6 +37,14 @@ using DataTypes = std::vector; using AggregateFunctionCreator = std::function; +inline std::string types_name(const DataTypes& types) { + std::string name; + for (auto&& type : types) { + name += type->get_name(); + } + return name; +} + class AggregateFunctionSimpleFactory { public: using Creator = AggregateFunctionCreator; @@ -78,6 +86,21 @@ class AggregateFunctionSimpleFactory { } } + void register_foreach_function_combinator(const Creator& creator, const std::string& suffix, + bool nullable = false) { + auto& functions = nullable ? nullable_aggregate_functions : aggregate_functions; + std::vector need_insert; + for (const auto& entity : aggregate_functions) { + std::string target_value = entity.first + suffix; + if (functions.find(target_value) == functions.end()) { + need_insert.emplace_back(std::move(target_value)); + } + } + for (const auto& function_name : need_insert) { + register_function(function_name, creator, nullable); + } + } + AggregateFunctionPtr get(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable = false, int be_version = BeExecVersionManager::get_newest_version(), @@ -97,7 +120,7 @@ class AggregateFunctionSimpleFactory { } temporary_function_update(be_version, name_str); - if (function_alias.count(name)) { + if (function_alias.contains(name)) { name_str = function_alias[name]; } @@ -148,7 +171,6 @@ class AggregateFunctionSimpleFactory { } } -public: static AggregateFunctionSimpleFactory& instance(); }; }; // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_sort.h b/be/src/vec/aggregate_functions/aggregate_function_sort.h index 02106b75e60290..07b57e41359661 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sort.h +++ b/be/src/vec/aggregate_functions/aggregate_function_sort.h @@ -138,7 +138,7 @@ class AggregateFunctionSort } } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { this->data(place).add(columns, _arguments.size(), row_num); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_state_union.h b/be/src/vec/aggregate_functions/aggregate_function_state_union.h index 4134b7f79d1b4e..3c9e2ed3767f85 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_state_union.h +++ b/be/src/vec/aggregate_functions/aggregate_function_state_union.h @@ -53,7 +53,7 @@ class AggregateStateUnion : public IAggregateFunctionHelper DataTypePtr get_return_type() const override { return _return_type; } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { //the range is [begin, end] _function->deserialize_and_merge_from_column_range(place, *columns[0], row_num, row_num, diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.h b/be/src/vec/aggregate_functions/aggregate_function_stddev.h index c84e67a7d6d442..456e91c3f6a2de 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_stddev.h +++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.h @@ -296,7 +296,7 @@ class AggregateFunctionSampVariance } } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { if constexpr (is_pop) { this->data(place).add(columns[0], row_num); diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.h b/be/src/vec/aggregate_functions/aggregate_function_sum.h index 41677dd419bf2f..b53d011e5f1268 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sum.h +++ b/be/src/vec/aggregate_functions/aggregate_function_sum.h @@ -98,7 +98,7 @@ class AggregateFunctionSum final } } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { const auto& column = assert_cast(*columns[0]); this->data(place).add(TResult(column.get_data()[row_num])); diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.h b/be/src/vec/aggregate_functions/aggregate_function_topn.h index 633a36231a776f..6c7502c99a38fa 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_topn.h +++ b/be/src/vec/aggregate_functions/aggregate_function_topn.h @@ -287,7 +287,7 @@ class AggregateFunctionTopNBase : IAggregateFunctionDataHelper, AggregateFunctionTopNBase>(argument_types_) {} - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { Impl::add(this->data(place), columns, row_num); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq.h b/be/src/vec/aggregate_functions/aggregate_function_uniq.h index 3ef0359461b442..727a145c45a7ce 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_uniq.h +++ b/be/src/vec/aggregate_functions/aggregate_function_uniq.h @@ -115,7 +115,7 @@ class AggregateFunctionUniq final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { detail::OneAdder::add(this->data(place), *columns[0], row_num); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.h b/be/src/vec/aggregate_functions/aggregate_function_window.h index 5ce46495464ed3..3b0748d519f0af 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_window.h +++ b/be/src/vec/aggregate_functions/aggregate_function_window.h @@ -65,7 +65,7 @@ class WindowFunctionRowNumber final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const override { + void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const override { ++data(place).count; } @@ -103,7 +103,7 @@ class WindowFunctionRank final : public IAggregateFunctionDataHelper(); } - void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const override { + void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const override { ++data(place).rank; } @@ -148,7 +148,7 @@ class WindowFunctionDenseRank final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const override { + void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const override { ++data(place).rank; } @@ -197,7 +197,7 @@ class WindowFunctionPercentRank final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const override {} + void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const override {} void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start, int64_t frame_end, AggregateDataPtr place, const IColumn** columns, @@ -255,7 +255,7 @@ class WindowFunctionCumeDist final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const override {} + void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const override {} void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start, int64_t frame_end, AggregateDataPtr place, const IColumn** columns, @@ -299,7 +299,7 @@ class WindowFunctionNTile final DataTypePtr get_return_type() const override { return std::make_shared(); } - void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const override {} + void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const override {} void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start, int64_t frame_end, AggregateDataPtr place, const IColumn** columns, @@ -556,7 +556,7 @@ class WindowFunctionData final this->data(place).insert_result_into(to); } - void add(AggregateDataPtr place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { LOG(FATAL) << "WindowFunctionLeadLagData do not support add"; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h index 253677bbc3c6ab..d11b45caef68ab 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h +++ b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h @@ -279,7 +279,7 @@ class AggregateFunctionWindowFunnel void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { const auto& window = assert_cast&>(*columns[0]).get_data()[row_num]; diff --git a/be/src/vec/data_types/serde/data_type_nullable_serde.cpp b/be/src/vec/data_types/serde/data_type_nullable_serde.cpp index 6bdadfe23a76d0..07f3c5edbd4d99 100644 --- a/be/src/vec/data_types/serde/data_type_nullable_serde.cpp +++ b/be/src/vec/data_types/serde/data_type_nullable_serde.cpp @@ -286,7 +286,7 @@ template Status DataTypeNullableSerDe::_write_column_to_mysql(const IColumn& column, MysqlRowBuffer& result, int row_idx, bool col_const) const { - auto& col = static_cast(column); + auto& col = assert_cast(column); auto& nested_col = col.get_nested_column(); col_const = col_const || is_column_const(nested_col); const auto col_index = index_check_const(row_idx, col_const); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java index d89ae3ffffae59..e38b7143608c35 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java @@ -88,6 +88,7 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl public static final String AGG_STATE_SUFFIX = "_state"; public static final String AGG_UNION_SUFFIX = "_union"; public static final String AGG_MERGE_SUFFIX = "_merge"; + public static final String AGG_FOREACH_SUFFIX = "_foreach"; public static final String DEFAULT_EXPR_NAME = "expr"; protected boolean disableTableName = false; diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java index 2ee5485d7f8055..7dbf3a0ec0a1a9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java @@ -43,6 +43,7 @@ import java.io.DataOutput; import java.io.DataOutputStream; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -878,4 +879,17 @@ public static FunctionCallExpr convertToUnionCombinator(FunctionCallExpr fnCall) fnCall.setType(fnCall.getChildren().get(0).getType()); return fnCall; } + + public static FunctionCallExpr convertForEachCombinator(FunctionCallExpr fnCall) { + Function aggFunction = fnCall.getFn(); + aggFunction.setName(new FunctionName(aggFunction.getFunctionName().getFunction() + Expr.AGG_FOREACH_SUFFIX)); + List argTypes = new ArrayList(); + for (Type type : aggFunction.argTypes) { + argTypes.add(new ArrayType(type)); + } + aggFunction.setArgs(argTypes); + aggFunction.setReturnType(new ArrayType(aggFunction.getReturnType(), fnCall.isNullable())); + aggFunction.setNullableMode(NullableMode.ALWAYS_NULLABLE); + return fnCall; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java index fc8efef8d3e34a..82a09d7e04b06b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java @@ -20,7 +20,7 @@ import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.exceptions.AnalysisException; -import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder; import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; import org.apache.doris.nereids.trees.expressions.functions.udf.UdfBuilder; import org.apache.doris.nereids.types.DataType; @@ -93,13 +93,13 @@ public FunctionBuilder findFunctionBuilder(String dbName, String name, List a if (StringUtils.isEmpty(dbName)) { // search internal function only if dbName is empty functionBuilders = name2InternalBuiltinBuilders.get(name.toLowerCase()); - if (CollectionUtils.isEmpty(functionBuilders) && AggStateFunctionBuilder.isAggStateCombinator(name)) { - String nestedName = AggStateFunctionBuilder.getNestedName(name); - String combinatorSuffix = AggStateFunctionBuilder.getCombinatorSuffix(name); + if (CollectionUtils.isEmpty(functionBuilders) && AggCombinerFunctionBuilder.isAggStateCombinator(name)) { + String nestedName = AggCombinerFunctionBuilder.getNestedName(name); + String combinatorSuffix = AggCombinerFunctionBuilder.getCombinatorSuffix(name); functionBuilders = name2InternalBuiltinBuilders.get(nestedName.toLowerCase()); if (functionBuilders != null) { functionBuilders = functionBuilders.stream() - .map(builder -> new AggStateFunctionBuilder(combinatorSuffix, builder)) + .map(builder -> new AggCombinerFunctionBuilder(combinatorSuffix, builder)) .filter(functionBuilder -> functionBuilder.canApply(arguments)) .collect(Collectors.toList()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java index 2d77124fa58be0..cc2034aa746b40 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java @@ -83,6 +83,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator; import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator; import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator; import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator; @@ -612,6 +613,16 @@ public Expr visitUnionCombinator(UnionCombinator combinator, PlanTranslatorConte new FunctionParams(false, arguments))); } + @Override + public Expr visitForEachCombinator(ForEachCombinator combinator, PlanTranslatorContext context) { + List arguments = combinator.children().stream() + .map(arg -> new SlotRef(arg.getDataType().toCatalogDataType(), arg.nullable())) + .collect(ImmutableList.toImmutableList()); + return Function.convertForEachCombinator( + new FunctionCallExpr(visitAggregateFunction(combinator.getNestedFunction(), context).getFn(), + new FunctionParams(false, arguments))); + } + @Override public Expr visitAggregateFunction(AggregateFunction function, PlanTranslatorContext context) { List arguments = function.children().stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggStateFunctionBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggCombinerFunctionBuilder.java similarity index 74% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggStateFunctionBuilder.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggCombinerFunctionBuilder.java index 054aa4767b6018..3c514475eedde2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggStateFunctionBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggCombinerFunctionBuilder.java @@ -18,11 +18,15 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator; import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator; import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator; import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator; import org.apache.doris.nereids.types.AggStateType; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.DataType; import java.util.List; import java.util.Objects; @@ -31,28 +35,30 @@ /** * This class used to resolve AggState's combinators */ -public class AggStateFunctionBuilder extends FunctionBuilder { +public class AggCombinerFunctionBuilder extends FunctionBuilder { public static final String COMBINATOR_LINKER = "_"; public static final String STATE = "state"; public static final String MERGE = "merge"; public static final String UNION = "union"; + public static final String FOREACH = "foreach"; public static final String STATE_SUFFIX = COMBINATOR_LINKER + STATE; public static final String MERGE_SUFFIX = COMBINATOR_LINKER + MERGE; public static final String UNION_SUFFIX = COMBINATOR_LINKER + UNION; + public static final String FOREACH_SUFFIX = COMBINATOR_LINKER + FOREACH; private final FunctionBuilder nestedBuilder; private final String combinatorSuffix; - public AggStateFunctionBuilder(String combinatorSuffix, FunctionBuilder nestedBuilder) { + public AggCombinerFunctionBuilder(String combinatorSuffix, FunctionBuilder nestedBuilder) { this.combinatorSuffix = Objects.requireNonNull(combinatorSuffix, "combinatorSuffix can not be null"); this.nestedBuilder = Objects.requireNonNull(nestedBuilder, "nestedBuilder can not be null"); } @Override public boolean canApply(List arguments) { - if (combinatorSuffix.equals(STATE)) { + if (combinatorSuffix.equals(STATE) || combinatorSuffix.equals(FOREACH)) { return nestedBuilder.canApply(arguments); } else { if (arguments.size() != 1) { @@ -71,6 +77,23 @@ private AggregateFunction buildState(String nestedName, List a return (AggregateFunction) nestedBuilder.build(nestedName, arguments); } + private AggregateFunction buildForEach(String nestedName, List arguments) { + List forEachargs = arguments.stream().map(expr -> { + if (!(expr instanceof SlotReference)) { + throw new IllegalStateException( + "Can not build foreach nested function: '" + nestedName); + } + DataType arrayType = (((Expression) expr).getDataType()); + if (!(arrayType instanceof ArrayType)) { + throw new IllegalStateException( + "foreach must be input array type: '" + nestedName); + } + DataType itemType = ((ArrayType) arrayType).getItemType(); + return new SlotReference("mocked", itemType, (((ArrayType) arrayType).containsNull())); + }).collect(Collectors.toList()); + return (AggregateFunction) nestedBuilder.build(nestedName, forEachargs); + } + private AggregateFunction buildMergeOrUnion(String nestedName, List arguments) { if (arguments.size() != 1 || !(arguments.get(0) instanceof Expression) || !((Expression) arguments.get(0)).getDataType().isAggStateType()) { @@ -105,13 +128,16 @@ public BoundFunction build(String name, List arguments) { } else if (combinatorSuffix.equals(UNION)) { AggregateFunction nestedFunction = buildMergeOrUnion(nestedName, arguments); return new UnionCombinator((List) arguments, nestedFunction); + } else if (combinatorSuffix.equals(FOREACH)) { + AggregateFunction nestedFunction = buildForEach(nestedName, arguments); + return new ForEachCombinator((List) arguments, nestedFunction); } return null; } public static boolean isAggStateCombinator(String name) { return name.toLowerCase().endsWith(STATE_SUFFIX) || name.toLowerCase().endsWith(MERGE_SUFFIX) - || name.toLowerCase().endsWith(UNION_SUFFIX); + || name.toLowerCase().endsWith(UNION_SUFFIX) || name.toLowerCase().endsWith(FOREACH_SUFFIX); } public static String getNestedName(String name) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java new file mode 100644 index 00000000000000..fbbf51eb909941 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.combinator; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.DataType; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +/** + * combinator foreach + */ +public class ForEachCombinator extends AggregateFunction + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { + + private final AggregateFunction nested; + + /** + * constructor of ForEachCombinator + */ + public ForEachCombinator(List arguments, AggregateFunction nested) { + super(nested.getName() + AggCombinerFunctionBuilder.FOREACH_SUFFIX, arguments); + + this.nested = Objects.requireNonNull(nested, "nested can not be null"); + } + + public static ForEachCombinator create(AggregateFunction nested) { + return new ForEachCombinator(nested.getArguments(), nested); + } + + @Override + public ForEachCombinator withChildren(List children) { + return new ForEachCombinator(children, nested); + } + + @Override + public List getSignatures() { + return nested.getSignatures().stream().map(sig -> { + return sig.withReturnType(ArrayType.of(sig.returnType)).withArgumentTypes(false, + sig.argumentsTypes.stream().map(arg -> { + return ArrayType.of(arg); + }).collect(ImmutableList.toImmutableList())); + }).collect(ImmutableList.toImmutableList()); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitForEachCombinator(this, context); + } + + @Override + public DataType getDataType() { + return ArrayType.of(nested.getDataType(), nested.nullable()); + } + + public AggregateFunction getNestedFunction() { + return nested; + } + + @Override + public AggregateFunction withDistinctAndChildren(boolean distinct, List children) { + throw new UnsupportedOperationException("Unimplemented method 'withDistinctAndChildren'"); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/MergeCombinator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/MergeCombinator.java index b529ae2de3f804..a9b2d13d0d5319 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/MergeCombinator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/MergeCombinator.java @@ -19,7 +19,7 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder; import org.apache.doris.nereids.trees.expressions.functions.ComputeNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; @@ -43,7 +43,7 @@ public class MergeCombinator extends AggregateFunction private final AggStateType inputType; public MergeCombinator(List arguments, AggregateFunction nested) { - super(nested.getName() + AggStateFunctionBuilder.MERGE_SUFFIX, arguments); + super(nested.getName() + AggCombinerFunctionBuilder.MERGE_SUFFIX, arguments); this.nested = Objects.requireNonNull(nested, "nested can not be null"); inputType = (AggStateType) arguments.get(0).getDataType(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java index db001a6793c831..877824822c51ec 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java @@ -19,7 +19,7 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; @@ -47,7 +47,7 @@ public class StateCombinator extends ScalarFunction * constructor of StateCombinator */ public StateCombinator(List arguments, AggregateFunction nested) { - super(nested.getName() + AggStateFunctionBuilder.STATE_SUFFIX, arguments); + super(nested.getName() + AggCombinerFunctionBuilder.STATE_SUFFIX, arguments); this.nested = Objects.requireNonNull(nested, "nested can not be null"); this.returnType = new AggStateType(nested.getName(), arguments.stream().map(arg -> { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/UnionCombinator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/UnionCombinator.java index 67f09a50ebf985..e1138dd4851a8a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/UnionCombinator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/UnionCombinator.java @@ -19,7 +19,7 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; @@ -43,7 +43,7 @@ public class UnionCombinator extends AggregateFunction private final AggStateType inputType; public UnionCombinator(List arguments, AggregateFunction nested) { - super(nested.getName() + AggStateFunctionBuilder.UNION_SUFFIX, arguments); + super(nested.getName() + AggCombinerFunctionBuilder.UNION_SUFFIX, arguments); this.nested = Objects.requireNonNull(nested, "nested can not be null"); inputType = (AggStateType) arguments.get(0).getDataType(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index 4a1830341b9b5b..594f9c754335aa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -72,6 +72,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Variance; import org.apache.doris.nereids.trees.expressions.functions.agg.VarianceSamp; import org.apache.doris.nereids.trees.expressions.functions.agg.WindowFunnel; +import org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator; import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator; import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator; import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf; @@ -305,6 +306,10 @@ default R visitUnionCombinator(UnionCombinator combinator, C context) { return visitAggregateFunction(combinator, context); } + default R visitForEachCombinator(ForEachCombinator combinator, C context) { + return visitAggregateFunction(combinator, context); + } + default R visitJavaUdaf(JavaUdaf javaUdaf, C context) { return visitAggregateFunction(javaUdaf, context); } diff --git a/regression-test/data/function_p0/test_agg_foreach.out b/regression-test/data/function_p0/test_agg_foreach.out new file mode 100644 index 00000000000000..849d49bc9df3fe --- /dev/null +++ b/regression-test/data/function_p0/test_agg_foreach.out @@ -0,0 +1,28 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql -- +[1, 2, 3] [1, 2, 3] [100, 2, 3] [100, 2, 3] [40.333333333333336, 2, 3] [85.95867768595042, 2, 3] + +-- !sql -- +[121, 4, 3] [42.897811391983879, 0, 0] [52.538874496255943, 0, null] [1840.2222222222219, 0, 0] [2760.333333333333, 0, null] + +-- !sql -- +[1840.2222222222222, 0, 0] [2760.3333333333335, 0, null] [1, 0, 0] + +-- !sql -- +["{"20":1,"100":1,"1":1}", "{"2":2}", "{"3":1}"] ["{"20":1,"100":1,"1":1}", "{"2":2}", "{"3":1}"] [[100, 20, 1], [2], [3]] [[100, 20, 1], [2], [3]] + +-- !sql -- +[3, 2, 1] ["[{"cbe":{"100":1,"1":1,"20":1},"notnull":3,"null":1,"all":4}]", "[{"cbe":{"2":2},"notnull":2,"null":0,"all":2}]", "[{"cbe":{"3":1},"notnull":1,"null":0,"all":1}]"] + +-- !sql -- +[100, 2, 3] + +-- !sql -- +[[1], [2, 2, 2], [3]] + +-- !sql -- +[null, null, null] + +-- !sql -- +[0, 2, 3] [117, 2, 3] [113, 0, 3] + diff --git a/regression-test/suites/function_p0/test_agg_foreach.groovy b/regression-test/suites/function_p0/test_agg_foreach.groovy new file mode 100644 index 00000000000000..eec05fcde9e3f8 --- /dev/null +++ b/regression-test/suites/function_p0/test_agg_foreach.groovy @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_agg_foreach") { + // for nereids_planner + // now support min min_by maxmax_by avg avg_weighted sum stddev stddev_samp_foreach variance var_samp + // covar covar_samp corr + // topn topn_array topn_weighted + // count count_by_enum + // PERCENTILE PERCENTILE_ARRAY PERCENTILE_APPROX + // histogram + // GROUP_BIT_AND GROUP_BIT_OR GROUP_BIT_XOR + + sql """ set enable_nereids_planner=true;""" + sql """ set enable_fallback_to_original_planner=false;""" + + sql """ + drop table if exists foreach_table; + """ + + sql """ + CREATE TABLE IF NOT EXISTS foreach_table ( + `id` INT(11) null COMMENT "", + `a` array null COMMENT "", + `b` array> null COMMENT "", + `s` array null COMMENT "" + ) ENGINE=OLAP + DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "storage_format" = "V2" + ); + """ + sql """ + insert into foreach_table values + (1,[1,2,3],[[1],[1,2,3],[2]],["ab","123"]), + (2,[20],[[2]],["cd"]), + (3,[100],[[1]],["efg"]) , + (4,null,[null],null), + (5,[null,2],[[2],null],[null,'c']); + """ + + + qt_sql """ + select min_foreach(a), min_by_foreach(a,a),max_foreach(a),max_by_foreach(a,a) , avg_foreach(a),avg_weighted_foreach(a,a) from foreach_table ; + """ + + qt_sql """ + select sum_foreach(a) , stddev_foreach(a) ,stddev_samp_foreach(a) , variance_foreach(a) , var_samp_foreach(a) from foreach_table ; + """ + + qt_sql """ + select covar_foreach(a,a) , covar_samp_foreach(a,a) , corr_foreach(a,a) from foreach_table ; + """ + qt_sql """ + select topn_foreach(a,a) ,topn_foreach(a,a,a) , topn_array_foreach(a,a) ,topn_array_foreach(a,a,a)from foreach_table ; + """ + + + qt_sql """ + select count_foreach(a) , count_by_enum_foreach(a) from foreach_table; + """ + + qt_sql """ + select PERCENTILE_foreach(a,a) from foreach_table; + """ + + qt_sql """ + select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id = 1; + """ + + qt_sql """ + + select PERCENTILE_APPROX_foreach(a,a) from foreach_table; + """ + + qt_sql """ + select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), GROUP_BIT_XOR_foreach(a) from foreach_table; + """ +}