Skip to content

Commit

Permalink
all kind hash set
Browse files Browse the repository at this point in the history
Signed-off-by: zombee0 <ewang2027@gmail.com>
  • Loading branch information
zombee0 committed Jul 2, 2024
1 parent 7a5ba8a commit 0e245d7
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 51 deletions.
5 changes: 0 additions & 5 deletions be/src/exec/aggregate/agg_hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@

namespace starrocks {

const constexpr int32_t prefetch_threhold = 8192;

using AggDataPtr = uint8_t*;

// =====================
Expand Down Expand Up @@ -83,9 +81,6 @@ using SliceAggTwoLevelHashMap =
phmap::parallel_flat_hash_map<Slice, AggDataPtr, SliceHashWithSeed<seed>, SliceEqual,
phmap::priv::Allocator<phmap::priv::Pair<const Slice, AggDataPtr>>, PHMAPN>;

// This is just an empirical value based on benchmark, and you can tweak it if more proper value is found.
static constexpr size_t AGG_HASH_MAP_DEFAULT_PREFETCH_DIST = 16;

static_assert(sizeof(AggDataPtr) == sizeof(size_t));
#define AGG_HASH_MAP_PRECOMPUTE_HASH_VALUES(column, prefetch_dist) \
size_t const column_size = column->size(); \
Expand Down
156 changes: 110 additions & 46 deletions be/src/exec/aggregate/agg_hash_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@

namespace starrocks {

const constexpr int32_t prefetch_threhold = 8192;
// This is just an empirical value based on benchmark, and you can tweak it if more proper value is found.
static constexpr size_t AGG_HASH_MAP_DEFAULT_PREFETCH_DIST = 16;

#define AGG_HASH_SET_PRECOMPUTE_HASH_VALS() \
hashes.reserve(chunk_size); \
for (size_t i = 0; i < chunk_size; i++) { \
hashes[i] = this->hash_set.hash_function()(keys[i]); \
}

#define AGG_HASH_SET_PREFETCH_HASH_VAL() \
if (i + AGG_HASH_MAP_DEFAULT_PREFETCH_DIST < chunk_size && this->hash_set.bucket_count() > prefetch_threhold) { \
this->hash_set.prefetch_hash(hashes[i + AGG_HASH_MAP_DEFAULT_PREFETCH_DIST]); \
}

#define AGG_STRING_HASH_SET_PREFETCH_HASH_VAL() \
if (i + AGG_HASH_MAP_DEFAULT_PREFETCH_DIST < chunk_size && this->hash_set.bucket_count() > prefetch_threhold) { \
this->hash_set.prefetch_hash(cache[i + AGG_HASH_MAP_DEFAULT_PREFETCH_DIST].hash); \
}

// =====================
// one level agg hash set
template <PhmapSeed seed>
Expand Down Expand Up @@ -66,9 +86,6 @@ using SliceAggTwoLevelHashSet =
phmap::parallel_flat_hash_set<TSliceWithHash<seed>, THashOnSliceWithHash<seed>, TEqualOnSliceWithHash<seed>,
phmap::priv::Allocator<Slice>, 4>;

// This is just an empirical value based on benchmark, and you can tweak it if more proper value is found.
static constexpr size_t AGG_HASH_SET_DEFAULT_PREFETCH_DIST = 16;

// ==============================================================

template <typename HashSet, typename Impl>
Expand All @@ -88,6 +105,14 @@ struct AggHashSet {
}
};

template <typename T>
struct no_prefetch_set : std::false_type {};
template <PhmapSeed seed>
struct no_prefetch_set<Int8AggHashSet<seed>> : std::true_type {};

template <class T>
constexpr bool is_no_prefetch_set = no_prefetch_set<T>::value;

// handle one number hash key
template <LogicalType logical_type, typename HashSet>
struct AggHashSetOfOneNumberKey : public AggHashSet<HashSet, AggHashSetOfOneNumberKey<logical_type, HashSet>> {
Expand All @@ -112,13 +137,25 @@ struct AggHashSetOfOneNumberKey : public AggHashSet<HashSet, AggHashSetOfOneNumb
not_founds->assign(chunk_size, 0);
}
auto* column = down_cast<ColumnType*>(key_columns[0].get());
const size_t row_num = column->size();
auto& keys = column->get_data();
for (size_t i = 0; i < row_num; ++i) {
if constexpr (compute_and_allocate) {
this->hash_set.emplace(keys[i]);
} else {
(*not_founds)[i] = !this->hash_set.contains(keys[i]);

if constexpr (!is_no_prefetch_set<HashSet>) {
AGG_HASH_SET_PRECOMPUTE_HASH_VALS();
for (size_t i = 0; i < chunk_size; ++i) {
AGG_HASH_SET_PREFETCH_HASH_VAL();
if constexpr (compute_and_allocate) {
this->hash_set.emplace_with_hash(this->hashes[i], keys[i]);
} else {
(*not_founds)[i] = this->hash_set.find(keys[i], this->hashes[i]) == this->hash_set.end();
}
}
} else {
for (size_t i = 0; i < chunk_size; ++i) {
if constexpr (compute_and_allocate) {
this->hash_set.emplace(keys[i]);
} else {
(*not_founds)[i] = !this->hash_set.contains(keys[i]);
}
}
}
}
Expand All @@ -130,6 +167,7 @@ struct AggHashSetOfOneNumberKey : public AggHashSet<HashSet, AggHashSetOfOneNumb

static constexpr bool has_single_null_key = false;
ResultVector results;
std::vector<size_t> hashes;
};

template <LogicalType logical_type, typename HashSet>
Expand Down Expand Up @@ -164,9 +202,8 @@ struct AggHashSetOfOneNullableNumberKey
const auto& null_data = nullable_column->null_column_data();
auto& keys = data_column->get_data();

size_t row_num = nullable_column->size();
if (nullable_column->has_null()) {
for (size_t i = 0; i < row_num; ++i) {
for (size_t i = 0; i < chunk_size; ++i) {
if (null_data[i]) {
has_null_key = true;
} else {
Expand All @@ -178,11 +215,23 @@ struct AggHashSetOfOneNullableNumberKey
}
}
} else {
for (size_t i = 0; i < row_num; ++i) {
if constexpr (compute_and_allocate) {
this->hash_set.emplace(keys[i]);
} else {
(*not_founds)[i] = !this->hash_set.contains(keys[i]);
if constexpr (!is_no_prefetch_set<HashSet>) {
AGG_HASH_SET_PRECOMPUTE_HASH_VALS();
for (size_t i = 0; i < chunk_size; ++i) {
AGG_HASH_SET_PREFETCH_HASH_VAL();
if constexpr (compute_and_allocate) {
this->hash_set.emplace_with_hash(this->hashes[i], keys[i]);
} else {
(*not_founds)[i] = this->hash_set.find(keys[i], hashes[i]) == this->hash_set.end();
}
}
} else {
for (size_t i = 0; i < chunk_size; ++i) {
if constexpr (compute_and_allocate) {
this->hash_set.emplace(keys[i]);
} else {
(*not_founds)[i] = !this->hash_set.contains(keys[i]);
}
}
}
}
Expand All @@ -199,6 +248,7 @@ struct AggHashSetOfOneNullableNumberKey
static constexpr bool has_single_null_key = true;
bool has_null_key = false;
ResultVector results;
std::vector<size_t> hashes;
};

template <typename HashSet>
Expand All @@ -207,7 +257,7 @@ struct AggHashSetOfOneStringKey : public AggHashSet<HashSet, AggHashSetOfOneStri
using KeyType = typename HashSet::key_type;
using ResultVector = typename std::vector<Slice>;

AggHashSetOfOneStringKey(int32_t chunk_size) {}
AggHashSetOfOneStringKey(int32_t chunk_size) { cache.reserve(chunk_size); }

// When compute_and_allocate=false:
// Elements queried in HashSet will be added to HashSet
Expand All @@ -222,19 +272,22 @@ struct AggHashSetOfOneStringKey : public AggHashSet<HashSet, AggHashSetOfOneStri
not_founds->assign(chunk_size, 0);
}

size_t row_num = column->size();
for (size_t i = 0; i < row_num; ++i) {
auto tmp = column->get_slice(i);
for (size_t i = 0; i < chunk_size; ++i) {
cache[i] = KeyType(column->get_slice(i));
}

for (size_t i = 0; i < chunk_size; ++i) {
AGG_STRING_HASH_SET_PREFETCH_HASH_VAL();
auto& key = cache[i];
if constexpr (compute_and_allocate) {
KeyType key(tmp);
this->hash_set.lazy_emplace(key, [&](const auto& ctor) {
this->hash_set.lazy_emplace_with_hash(key, key.hash, [&](const auto& ctor) {
// we must persist the slice before insert
uint8_t* pos = pool->allocate(key.size);
memcpy(pos, key.data, key.size);
ctor(pos, key.size, key.hash);
});
} else {
(*not_founds)[i] = !this->hash_set.contains(tmp);
(*not_founds)[i] = this->hash_set.find(key, key.hash) == this->hash_set.end();
}
}
}
Expand All @@ -247,6 +300,7 @@ struct AggHashSetOfOneStringKey : public AggHashSet<HashSet, AggHashSetOfOneStri

static constexpr bool has_single_null_key = false;
ResultVector results;
std::vector<KeyType> cache;
};

template <typename HashSet>
Expand All @@ -255,7 +309,7 @@ struct AggHashSetOfOneNullableStringKey : public AggHashSet<HashSet, AggHashSetO
using KeyType = typename HashSet::key_type;
using ResultVector = typename std::vector<Slice>;

AggHashSetOfOneNullableStringKey(int32_t chunk_size) {}
AggHashSetOfOneNullableStringKey(int32_t chunk_size) { cache.reserve(chunk_size); }

// When compute_and_allocate=false:
// Elements queried in HashSet will be added to HashSet
Expand All @@ -274,9 +328,8 @@ struct AggHashSetOfOneNullableStringKey : public AggHashSet<HashSet, AggHashSetO
auto* data_column = down_cast<BinaryColumn*>(nullable_column->data_column().get());
const auto& null_data = nullable_column->null_column_data();

size_t row_num = nullable_column->size();
if (nullable_column->has_null()) {
for (size_t i = 0; i < row_num; ++i) {
for (size_t i = 0; i < chunk_size; ++i) {
if (null_data[i]) {
has_null_key = true;
} else {
Expand All @@ -288,11 +341,20 @@ struct AggHashSetOfOneNullableStringKey : public AggHashSet<HashSet, AggHashSetO
}
}
} else {
for (size_t i = 0; i < row_num; ++i) {
for (size_t i = 0; i < chunk_size; ++i) {
cache[i] = KeyType(data_column->get_slice(i));
}
for (size_t i = 0; i < chunk_size; ++i) {
AGG_STRING_HASH_SET_PREFETCH_HASH_VAL();
auto& key = cache[i];
if constexpr (compute_and_allocate) {
_handle_data_key_column(data_column, i, pool, not_founds);
this->hash_set.lazy_emplace_with_hash(key, key.hash, [&](const auto& ctor) {
uint8_t* pos = pool->allocate(key.size);
memcpy(pos, key.data, key.size);
ctor(pos, key.size, key.hash);
});
} else {
_handle_data_key_column(data_column, i, not_founds);
(*not_founds)[i] = this->hash_set.find(key, key.hash) == this->hash_set.end();
}
}
}
Expand Down Expand Up @@ -328,6 +390,7 @@ struct AggHashSetOfOneNullableStringKey : public AggHashSet<HashSet, AggHashSetO
static constexpr bool has_single_null_key = true;
bool has_null_key = false;
ResultVector results;
std::vector<KeyType> cache;
};

template <typename HashSet>
Expand All @@ -339,7 +402,9 @@ struct AggHashSetOfSerializedKey : public AggHashSet<HashSet, AggHashSetOfSerial
AggHashSetOfSerializedKey(int32_t chunk_size)
: _mem_pool(std::make_unique<MemPool>()),
_buffer(_mem_pool->allocate(max_one_row_size * chunk_size)),
_chunk_size(chunk_size) {}
_chunk_size(chunk_size) {
cache.reserve(chunk_size);
}

// When compute_and_allocate=false:
// Elements queried in HashSet will be added to HashSet
Expand Down Expand Up @@ -367,17 +432,21 @@ struct AggHashSetOfSerializedKey : public AggHashSet<HashSet, AggHashSetOfSerial
}

for (size_t i = 0; i < chunk_size; ++i) {
Slice tmp = {_buffer + i * max_one_row_size, slice_sizes[i]};
cache[i] = KeyType(Slice(_buffer + i * max_one_row_size, slice_sizes[i]));
}

for (size_t i = 0; i < chunk_size; ++i) {
AGG_STRING_HASH_SET_PREFETCH_HASH_VAL();
auto& key = cache[i];
if constexpr (compute_and_allocate) {
KeyType key(tmp);
this->hash_set.lazy_emplace(key, [&](const auto& ctor) {
this->hash_set.lazy_emplace_with_hash(key, key.hash, [&](const auto& ctor) {
// we must persist the slice before insert
uint8_t* pos = pool->allocate(key.size);
memcpy(pos, key.data, key.size);
ctor(pos, key.size, key.hash);
});
} else {
(*not_founds)[i] = !this->hash_set.contains(tmp);
(*not_founds)[i] = this->hash_set.find(key, key.hash) == this->hash_set.end();
}
}
}
Expand Down Expand Up @@ -421,6 +490,7 @@ struct AggHashSetOfSerializedKey : public AggHashSet<HashSet, AggHashSetOfSerial
ResultVector results;

int32_t _chunk_size;
std::vector<KeyType> cache;
};

template <typename HashSet>
Expand Down Expand Up @@ -463,28 +533,22 @@ struct AggHashSetOfSerializedKeyFixedSize : public AggHashSet<HashSet, AggHashSe
key_column->serialize_batch(buffer, slice_sizes, chunk_size, max_fixed_size);
}

auto* key = reinterpret_cast<FixedSizeSliceKey*>(buffer);
auto* keys = reinterpret_cast<FixedSizeSliceKey*>(buffer);

if (has_null_column) {
for (size_t i = 0; i < chunk_size; ++i) {
key[i].u.size = slice_sizes[i];
keys[i].u.size = slice_sizes[i];
}
}

for (size_t i = 0; i < chunk_size; ++i) {
hashes[i] = this->hash_set.hash_function()(key[i]);
}

size_t __prefetch_index = AGG_HASH_SET_DEFAULT_PREFETCH_DIST;
AGG_HASH_SET_PRECOMPUTE_HASH_VALS();

for (size_t i = 0; i < chunk_size; ++i) {
if (__prefetch_index < chunk_size) {
this->hash_set.prefetch_hash(hashes[__prefetch_index++]);
}
AGG_HASH_SET_PREFETCH_HASH_VAL();
if constexpr (compute_and_allocate) {
this->hash_set.emplace_with_hash(hashes[i], key[i]);
this->hash_set.emplace_with_hash(hashes[i], keys[i]);
} else {
(*not_founds)[i] = !this->hash_set.contains(key[i]);
(*not_founds)[i] = this->hash_set.find(keys[i], hashes[i]) == this->hash_set.end();
}
}
}
Expand Down

0 comments on commit 0e245d7

Please sign in to comment.