Skip to content

Commit

Permalink
[Feature] group_concat() support distinct and order by
Browse files Browse the repository at this point in the history
Signed-off-by: Zhuhe Fang <fzhedu@gmail.com>
  • Loading branch information
fzhedu committed Aug 23, 2023
1 parent ef6c338 commit e31fbe8
Show file tree
Hide file tree
Showing 15 changed files with 333 additions and 39 deletions.
4 changes: 3 additions & 1 deletion be/src/exec/aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ void AggregatorParams::init() {

bool is_input_nullable = has_outer_join_child || desc.nodes[0].has_nullable_child;
agg_fn_types[i] = {return_type, serde_type, arg_typedescs, is_input_nullable, desc.nodes[0].is_nullable};
if (fn.name.function_name == "array_agg") {
if (fn.name.function_name == "array_agg" || fn.name.function_name == "group_concat") {
// set order by info
if (fn.aggregate_fn.__isset.is_asc_order && fn.aggregate_fn.__isset.nulls_first &&
!fn.aggregate_fn.is_asc_order.empty()) {
Expand Down Expand Up @@ -868,6 +868,7 @@ Status Aggregator::output_chunk_by_streaming(Chunk* input_chunk, ChunkPtr* chunk
DCHECK(!_group_by_columns.empty());

RETURN_IF_ERROR(evaluate_agg_fn_exprs(input_chunk));
std::cout << fmt::format("convert_to_serialize_format by streaming0") << std::endl;

const auto num_rows = _group_by_columns[0]->size();
Columns agg_result_column = _create_agg_result_columns(num_rows, true);
Expand All @@ -878,6 +879,7 @@ Status Aggregator::output_chunk_by_streaming(Chunk* input_chunk, ChunkPtr* chunk
DCHECK(i < _agg_input_columns.size() && _agg_input_columns[i].size() >= 1);
result_chunk->append_column(std::move(_agg_input_columns[i][0]), slot_id);
} else {
std::cout << fmt::format("convert_to_serialize_format by streaming1") << std::endl;
_agg_functions[i]->convert_to_serialize_format(_agg_fn_ctxs[i], _agg_input_columns[i],
result_chunk->num_rows(), &agg_result_column[i]);
result_chunk->append_column(std::move(agg_result_column[i]), slot_id);
Expand Down
6 changes: 6 additions & 0 deletions be/src/exprs/agg/factory/aggregate_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ static const AggregateFunction* get_function(const std::string& name, LogicalTyp
}
}

if (func_version > 6) {
if (name == "group_concat") {
func_name = "group_concat2";
}
}

if (binary_type == TFunctionBinaryType::BUILTIN) {
auto func = AggregateFuncResolver::instance()->get_aggregate_info(func_name, arg_type, return_type,
is_window_function, is_null);
Expand Down
4 changes: 4 additions & 0 deletions be/src/exprs/agg/factory/aggregate_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class AggregateFactory {
return std::make_shared<ArrayAggAggregateFunctionV2>();
}

static AggregateFunctionPtr MakeGroupConcatAggregateFunctionV2() {
return std::make_shared<GroupConcatAggregateFunctionV2>();
}

template <LogicalType LT>
static auto MakeMaxAggregateFunction();

Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/agg/factory/aggregate_resolver_others.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ void AggregateFuncResolver::register_others() {

add_general_mapping<AnyValueSemiState>("any_value", false, AggregateFactory::MakeAnyValueSemiAggregateFunction());
add_general_mapping_notnull("array_agg2", false, AggregateFactory::MakeArrayAggAggregateFunctionV2());
add_general_mapping_nullable_variadic<GroupConcatAggregateStateV2>(
"group_concat2", false, AggregateFactory::MakeGroupConcatAggregateFunctionV2());
}

} // namespace starrocks
235 changes: 235 additions & 0 deletions be/src/exprs/agg/group_concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@

#include <cmath>

#include "column/array_column.h"
#include "column/binary_column.h"
#include "column/column_helper.h"
#include "column/struct_column.h"
#include "column/type_traits.h"
#include "exec/sorting/sorting.h"
#include "exprs/agg/aggregate.h"
#include "exprs/function_context.h"
#include "gutil/casts.h"
#include "runtime/runtime_state.h"

namespace starrocks {
template <LogicalType LT, typename = guard::Guard>
Expand Down Expand Up @@ -292,4 +298,233 @@ class GroupConcatAggregateFunction
std::string get_name() const override { return "group concat"; }
};

// input columns result in intermediate result: struct{array[col0], array[col1], array[col2]... array[coln]}
// return ordered string("col0col1...colnSEPcol0col1...coln...")
struct GroupConcatAggregateStateV2 {
void update(FunctionContext* ctx, const Column& column, size_t index, size_t offset, size_t count) {
auto notnull = ColumnHelper::get_data_column(&column);
(*data_columns)[index]->append(*notnull, offset, count);
}

// release the trailing N-1 order-by columns
void release_order_by_columns(int output_col_num) const {
DCHECK(data_columns != nullptr);
for (auto i = output_col_num; i < data_columns->size(); ++i) {
data_columns->at(i).reset();
}
data_columns->resize(output_col_num);
}

~GroupConcatAggregateStateV2() {
if (data_columns != nullptr) {
for (auto& col : *data_columns) {
col.reset();
}
data_columns->clear();
delete data_columns;
data_columns = nullptr;
}
}
// using pointer rather than vector to avoid variadic size
// group_concat(a, b order by c, d), the a,b,c,d are put into data_columns in order.
Columns* data_columns = nullptr; // not null
};

class GroupConcatAggregateFunctionV2
: public AggregateFunctionBatchHelper<GroupConcatAggregateStateV2, GroupConcatAggregateFunctionV2> {
public:
void reset(FunctionContext* ctx, const Columns& args, AggDataPtr __restrict state) const override {
auto& state_impl = this->data(state);
if (state_impl.data_columns != nullptr) {
for (auto& col : *state_impl.data_columns) {
col->resize(0);
}
}
}

void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state,
size_t row_num) const override {
auto num = ctx->get_num_args();
GroupConcatAggregateStateV2& state_impl = this->data(state);
if (state_impl.data_columns == nullptr) { // init
state_impl.data_columns = new Columns;
for (auto i = 0; i < num; ++i) {
state_impl.data_columns->emplace_back(ctx->create_column(*ctx->get_arg_type(i), false));
}
DCHECK(ctx->get_is_asc_order().size() == ctx->get_nulls_first().size());
}
for (auto i = 0; i < num; ++i) {
auto* data_col = columns[i];
auto tmp_row_num = row_num;
if (columns[i]->is_constant()) {
// just copy the first const value.
data_col = down_cast<const ConstColumn*>(columns[i])->data_column().get();
tmp_row_num = 0;
}
this->data(state).update(ctx, *data_col, i, tmp_row_num, 1);
std::cout << fmt::format("after update {} -th col = {} id = {}, result {}", i, columns[i]->debug_string(),
row_num, (*state_impl.data_columns)[i]->debug_string())
<< std::endl;
}
}

void update_batch_single_state(FunctionContext* ctx, size_t chunk_size, const Column** columns,
AggDataPtr __restrict state) const override {
auto& state_impl = this->data(state);
for (auto& col : *state_impl.data_columns) {
col->resize(col->size() + chunk_size);
}
for (size_t i = 0; i < chunk_size; ++i) {
update(ctx, columns, state, i);
}
}

void update_batch_single_state_with_frame(FunctionContext* ctx, AggDataPtr __restrict state, const Column** columns,
int64_t peer_group_start, int64_t peer_group_end, int64_t frame_start,
int64_t frame_end) const override {
for (size_t i = frame_start; i < frame_end; ++i) {
update(ctx, columns, state, i);
}
}

void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override {
if (row_num >= column->size()) {
std::cout << fmt::format("row num {} >= size {}", row_num, column->size()) << std::endl;
throw std::runtime_error("merge error");
return;
}
std::cout << fmt::format("merge {}", column->debug_string()) << std::endl;
auto& input_columns = down_cast<const StructColumn*>(ColumnHelper::get_data_column(column))->fields();
auto& state_impl = this->data(state);
if (state_impl.data_columns == nullptr) {
auto num = ctx->get_num_args();
state_impl.data_columns = new Columns;
for (auto i = 0; i < num; ++i) {
state_impl.data_columns->emplace_back(ctx->create_column(*ctx->get_arg_type(i), false));
}
DCHECK(ctx->get_is_asc_order().size() == ctx->get_nulls_first().size());
}
for (auto i = 0; i < input_columns.size(); ++i) {
auto array_column = down_cast<const ArrayColumn*>(ColumnHelper::get_data_column(input_columns[i].get()));
auto& offsets = array_column->offsets().get_data();
state_impl.update(ctx, array_column->elements(), i, offsets[row_num],
offsets[row_num + 1] - offsets[row_num]);
}
}

void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {
std::cout << fmt::format("ser to name = {}, val = {}", to->get_name(), to->debug_string()) << std::endl;
auto& state_impl = this->data(state);
auto& columns = down_cast<StructColumn*>(ColumnHelper::get_data_column(to))->fields_column();
if (to->is_nullable()) {
down_cast<NullableColumn*>(to)->null_column_data().emplace_back(0);
}
for (auto i = 0; i < columns.size(); ++i) {
auto elem_size = (*state_impl.data_columns)[i]->size();
auto array_col = down_cast<ArrayColumn*>(ColumnHelper::get_data_column(columns[i].get()));
if (columns[i]->is_nullable()) {
down_cast<NullableColumn*>(columns[i].get())->null_column_data().emplace_back(0);
}

array_col->elements_column()->append(
*ColumnHelper::unpack_and_duplicate_const_column(elem_size, (*state_impl.data_columns)[i]), 0,
elem_size);

std::cout << fmt::format("after ser {} res {}", (*state_impl.data_columns)[i]->debug_string(),
array_col->elements_column()->debug_string())
<< std::endl;
auto& offsets = array_col->offsets_column()->get_data();
offsets.push_back(offsets.back() + elem_size);
}
std::cout << fmt::format("ser to name = {}, val = {}", to->get_name(), to->debug_string()) << std::endl;
}

void convert_to_serialize_format(FunctionContext* ctx, const Columns& src, size_t chunk_size,
ColumnPtr* dst) const override {
auto columns = down_cast<StructColumn*>(ColumnHelper::get_data_column(dst->get()))->fields_column();
if (dst->get()->is_nullable()) {
for (size_t i = 0; i < chunk_size; i++) {
down_cast<NullableColumn*>(dst->get())->null_column_data().emplace_back(0);
}
}
for (auto j = 0; j < columns.size(); ++j) {
auto array_col = down_cast<ArrayColumn*>(ColumnHelper::get_data_column(columns[j].get()));
if (columns[j].get()->is_nullable()) {
for (size_t i = 0; i < chunk_size; i++) {
down_cast<NullableColumn*>(columns[j].get())->null_column_data().emplace_back(0);
}
}
auto& element_column = array_col->elements_column();
auto& offsets = array_col->offsets_column()->get_data();
for (size_t i = 0; i < chunk_size; i++) {
element_column->append_datum(src[j]->get(i));
offsets.emplace_back(offsets.back() + 1);
}
std::cout << fmt::format("conv ser src {} to {}", src[j]->debug_string(), array_col->debug_string())
<< std::endl;
}
}

void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {
auto& state_impl = this->data(state);
auto elem_size = (*state_impl.data_columns)[0]->size();
auto output_col_num = state_impl.data_columns->size() - ctx->get_is_asc_order().size();
Columns outputs;
outputs.resize(output_col_num);
for (auto i = 0; i < output_col_num; ++i) {
outputs[i] = (*state_impl.data_columns)[i];
std::cout << fmt::format("finalize input i = {}, output = {}", i, outputs[i]->debug_string()) << std::endl;
}
if (!ctx->get_is_asc_order().empty()) {
for (auto i = 0; i < output_col_num; ++i) {
outputs[i] = (*state_impl.data_columns)[i]->clone_empty();
}
Permutation perm;
Columns order_by_columns;
SortDescs sort_desc(ctx->get_is_asc_order(), ctx->get_nulls_first());
order_by_columns.assign(state_impl.data_columns->begin() + output_col_num, state_impl.data_columns->end());
Status st = sort_and_tie_columns(ctx->state()->cancelled_ref(), order_by_columns, sort_desc, &perm);
// release order-by columns early
order_by_columns.clear();
state_impl.release_order_by_columns(output_col_num);
DCHECK(ctx->state()->cancelled_ref() || st.ok());
for (auto i = 0; i < output_col_num; ++i) {
materialize_column_by_permutation(outputs[i].get(), {(*state_impl.data_columns)[i]}, perm);
}
}
std::cout << fmt::format("finalize to name = {}, val = {}", to->get_name(), to->debug_string()) << std::endl;

auto* string = down_cast<BinaryColumn*>(ColumnHelper::get_data_column(to));
if (to->is_nullable()) {
down_cast<NullableColumn*>(to)->null_column_data().emplace_back(0);
}
/// TODO(fzh) just consider string type
Bytes& bytes = string->get_bytes();
size_t offset = bytes.size();
size_t length = 0;
for (auto i = 0; i < output_col_num; ++i) {
std::cout << i << " th col " << outputs[i]->get_name() << std::endl;
if (outputs[i]->is_binary()) {
length += down_cast<BinaryColumn*>(outputs[i].get())->get_bytes().size();
}
}

bytes.resize(offset + length);
Slice bstr(bytes.data(), offset);
for (auto j = 0; j < elem_size; ++j) {
for (auto i = 0; i < output_col_num; ++i) {
if (outputs[i]->is_binary()) {
auto str = down_cast<BinaryColumn*>(outputs[i].get())->get_slice(j);
memcpy(bytes.data() + offset, str.get_data(), str.get_size());
offset += str.get_size();
}
}
}
string->get_offset().emplace_back(offset);
std::cout << fmt::format("from {} to {}", bstr.to_string(), string->debug_string()) << std::endl;
}

std::string get_name() const override { return "group concat2"; }
};

} // namespace starrocks
1 change: 1 addition & 0 deletions be/src/exprs/agg/nullable_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ class NullableAggregateFunctionVariadic final
this->nested_function->convert_to_serialize_format(ctx, data_columns, chunk_size,
&dst_nullable_column->data_column());
}
std::cout << fmt::format("convert_to_serialize_format res = {}", (*dst)->debug_string()) << std::endl;
}

void retract(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ public List<OrderByElement> getOrderByElements() {
return orderByElements == null ? null : orderByElements.isEmpty() ? null : orderByElements;
}

public int getOrderByElemNum() {
return orderByElements == null ? 0 : orderByElements.size();
}

public String getOrderByStringToSql() {
if (orderByElements != null && !orderByElements.isEmpty()) {
StringBuilder sb = new StringBuilder();
Expand Down
12 changes: 4 additions & 8 deletions fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,10 @@ private void initAggregateBuiltins() {
Lists.newArrayList(Type.ANY_ELEMENT), Type.ANY_ARRAY, Type.ANY_STRUCT, true,
true, false, false));

addBuiltin(AggregateFunction.createBuiltin(GROUP_CONCAT,
Lists.newArrayList(Type.ANY_ELEMENT), Type.VARCHAR, Type.ANY_STRUCT, true,
false, false, false));

for (Type t : Type.getSupportedTypes()) {
if (t.isFunctionType()) {
continue;
Expand Down Expand Up @@ -971,14 +975,6 @@ private void initAggregateBuiltins() {
addBuiltin(AggregateFunction.createBuiltin(RETENTION, Lists.newArrayList(Type.ARRAY_BOOLEAN),
Type.ARRAY_BOOLEAN, Type.BIGINT, false, false, false));

// Group_concat(string)
addBuiltin(AggregateFunction.createBuiltin(GROUP_CONCAT,
Lists.newArrayList(Type.VARCHAR), Type.VARCHAR, Type.VARCHAR,
false, false, false));
// Group_concat(string, string)
addBuiltin(AggregateFunction.createBuiltin(GROUP_CONCAT,
Lists.newArrayList(Type.VARCHAR, Type.VARCHAR), Type.VARCHAR, Type.VARCHAR,
false, false, false));

// Type.DATE must before Type.DATATIME, because DATE could be considered as DATETIME.
addBuiltin(AggregateFunction.createBuiltin(WINDOW_FUNNEL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -917,12 +917,25 @@ public String visitFunctionCall(FunctionCallExpr node, Void context) {
StringLiteral boundary = (StringLiteral) node.getChild(3);
sb.append(", ").append(boundary.getValue());
sb.append(")");
} else if (functionName.equalsIgnoreCase(FunctionSet.ARRAY_AGG)) {
sb.append(visit(node.getChild(0)));
} else if (functionName.equals(FunctionSet.ARRAY_AGG) || functionName.equals(FunctionSet.GROUP_CONCAT)) {
int end = 1;
if (functionName.equals(FunctionSet.GROUP_CONCAT)) {
end = fnParams.exprs().size() - fnParams.getOrderByElemNum() - 1;
}
for (int i = 0; i < end; ++i) {
if (i != 0) {
sb.append(",");
}
sb.append(visit(node.getChild(i)));
}
List<OrderByElement> sortClause = fnParams.getOrderByElements();
if (sortClause != null) {
sb.append(" ORDER BY ").append(visitAstList(sortClause));
}
if (functionName.equals(FunctionSet.GROUP_CONCAT)) {
sb.append(" SEPARATOR ");
sb.append(visit(node.getChild(end)));
}
sb.append(")");
} else {
List<String> p = node.getChildren().stream().map(this::visit).collect(Collectors.toList());
Expand Down
Loading

0 comments on commit e31fbe8

Please sign in to comment.