Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] support push down agg distinct limit #55455

Merged
merged 2 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions be/src/exec/aggregate/agg_hash_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
#include "column/hash_set.h"
#include "column/type_traits.h"
#include "column/vectorized_fwd.h"
#include "exec/aggregate/agg_profile.h"
#include "gutil/casts.h"
#include "runtime/mem_pool.h"
#include "runtime/runtime_state.h"
#include "util/fixed_hash_map.h"
#include "util/hash_util.hpp"
#include "util/phmap/phmap.h"
#include "util/runtime_profile.h"

namespace starrocks {

Expand Down Expand Up @@ -91,9 +93,10 @@ using SliceAggTwoLevelHashSet =

template <typename HashSet, typename Impl>
struct AggHashSet {
AggHashSet() = default;
AggHashSet(size_t chunk_size, AggStatistics* agg_stat_) : agg_stat(agg_stat_) {}
using HHashSetType = HashSet;
HashSet hash_set;
AggStatistics* agg_stat;

////// Common Methods ////////
void build_hash_set(size_t chunk_size, const Columns& key_columns, MemPool* pool) {
Expand All @@ -117,14 +120,16 @@ constexpr bool is_no_prefetch_set = no_prefetch_set<T>::value;
// handle one number hash key
template <LogicalType logical_type, typename HashSet>
struct AggHashSetOfOneNumberKey : public AggHashSet<HashSet, AggHashSetOfOneNumberKey<logical_type, HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfOneNumberKey<logical_type, HashSet>>;
using KeyType = typename HashSet::key_type;
using Iterator = typename HashSet::iterator;
using ColumnType = RunTimeColumnType<logical_type>;
using ResultVector = typename ColumnType::Container;
using FieldType = RunTimeCppType<logical_type>;
static_assert(sizeof(FieldType) <= sizeof(KeyType), "hash set key size needs to be larger than the actual element");

AggHashSetOfOneNumberKey(int32_t chunk_size) {}
template <class... Args>
AggHashSetOfOneNumberKey(Args&&... args) : Base(std::forward<Args>(args)...) {}

// When compute_and_allocate=false:
// Elements queried in HashSet will be added to HashSet
Expand Down Expand Up @@ -194,6 +199,7 @@ struct AggHashSetOfOneNumberKey : public AggHashSet<HashSet, AggHashSetOfOneNumb
template <LogicalType logical_type, typename HashSet>
struct AggHashSetOfOneNullableNumberKey
: public AggHashSet<HashSet, AggHashSetOfOneNullableNumberKey<logical_type, HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfOneNullableNumberKey<logical_type, HashSet>>;
using KeyType = typename HashSet::key_type;
using Iterator = typename HashSet::iterator;
using ColumnType = RunTimeColumnType<logical_type>;
Expand All @@ -202,7 +208,8 @@ struct AggHashSetOfOneNullableNumberKey

static_assert(sizeof(FieldType) <= sizeof(KeyType), "hash set key size needs to be larger than the actual element");

AggHashSetOfOneNullableNumberKey(int32_t chunk_size) {}
template <class... Args>
AggHashSetOfOneNullableNumberKey(Args&&... args) : Base(std::forward<Args>(args)...) {}

// When compute_and_allocate=false:
// Elements queried in HashSet will be added to HashSet
Expand Down Expand Up @@ -296,11 +303,13 @@ struct AggHashSetOfOneNullableNumberKey

template <typename HashSet>
struct AggHashSetOfOneStringKey : public AggHashSet<HashSet, AggHashSetOfOneStringKey<HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfOneStringKey<HashSet>>;
using Iterator = typename HashSet::iterator;
using KeyType = typename HashSet::key_type;
using ResultVector = Buffer<Slice>;

AggHashSetOfOneStringKey(int32_t chunk_size) {}
template <class... Args>
AggHashSetOfOneStringKey(Args&&... args) : Base(std::forward<Args>(args)...) {}

// When compute_and_allocate=false:
// Elements queried in HashSet will be added to HashSet
Expand Down Expand Up @@ -379,12 +388,13 @@ struct AggHashSetOfOneStringKey : public AggHashSet<HashSet, AggHashSetOfOneStri

template <typename HashSet>
struct AggHashSetOfOneNullableStringKey : public AggHashSet<HashSet, AggHashSetOfOneNullableStringKey<HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfOneNullableStringKey<HashSet>>;
using Iterator = typename HashSet::iterator;
using KeyType = typename HashSet::key_type;
// using ResultVector = typename std::vector<Slice>;
using ResultVector = Buffer<Slice>;

AggHashSetOfOneNullableStringKey(int32_t chunk_size) {}
template <class... Args>
AggHashSetOfOneNullableStringKey(Args&&... args) : Base(std::forward<Args>(args)...) {}

// When compute_and_allocate=false:
// Elements queried in HashSet will be added to HashSet
Expand Down Expand Up @@ -496,13 +506,15 @@ struct AggHashSetOfOneNullableStringKey : public AggHashSet<HashSet, AggHashSetO

template <typename HashSet>
struct AggHashSetOfSerializedKey : public AggHashSet<HashSet, AggHashSetOfSerializedKey<HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfSerializedKey<HashSet>>;
using Iterator = typename HashSet::iterator;
// using ResultVector = typename std::vector<Slice>;
using ResultVector = Buffer<Slice>;
using KeyType = typename HashSet::key_type;

AggHashSetOfSerializedKey(int32_t chunk_size)
: _mem_pool(std::make_unique<MemPool>()),
template <class... Args>
AggHashSetOfSerializedKey(int32_t chunk_size, Args&&... args)
: Base(chunk_size, std::forward<Args>(args)...),
_mem_pool(std::make_unique<MemPool>()),
_buffer(_mem_pool->allocate(max_one_row_size * chunk_size + SLICE_MEMEQUAL_OVERFLOW_PADDING)),
_chunk_size(chunk_size) {}

Expand Down Expand Up @@ -623,6 +635,7 @@ struct AggHashSetOfSerializedKey : public AggHashSet<HashSet, AggHashSetOfSerial

template <typename HashSet>
struct AggHashSetOfSerializedKeyFixedSize : public AggHashSet<HashSet, AggHashSetOfSerializedKeyFixedSize<HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfSerializedKeyFixedSize<HashSet>>;
using Iterator = typename HashSet::iterator;
using KeyType = typename HashSet::key_type;
using FixedSizeSliceKey = typename HashSet::key_type;
Expand All @@ -632,8 +645,10 @@ struct AggHashSetOfSerializedKeyFixedSize : public AggHashSet<HashSet, AggHashSe
int fixed_byte_size = -1; // unset state
static constexpr size_t max_fixed_size = sizeof(FixedSizeSliceKey);

AggHashSetOfSerializedKeyFixedSize(int32_t chunk_size)
: _mem_pool(std::make_unique<MemPool>()),
template <class... Args>
AggHashSetOfSerializedKeyFixedSize(int32_t chunk_size, Args&&... args)
: Base(chunk_size, std::forward<Args>(args)...),
_mem_pool(std::make_unique<MemPool>()),
buffer(_mem_pool->allocate(max_fixed_size * chunk_size + SLICE_MEMEQUAL_OVERFLOW_PADDING)),
_chunk_size(chunk_size) {
memset(buffer, 0x0, max_fixed_size * _chunk_size);
Expand Down
5 changes: 3 additions & 2 deletions be/src/exec/aggregate/agg_hash_variant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,12 @@ size_t AggHashMapVariant::allocated_memory_usage(const MemPool* pool) const {

void AggHashSetVariant::init(RuntimeState* state, Type type, AggStatistics* agg_stat) {
_type = type;
_agg_stat = agg_stat;
switch (_type) {
#define M(NAME) \
case Type::NAME: \
hash_set_with_key = std::make_unique<detail::AggHashSetVariantTypeTraits<Type::NAME>::HashSetWithKeyType>( \
state->chunk_size()); \
state->chunk_size(), _agg_stat); \
break;
APPLY_FOR_AGG_VARIANT_ALL(M)
#undef M
Expand All @@ -255,7 +256,7 @@ void AggHashSetVariant::init(RuntimeState* state, Type type, AggStatistics* agg_
#define CONVERT_TO_TWO_LEVEL_SET(DST, SRC) \
if (_type == AggHashSetVariant::Type::SRC) { \
auto dst = std::make_unique<detail::AggHashSetVariantTypeTraits<Type::DST>::HashSetWithKeyType>( \
state->chunk_size()); \
state->chunk_size(), _agg_stat); \
std::visit( \
[&](auto& hash_set_with_key) { \
if constexpr (std::is_same_v<typename decltype(hash_set_with_key->hash_set)::key_type, \
Expand Down
1 change: 1 addition & 0 deletions be/src/exec/aggregate/agg_hash_variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ struct AggHashSetVariant {

private:
Type _type = Type::phase1_slice;
AggStatistics* _agg_stat = nullptr;
};

} // namespace starrocks
7 changes: 6 additions & 1 deletion be/src/exec/aggregate/aggregate_blocking_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,12 @@ pipeline::OpFactories AggregateBlockingNode::decompose_to_pipeline(pipeline::Pip
_decompose_to_pipeline<StreamingAggregatorFactory, SortedAggregateStreamingSourceOperatorFactory,
SortedAggregateStreamingSinkOperatorFactory>(ops_with_sink, context, false);
} else {
if (runtime_state()->enable_spill() && runtime_state()->enable_agg_spill() && has_group_by_keys) {
// disable spill when group by with a small limit
bool enable_agg_spill = runtime_state()->enable_spill() && runtime_state()->enable_agg_spill();
if (limit() != -1 && limit() < runtime_state()->chunk_size()) {
enable_agg_spill = false;
}
if (enable_agg_spill && has_group_by_keys) {
ops_with_source = _decompose_to_pipeline<AggregatorFactory, SpillableAggregateBlockingSourceOperatorFactory,
SpillableAggregateBlockingSinkOperatorFactory>(ops_with_sink,
context, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ class AggregateBlockingSinkOperator : public Operator {
// - reffed at constructor() of both sink and source operator,
// - unreffed at close() of both sink and source operator.
AggregatorPtr _aggregator = nullptr;
bool _agg_group_by_with_limit = false;

private:
// Whether prev operator has no output
std::atomic_bool _is_finished = false;
// whether enable aggregate group by limit optimize
bool _agg_group_by_with_limit = false;
std::atomic<int64_t>& _shared_limit_countdown;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Status AggregateDistinctStreamingSinkOperator::prepare(RuntimeState* state) {
if (_aggregator->streaming_preaggregation_mode() == TStreamingPreaggregationMode::LIMITED_MEM) {
_limited_mem_state.limited_memory_size = config::streaming_agg_limited_memory_size;
}
_aggregator->streaming_preaggregation_mode() = TStreamingPreaggregationMode::FORCE_PREAGGREGATION;
_aggregator->attach_sink_observer(state, this->_observer);
return _aggregator->open(state);
}
Expand All @@ -38,8 +39,13 @@ void AggregateDistinctStreamingSinkOperator::close(RuntimeState* state) {
}

Status AggregateDistinctStreamingSinkOperator::set_finishing(RuntimeState* state) {
if (_is_finished) return Status::OK();
ONCE_DETECT(_set_finishing_once);
auto notify = _aggregator->defer_notify_source();
_is_finished = true;
auto defer = DeferOp([this]() {
_aggregator->sink_complete();
_is_finished = true;
});

// skip processing if cancelled
if (state->is_cancelled()) {
Expand All @@ -50,7 +56,6 @@ Status AggregateDistinctStreamingSinkOperator::set_finishing(RuntimeState* state
_aggregator->set_ht_eos();
}

_aggregator->sink_complete();
return Status::OK();
}

Expand All @@ -70,7 +75,14 @@ Status AggregateDistinctStreamingSinkOperator::push_chunk(RuntimeState* state, c

_aggregator->update_num_input_rows(chunk_size);
COUNTER_SET(_aggregator->input_row_count(), _aggregator->num_input_rows());

bool limit_with_no_agg = _aggregator->limit() != -1;
if (limit_with_no_agg) {
auto size = _aggregator->hash_set_variant().size();
if (size >= _aggregator->limit()) {
(void)set_finishing(state);
return Status::OK();
}
}
RETURN_IF_ERROR(_aggregator->evaluate_groupby_exprs(chunk.get()));

if (_aggregator->streaming_preaggregation_mode() == TStreamingPreaggregationMode::FORCE_STREAMING) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class AggregateDistinctStreamingSinkOperator : public Operator {
// Whether prev operator has no output
bool _is_finished = false;
LimitedMemAggState _limited_mem_state;
DECLARE_ONCE_DETECTOR(_set_finishing_once);
};

class AggregateDistinctStreamingSinkOperatorFactory final : public OperatorFactory {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ Status SpillableAggregateBlockingSinkOperator::prepare(RuntimeState* state) {
_peak_revocable_mem_bytes = _unique_metrics->AddHighWaterMarkCounter(
"PeakRevocableMemoryBytes", TUnit::BYTES, RuntimeProfile::Counter::create_strategy(TUnit::BYTES));
_hash_table_spill_times = ADD_COUNTER(_unique_metrics.get(), "HashTableSpillTimes", TUnit::UNIT);
_agg_group_by_with_limit = false;
_aggregator->params()->enable_pipeline_share_limit = false;

return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,8 @@ public static MaterializedViewRewriteMode parse(String str) {

public static final String CBO_PUSHDOWN_TOPN_LIMIT = "cbo_push_down_topn_limit";

public static final String CBO_PUSHDOWN_DISTINCT_LIMIT = "cbo_push_down_distinct_limit";

public static final String ENABLE_AGGREGATION_PIPELINE_SHARE_LIMIT = "enable_aggregation_pipeline_share_limit";

public static final String ENABLE_EXPR_PRUNE_PARTITION = "enable_expr_prune_partition";
Expand Down Expand Up @@ -1639,6 +1641,9 @@ public static MaterializedViewRewriteMode parse(String str) {
@VarAttr(name = CBO_PUSHDOWN_TOPN_LIMIT)
private long cboPushDownTopNLimit = 1000;

@VarAttr(name = CBO_PUSHDOWN_DISTINCT_LIMIT)
private long cboPushDownDistinctLimit = 4096;

@VarAttr(name = ENABLE_AGGREGATION_PIPELINE_SHARE_LIMIT, flag = VariableMgr.INVISIBLE)
private boolean enableAggregationPipelineShareLimit = true;

Expand Down Expand Up @@ -1709,6 +1714,10 @@ public long getCboPushDownTopNLimit() {
return cboPushDownTopNLimit;
}

public long cboPushDownDistinctLimit() {
return cboPushDownDistinctLimit;
}

public void setCboPushDownTopNLimit(long cboPushDownTopNLimit) {
this.cboPushDownTopNLimit = cboPushDownTopNLimit;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,18 @@ public List<OptExpression> transform(OptExpression input, OptimizerContext conte
}
}

long localAggLimit = Operator.DEFAULT_LIMIT;
boolean isOnlyGroupBy = aggOp.getAggregations().isEmpty();
if (isOnlyGroupBy && aggOp.getLimit() < context.getSessionVariable().cboPushDownDistinctLimit()) {
localAggLimit = aggOp.getLimit();
}

LogicalAggregationOperator local = new LogicalAggregationOperator.Builder().withOperator(aggOp)
.setType(AggType.LOCAL)
.setAggregations(createNormalAgg(AggType.LOCAL, newAggMap))
.setSplit()
.setPredicate(null)
.setLimit(Operator.DEFAULT_LIMIT)
.setLimit(localAggLimit)
.setProjection(null)
.build();
OptExpression localOptExpression = OptExpression.create(local, input.getInputs());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2211,6 +2211,7 @@ public PlanFragment visitPhysicalHashAggregate(OptExpression optExpr, ExecPlan c
hasColocateOlapScanChildInFragment(aggregationNode)) {
aggregationNode.setColocate(!node.isWithoutColocateRequirement());
}
aggregationNode.setLimit(node.getLimit());
} else if (node.getType().isGlobal() || (node.getType().isLocal() && !node.isSplit())) {
// Local && un-split aggregate meanings only execute local pre-aggregation, we need promise
// output type match other node, so must use `update finalized` phase
Expand Down
12 changes: 12 additions & 0 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2970,6 +2970,18 @@ public void testAvgDecimalScale() throws Exception {
" | cardinality: 1");
}

@Test
public void testOnlyGroupByLimit() throws Exception {
FeConstants.runningUnitTest = true;
String sql = "select distinct v1 + v2 as vx from t0 limit 10";
String plan = getFragmentPlan(sql);
assertContains(plan, " 2:AGGREGATE (update serialize)\n" +
" | STREAMING\n" +
" | group by: 4: expr\n" +
" | limit: 10");
FeConstants.runningUnitTest = false;
}

@Test
public void testHavingAggregate() throws Exception {
String sql = "select * from (" +
Expand Down
Loading