Skip to content

Commit

Permalink
[Feature] improve hash join performance by coroutine-based interleavi…
Browse files Browse the repository at this point in the history
…ng (StarRocks#27907)

* [Feature] improve hash join performance by coroutine-based interleaving
interleaving_group_size controls enable interleaving or not, default 10. It means how many coroutines to interleave per thread.

0 means disable interleaving;
positive values mean adaptive interleaving with group = interleaving_group_size, enable it if the hash table is enough large(>32MB);
negtive values mean force interleaving with group = abs(interleaving_group_size);
---------

Signed-off-by: Zhuhe Fang <fzhedu@gmail.com>
  • Loading branch information
fzhedu authored Sep 7, 2023
1 parent 8b27648 commit e056a01
Show file tree
Hide file tree
Showing 11 changed files with 934 additions and 301 deletions.
2 changes: 2 additions & 0 deletions be/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,8 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# ignore warning from apache-orc
set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-unsafe-buffer-usage")
endif ()
else ()
set(CXX_GCC_FLAGS "${CXX_GCC_FLAGS} -fcoroutines")
endif()

set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -DBOOST_DATE_TIME_POSIX_TIME_STD_CONFIG")
Expand Down
11 changes: 4 additions & 7 deletions be/src/exec/hash_joiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ void HashJoinBuildMetrics::prepare(RuntimeProfile* runtime_profile) {
build_conjunct_evaluate_timer = ADD_TIMER(runtime_profile, "BuildConjunctEvaluateTime");
build_buckets_counter = ADD_COUNTER(runtime_profile, "BuildBuckets", TUnit::UNIT);
runtime_filter_num = ADD_COUNTER(runtime_profile, "RuntimeFilterNum", TUnit::UNIT);
build_keys_per_bucket = ADD_COUNTER(runtime_profile, "BuildKeysPerBucket%", TUnit::UNIT);
}

HashJoiner::HashJoiner(const HashJoinerParam& param)
Expand Down Expand Up @@ -211,6 +212,7 @@ Status HashJoiner::build_ht(RuntimeState* state) {
RETURN_IF_ERROR(_hash_join_builder->build(state));
size_t bucket_size = _hash_join_builder->hash_table().get_bucket_size();
COUNTER_SET(build_metrics().build_buckets_counter, static_cast<int64_t>(bucket_size));
COUNTER_SET(build_metrics().build_keys_per_bucket, static_cast<int64_t>(100 * avg_keys_per_bucket()));
}

return Status::OK();
Expand Down Expand Up @@ -346,14 +348,9 @@ void HashJoiner::decr_prober(RuntimeState* state) {
}
}

size_t HashJoiner::avg_keys_perf_bucket() const {
float HashJoiner::avg_keys_per_bucket() const {
const auto& hash_table = _hash_join_builder->hash_table();
size_t used_bucket_count = hash_table.get_used_bucket_count();
if (used_bucket_count == 0) {
return 0;
}
size_t count = hash_table.get_row_count() / used_bucket_count;
return count;
return hash_table.get_keys_per_bucket();
}

Status HashJoiner::reset_probe(starrocks::RuntimeState* state) {
Expand Down
3 changes: 2 additions & 1 deletion be/src/exec/hash_joiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ struct HashJoinBuildMetrics {
RuntimeProfile::Counter* build_conjunct_evaluate_timer = nullptr;
RuntimeProfile::Counter* build_buckets_counter = nullptr;
RuntimeProfile::Counter* runtime_filter_num = nullptr;
RuntimeProfile::Counter* build_keys_per_bucket = nullptr;

void prepare(RuntimeProfile* runtime_profile);
};
Expand Down Expand Up @@ -232,7 +233,7 @@ class HashJoiner final : public pipeline::ContextWithDependency {
Columns string_key_columns() { return _string_key_columns; }
Status reset_probe(RuntimeState* state);

size_t avg_keys_perf_bucket() const;
float avg_keys_per_bucket() const;

const HashJoinBuildMetrics& build_metrics() { return *_build_metrics; }
const HashJoinProbeMetrics& probe_metrics() { return *_probe_metrics; }
Expand Down
56 changes: 50 additions & 6 deletions be/src/exec/join_hash_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,52 @@

namespace starrocks {

// if the same hash values are clustered, after the first probe, all related hash buckets are cached, without too many
// misses. So check time locality of probe keys here.
void HashTableProbeState::consider_probe_time_locality() {
if (active_coroutines > 0) {
// redo decision
if ((probe_chunks & (detect_step - 1)) == 0) {
int window_size = std::min(active_coroutines * 4, 50);
if (probe_row_count > window_size) {
phmap::flat_hash_map<uint32_t, uint32_t> occurrence;
occurrence.reserve(probe_row_count);
uint32_t unique_size = 0;
bool enable_interleaving = true;
uint32_t target = probe_row_count >> 3;
for (auto i = 0; i < probe_row_count; i++) {
if (occurrence[next[i]] == 0) {
++unique_size;
if (unique_size >= target) {
break;
}
}
occurrence[next[i]]++;
if (i >= window_size) {
occurrence[next[i - window_size]]--;
}
}
if (unique_size < target) {
active_coroutines = 0;
enable_interleaving = false;
}
// enlarge step if the decision is the same, otherwise reduce it
if (enable_interleaving == last_enable_interleaving) {
detect_step = detect_step >= 1024 ? detect_step : (detect_step << 1);
} else {
last_enable_interleaving = enable_interleaving;
detect_step = 1;
}
} else {
active_coroutines = 0;
}
} else if (!last_enable_interleaving) {
active_coroutines = 0;
}
}
++probe_chunks;
}

void SerializedJoinBuildFunc::prepare(RuntimeState* state, JoinHashTableItems* table_items) {
table_items->bucket_size = JoinHashMapHelper::calc_bucket_size(table_items->row_count + 1);
table_items->first.resize(table_items->bucket_size, 0);
Expand Down Expand Up @@ -82,6 +128,7 @@ void SerializedJoinBuildFunc::construct_hash_table(RuntimeState* state, JoinHash
}
_build_columns(table_items, probe_state, data_columns, 1 + state->chunk_size() * quo, rem, &ptr);
}
table_items->calculate_ht_info(serialize_size);
}

void SerializedJoinBuildFunc::_build_columns(JoinHashTableItems* table_items, HashTableProbeState* probe_state,
Expand Down Expand Up @@ -163,6 +210,7 @@ void SerializedJoinProbeFunc::lookup_init(const JoinHashTableItems& table_items,
} else {
_probe_column(table_items, probe_state, data_columns, ptr);
}
probe_state->consider_probe_time_locality();
}

void SerializedJoinProbeFunc::_probe_column(const JoinHashTableItems& table_items, HashTableProbeState* probe_state,
Expand Down Expand Up @@ -249,12 +297,8 @@ void JoinHashTable::set_probe_profile(RuntimeProfile::Counter* search_ht_timer,
_probe_state->output_build_column_timer = output_build_column_timer;
}

size_t JoinHashTable::get_used_bucket_count() const {
size_t count = 0;
for (const auto value : _table_items->first) {
count += value != 0;
}
return count;
float JoinHashTable::get_keys_per_bucket() const {
return _table_items->get_keys_per_bucket();
}

void JoinHashTable::close() {
Expand Down
126 changes: 97 additions & 29 deletions be/src/exec/join_hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include <runtime/descriptors.h>
#include <runtime/runtime_state.h>

#include <coroutine>
#include <cstdint>
#include <set>

#include "column/chunk.h"
#include "column/column_hash.h"
Expand Down Expand Up @@ -112,7 +114,7 @@ struct JoinHashTableItems {
Buffer<uint32_t> first;
Buffer<uint32_t> next;
Buffer<Slice> build_slice;
ColumnPtr build_key_column;
ColumnPtr build_key_column = nullptr;
uint32_t bucket_size = 0;
uint32_t row_count = 0; // real row count
size_t build_column_count = 0;
Expand All @@ -122,6 +124,24 @@ struct JoinHashTableItems {
bool left_to_nullable = false;
bool right_to_nullable = false;
bool has_large_column = false;
float keys_per_bucket = 0;
size_t used_buckets = 0;
bool cache_miss_serious = false;

float get_keys_per_bucket() const { return keys_per_bucket; }
bool ht_cache_miss_serious() const { return cache_miss_serious; }

void calculate_ht_info(size_t key_bytes) {
if (used_buckets == 0) { // to avoid redo
for (const auto value : first) {
used_buckets += value != 0;
}
keys_per_bucket = used_buckets == 0 ? 0 : row_count * 1.0 / used_buckets;
size_t probe_bytes = key_bytes + row_count * sizeof(uint32_t);
cache_miss_serious = ((probe_bytes > (1UL << 25) && keys_per_bucket > 1.5) || probe_bytes > (1UL << 26)) &&
row_count > (1UL << 18);
}
}

TJoinOp::type join_type = TJoinOp::INNER_JOIN;

Expand Down Expand Up @@ -173,7 +193,32 @@ struct HashTableProbeState {
RuntimeProfile::Counter* output_build_column_timer = nullptr;

HashTableProbeState() = default;
~HashTableProbeState() = default;

struct ProbeCoroutine {
struct ProbePromise {
ProbeCoroutine get_return_object() { return std::coroutine_handle<ProbePromise>::from_promise(*this); }
std::suspend_always initial_suspend() { return {}; }
// as final_suspend() suspends coroutines, so should destroy manually in final.
std::suspend_always final_suspend() noexcept { return {}; }
void unhandled_exception() { exception = std::current_exception(); }
void return_void() {}
std::exception_ptr exception = nullptr;
};

using promise_type = ProbePromise;
ProbeCoroutine(std::coroutine_handle<ProbePromise> h) : handle(h) {}
~ProbeCoroutine() {}
std::coroutine_handle<ProbePromise> handle;
operator std::coroutine_handle<promise_type>() const { return std::move(handle); }
};
uint32_t match_count = 0;
int active_coroutines = 0;
// used to adaptively detect time locality
size_t probe_chunks = 0;
uint32_t detect_step = 1;
bool last_enable_interleaving = true;

std::set<std::coroutine_handle<ProbeCoroutine::ProbePromise>> handles;

HashTableProbeState(const HashTableProbeState& rhs)
: is_nulls(rhs.is_nulls),
Expand Down Expand Up @@ -205,6 +250,15 @@ struct HashTableProbeState {
// Disable move ctor and assignment.
HashTableProbeState(HashTableProbeState&&) = delete;
HashTableProbeState& operator=(HashTableProbeState&&) = delete;

void consider_probe_time_locality();

~HashTableProbeState() {
for (auto it = handles.begin(); it != handles.end(); it++) {
it->destroy();
}
handles.clear();
}
};

struct HashTableParam {
Expand Down Expand Up @@ -583,81 +637,96 @@ class JoinHashMap {
template <bool first_probe>
void _probe_from_ht(RuntimeState* state, const Buffer<CppType>& build_data, const Buffer<CppType>& probe_data);

HashTableProbeState::ProbeCoroutine _probe_from_ht(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

template <bool first_probe, bool init_match = false>
void _probe_coroutine(RuntimeState* state, const Buffer<CppType>& build_data, const Buffer<CppType>& probe_data);

// for one key left outer join
template <bool first_probe>
void _probe_from_ht_for_left_outer_join(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

HashTableProbeState::ProbeCoroutine _probe_from_ht_for_left_outer_join(RuntimeState* state,
const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
// for one key left semi join
template <bool first_probe>
void _probe_from_ht_for_left_semi_join(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

HashTableProbeState::ProbeCoroutine _probe_from_ht_for_left_semi_join(RuntimeState* state,
const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
// for one key left anti join
template <bool first_probe>
void _probe_from_ht_for_left_anti_join(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
HashTableProbeState::ProbeCoroutine _probe_from_ht_for_left_anti_join(RuntimeState* state,
const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

// for one key right outer join
template <bool first_probe>
void _probe_from_ht_for_right_outer_join(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
HashTableProbeState::ProbeCoroutine _probe_from_ht_for_right_outer_join(RuntimeState* state,
const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

// for one key right semi join
template <bool first_probe>
void _probe_from_ht_for_right_semi_join(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
HashTableProbeState::ProbeCoroutine _probe_from_ht_for_right_semi_join(RuntimeState* state,
const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

// for one key right anti join
template <bool first_probe>
void _probe_from_ht_for_right_anti_join(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
HashTableProbeState::ProbeCoroutine _probe_from_ht_for_right_anti_join(RuntimeState* state,
const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

// for one key full outer join
template <bool first_probe>
void _probe_from_ht_for_full_outer_join(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

// for left outer join with other join conjunct
template <bool first_probe>
void _probe_from_ht_for_left_outer_join_with_other_conjunct(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
HashTableProbeState::ProbeCoroutine _probe_from_ht_for_full_outer_join(RuntimeState* state,
const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

// for left semi join with other join conjunct
template <bool first_probe>
void _probe_from_ht_for_left_semi_join_with_other_conjunct(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

// for left anti join with other join conjunct
template <bool first_probe>
void _probe_from_ht_for_left_anti_join_with_other_conjunct(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
HashTableProbeState::ProbeCoroutine _probe_from_ht_for_left_semi_join_with_other_conjunct(
RuntimeState* state, const Buffer<CppType>& build_data, const Buffer<CppType>& probe_data);

// for null aware anti join with other join conjunct
template <bool first_probe>
void _probe_from_ht_for_null_aware_anti_join_with_other_conjunct(RuntimeState* state,
const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
HashTableProbeState::ProbeCoroutine _probe_from_ht_for_null_aware_anti_join_with_other_conjunct(
RuntimeState* state, const Buffer<CppType>& build_data, const Buffer<CppType>& probe_data);

// for one key right outer join with other conjunct
template <bool first_probe>
void _probe_from_ht_for_right_outer_join_with_other_conjunct(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

// for one key right semi join with other join conjunct
template <bool first_probe>
void _probe_from_ht_for_right_semi_join_with_other_conjunct(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);

// for one key right anti join with other join conjunct
template <bool first_probe>
void _probe_from_ht_for_right_anti_join_with_other_conjunct(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
void _probe_from_ht_for_right_outer_right_semi_right_anti_join_with_other_conjunct(
RuntimeState* state, const Buffer<CppType>& build_data, const Buffer<CppType>& probe_data);
HashTableProbeState::ProbeCoroutine _probe_from_ht_for_right_outer_right_semi_right_anti_join_with_other_conjunct(
RuntimeState* state, const Buffer<CppType>& build_data, const Buffer<CppType>& probe_data);

// for one key full outer join with other join conjunct
template <bool first_probe>
void _probe_from_ht_for_full_outer_join_with_other_conjunct(RuntimeState* state, const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
void _probe_from_ht_for_left_outer_left_anti_full_outer_join_with_other_conjunct(RuntimeState* state,
const Buffer<CppType>& build_data,
const Buffer<CppType>& probe_data);
HashTableProbeState::ProbeCoroutine _probe_from_ht_for_left_outer_left_anti_full_outer_join_with_other_conjunct(
RuntimeState* state, const Buffer<CppType>& build_data, const Buffer<CppType>& probe_data);

JoinHashTableItems* _table_items = nullptr;
HashTableProbeState* _probe_state = nullptr;
Expand Down Expand Up @@ -705,8 +774,7 @@ class JoinHashTable {
size_t get_probe_column_count() const { return _table_items->probe_column_count; }
size_t get_build_column_count() const { return _table_items->build_column_count; }
size_t get_bucket_size() const { return _table_items->bucket_size; }
size_t get_used_bucket_count() const;

float get_keys_per_bucket() const;
void remove_duplicate_index(Filter* filter);

int64_t mem_usage() const;
Expand Down
Loading

0 comments on commit e056a01

Please sign in to comment.