Skip to content

Commit

Permalink
[Feature] support merge dict (StarRocks#416)
Browse files Browse the repository at this point in the history
used for low cardinality data collect

agg function: dict_merge
input_type: TYPE_ARRAY (VARCHAR)
return_type: TYPE_VARCHAR (serialized thrift binary protocol: json)
  • Loading branch information
stdpain authored Oct 12, 2021
1 parent 63c794f commit e85015a
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 0 deletions.
31 changes: 31 additions & 0 deletions be/src/exprs/agg/aggregate_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "exprs/agg/aggregate_factory.h"

#include <tuple>
#include <unordered_map>

#include "column/type_traits.h"
Expand Down Expand Up @@ -126,6 +127,10 @@ AggregateFunctionPtr AggregateFactory::MakeSumDistinctAggregateFunctionV2() {
return std::make_shared<DistinctAggregateFunctionV2<PT, AggDistinctType::SUM>>();
}

AggregateFunctionPtr AggregateFactory::MakeDictMergeAggregateFunction() {
return std::make_shared<DictMergeAggregateFunction>();
}

AggregateFunctionPtr AggregateFactory::MakeHllUnionAggregateFunction() {
return std::make_shared<HllUnionAggregateFunction>();
}
Expand Down Expand Up @@ -218,6 +223,14 @@ class AggregateFuncResolver {
create_object_function<arg_type, return_type, true>(name));
}

template <PrimitiveType arg_type, PrimitiveType return_type>
void add_array_mapping(std::string&& name) {
_infos_mapping.emplace(std::make_tuple(name, arg_type, return_type, false),
create_array_function<arg_type, return_type, false>(name));
_infos_mapping.emplace(std::make_tuple(name, arg_type, return_type, true),
create_array_function<arg_type, return_type, true>(name));
}

template <PrimitiveType arg_type, PrimitiveType return_type, bool is_null>
AggregateFunctionPtr create_object_function(std::string& name) {
if constexpr (is_null) {
Expand Down Expand Up @@ -276,6 +289,22 @@ class AggregateFuncResolver {
return nullptr;
}

template <PrimitiveType arg_type, PrimitiveType return_type, bool is_null>
AggregateFunctionPtr create_array_function(std::string& name) {
if constexpr (is_null) {
if (name == "dict_merge") {
auto dict_merge = AggregateFactory::MakeDictMergeAggregateFunction();
return AggregateFactory::MakeNullableAggregateFunctionUnary<DictMergeState>(dict_merge);
}
} else {
if (name == "dict_merge") {
return AggregateFactory::MakeDictMergeAggregateFunction();
}
}

return nullptr;
}

// TODO(kks): simplify create_function method
template <PrimitiveType ArgPT, PrimitiveType ReturnPT, bool is_null>
std::enable_if_t<isArithmeticPT<ArgPT>, AggregateFunctionPtr> create_function(std::string& name) {
Expand Down Expand Up @@ -734,6 +763,8 @@ AggregateFuncResolver::AggregateFuncResolver() {
add_object_mapping<TYPE_DOUBLE, TYPE_DOUBLE>("percentile_approx");

add_object_mapping<TYPE_PERCENTILE, TYPE_PERCENTILE>("percentile_union");

add_array_mapping<TYPE_ARRAY, TYPE_VARCHAR>("dict_merge");
}

#undef ADD_ALL_TYPE
Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/agg/aggregate_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class AggregateFactory {
template <PrimitiveType PT>
static AggregateFunctionPtr MakeSumDistinctAggregateFunctionV2();

static AggregateFunctionPtr MakeDictMergeAggregateFunction();

// Hyperloglog functions:
static AggregateFunctionPtr MakeHllUnionAggregateFunction();

Expand Down
110 changes: 110 additions & 0 deletions be/src/exprs/agg/distinct.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@

#pragma once

#include <cstring>
#include <limits>
#include <type_traits>

#include "column/array_column.h"
#include "column/binary_column.h"
#include "column/fixed_length_column.h"
#include "column/hash_set.h"
#include "column/type_traits.h"
#include "column/vectorized_fwd.h"
#include "exprs/agg/aggregate.h"
#include "exprs/agg/sum.h"
#include "gen_cpp/Data_types.h"
#include "glog/logging.h"
#include "gutil/casts.h"
#include "runtime/mem_pool.h"
#include "thrift/protocol/TJSONProtocol.h"
#include "udf/udf_internal.h"
#include "util/phmap/phmap_dump.h"
#include "util/slice.h"

namespace starrocks::vectorized {

Expand Down Expand Up @@ -428,4 +437,105 @@ class DistinctAggregateFunction : public TDistinctAggregateFunction<PT, Distinct
template <PrimitiveType PT, AggDistinctType DistinctType, typename T = RunTimeCppType<PT>>
class DistinctAggregateFunctionV2 : public TDistinctAggregateFunction<PT, DistinctAggregateStateV2, DistinctType, T> {};

// now we only support String
struct DictMergeState : DistinctAggregateStateV2<TYPE_VARCHAR> {
DictMergeState() = default;
};

class DictMergeAggregateFunction final
: public AggregateFunctionBatchHelper<DictMergeState, DictMergeAggregateFunction> {
public:
void update(FunctionContext* ctx, const Column** columns, AggDataPtr state, size_t row_num) const override {
DCHECK(false) << "this method shouldn't be called";
}

void update_batch_single_state(FunctionContext* ctx, size_t batch_size, const Column** columns,
AggDataPtr state) const override {
size_t mem_usage = 0;
auto& agg_state = this->data(state);
const auto* column = down_cast<const ArrayColumn*>(columns[0]);
MemPool* mem_pool = ctx->impl()->mem_pool();

const auto& elements_column = column->elements();
if (column->elements().is_nullable()) {
const auto& null_column = down_cast<const NullableColumn&>(elements_column);
const auto& null_data = null_column.immutable_null_column_data();
const auto& binary_column = down_cast<const BinaryColumn&>(null_column.data_column_ref());

for (size_t i = 0; i < binary_column.size(); ++i) {
if (!null_data[i]) {
mem_usage += agg_state.update(mem_pool, binary_column.get_slice(i));
}
}
} else {
const auto& binary_column = down_cast<const BinaryColumn&>(elements_column);
for (size_t i = 0; i < binary_column.size(); ++i) {
mem_usage += agg_state.update(mem_pool, binary_column.get_slice(i));
}
}
}

void merge(FunctionContext* ctx, const Column* column, AggDataPtr state, size_t row_num) const override {
const auto* input_column = down_cast<const BinaryColumn*>(column);
Slice slice = input_column->get_slice(row_num);
size_t mem_usage = 0;
mem_usage += this->data(state).deserialize_and_merge(ctx->impl()->mem_pool(), (const uint8_t*)slice.data,
slice.size);
ctx->impl()->add_mem_usage(mem_usage);
}

void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr state, Column* to) const override {
auto* column = down_cast<BinaryColumn*>(to);
size_t old_size = column->get_bytes().size();
size_t new_size = old_size + this->data(state).serialize_size();
column->get_bytes().resize(new_size);
this->data(state).serialize(column->get_bytes().data() + old_size);
column->get_offset().emplace_back(new_size);
}

void convert_to_serialize_format(const Columns& src, size_t chunk_size, ColumnPtr* dst) const override {
DCHECK(false) << "this method shouldn't be called";
}

void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr state, Column* to) const override {
if (this->data(state).set.size() == 0) {
to->append_default();
return;
}
std::vector<int32_t> dict_ids;
dict_ids.resize(this->data(state).set.size());

auto* binary_column = down_cast<BinaryColumn*>(to);

// set dict_ids as [1...n]
for (int i = 0; i < dict_ids.size(); ++i) {
dict_ids[i] = i + 1;
}

TGlobalDict tglobal_dict;
tglobal_dict.__isset.ids = true;
tglobal_dict.ids = std::move(dict_ids);
tglobal_dict.__isset.strings = true;
tglobal_dict.strings.reserve(dict_ids.size());

for (const auto& v : this->data(state).set) {
tglobal_dict.strings.emplace_back(v.data, v.size);
}

std::string result_value = apache::thrift::ThriftJSONString(tglobal_dict);

size_t old_size = binary_column->get_bytes().size();
size_t new_size = old_size + result_value.size();

auto& data = binary_column->get_bytes();
data.resize(old_size + new_size);

memcpy(data.data() + old_size, result_value.data(), result_value.size());

binary_column->get_offset().emplace_back(new_size);
}

std::string get_name() const override { return "dict_merge"; }
};

} // namespace starrocks::vectorized
12 changes: 12 additions & 0 deletions be/src/util/thrift_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ void t_network_address_to_string(const TNetworkAddress& address, std::string* ou
// string representation
bool t_network_address_comparator(const TNetworkAddress& a, const TNetworkAddress& b);

template <typename ThriftStruct>
ThriftStruct from_json_string(const std::string& json_val) {
using namespace apache::thrift::transport;
using namespace apache::thrift::protocol;
ThriftStruct ts;
TMemoryBuffer* buffer = new TMemoryBuffer((uint8_t*)json_val.c_str(), (uint32_t)json_val.size());
std::shared_ptr<TTransport> trans(buffer);
TJSONProtocol protocol(trans);
ts.read(&protocol);
return ts;
}

} // namespace starrocks

#endif
63 changes: 63 additions & 0 deletions be/test/exprs/agg/aggregate_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,27 @@
#include <gtest/gtest.h>
#include <math.h>

#include <algorithm>

#include "column/array_column.h"
#include "column/column_builder.h"
#include "column/fixed_length_column.h"
#include "column/nullable_column.h"
#include "column/vectorized_fwd.h"
#include "exprs/agg/aggregate_factory.h"
#include "exprs/agg/maxmin.h"
#include "exprs/agg/nullable_aggregate.h"
#include "exprs/agg/sum.h"
#include "exprs/vectorized/arithmetic_operation.h"
#include "gen_cpp/Data_types.h"
#include "gutil/casts.h"
#include "runtime/vectorized/time_types.h"
#include "testutil/function_utils.h"
#include "udf/udf_internal.h"
#include "util/bitmap_value.h"
#include "util/slice.h"
#include "util/thrift_util.h"
#include "util/unaligned_access.h"

namespace starrocks::vectorized {

Expand Down Expand Up @@ -617,6 +628,58 @@ TEST_F(AggregateTest, test_sum_distinct) {
DecimalV2Value(21));
}

TEST_F(AggregateTest, test_dict_merge) {
const AggregateFunction* func = get_aggregate_function("dict_merge", TYPE_ARRAY, TYPE_VARCHAR, false);
ColumnBuilder<TYPE_VARCHAR> builder;
builder.append(Slice("key1"));
builder.append(Slice("key2"));
builder.append(Slice("starrocks-1"));
builder.append(Slice("starrocks-starrocks"));
builder.append(Slice("starrocks-starrocks"));
auto data_col = builder.build(false);

auto offsets = UInt32Column::create();
offsets->append(0);
offsets->append(0);
offsets->append(2);
offsets->append(5);
// []
// [key1, key2]
// [sr-1, sr-2, sr-3]
auto col = ArrayColumn::create(data_col, offsets);
const Column* column = col.get();
std::unique_ptr<ManagedAggregateState> state = ManagedAggregateState::Make(func);
func->update_batch_single_state(ctx, col->size(), &column, state->mutable_data());

auto res = BinaryColumn::create();
func->finalize_to_column(ctx, state->data(), res.get());

ASSERT_EQ(res->size(), 1);
auto slice = res->get_slice(0);
std::map<int, std::string> datas;
auto dict = from_json_string<TGlobalDict>(std::string(slice.data, slice.size));
int sz = dict.ids.size();
for (int i = 0; i < sz; ++i) {
datas.emplace(dict.ids[i], dict.strings[i]);
}
ASSERT_EQ(dict.ids.size(), dict.strings.size());

std::set<std::string> origin_data;
std::set<int> ids;
auto binary_column = down_cast<BinaryColumn*>(data_col.get());
for (int i = 0; i < binary_column->size(); ++i) {
auto slice = binary_column->get_slice(i);
origin_data.emplace(slice.data, slice.size);
}

for (const auto& [k, v] : datas) {
ASSERT_TRUE(origin_data.count(v) != 0);
origin_data.erase(v);
}

ASSERT_TRUE(origin_data.empty());
}

TEST_F(AggregateTest, test_sum_nullable) {
using NullableSumInt64 = NullableAggregateFunctionState<SumAggregateState<int64_t>>;
const AggregateFunction* sum_null = get_aggregate_function("sum", TYPE_INT, TYPE_BIGINT, true);
Expand Down

0 comments on commit e85015a

Please sign in to comment.