Skip to content

Commit 1dd5c08

Browse files
committed
More overflow-safe swiss table.
1 parent 18e8f50 commit 1dd5c08

File tree

5 files changed

+209
-203
lines changed

5 files changed

+209
-203
lines changed

cpp/src/arrow/acero/swiss_join.cc

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -643,37 +643,36 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
643643
//
644644
int source_group_id_bits =
645645
SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks());
646-
uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits);
647-
int64_t source_block_bytes = source_group_id_bits + 8;
646+
int source_block_bytes =
647+
SwissTable::num_block_bytes_from_num_groupid_bits(source_group_id_bits);
648648
ARROW_DCHECK(source_block_bytes % sizeof(uint64_t) == 0);
649649

650650
// Compute index of the last block in target that corresponds to the given
651651
// partition.
652652
//
653653
ARROW_DCHECK(num_partition_bits <= target->log_blocks());
654-
int64_t target_max_block_id =
654+
uint32_t target_max_block_id =
655655
((partition_id + 1) << (target->log_blocks() - num_partition_bits)) - 1;
656656

657657
overflow_group_ids->clear();
658658
overflow_hashes->clear();
659659

660660
// For each source block...
661-
int64_t source_blocks = 1LL << source->log_blocks();
662-
for (int64_t block_id = 0; block_id < source_blocks; ++block_id) {
663-
uint8_t* block_bytes = source->blocks() + block_id * source_block_bytes;
661+
uint32_t source_blocks = 1 << source->log_blocks();
662+
for (uint32_t block_id = 0; block_id < source_blocks; ++block_id) {
663+
const uint8_t* block_bytes = source->block_data(block_id, source_block_bytes);
664664
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes);
665665

666666
// For each non-empty source slot...
667667
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL;
668-
constexpr int kSlotsPerBlock = 8;
669-
int num_full_slots =
670-
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
668+
int num_full_slots = SwissTable::kSlotsPerBlock -
669+
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
671670
for (int local_slot_id = 0; local_slot_id < num_full_slots; ++local_slot_id) {
672671
// Read group id and hash for this slot.
673672
//
674-
uint64_t group_id =
675-
source->extract_group_id(block_bytes, local_slot_id, source_group_id_mask);
676-
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
673+
uint32_t group_id =
674+
source->extract_group_id(block_bytes, local_slot_id, source_group_id_bits);
675+
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id);
677676
uint32_t hash = source->hashes()[global_slot_id];
678677
// Insert partition id into the highest bits of hash, shifting the
679678
// remaining hash bits right.
@@ -696,17 +695,18 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
696695
}
697696
}
698697

699-
inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_id,
700-
uint32_t hash, int64_t max_block_id) {
698+
inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint32_t group_id,
699+
uint32_t hash, uint32_t max_block_id) {
701700
// Load the first block to visit for this hash
702701
//
703-
int64_t block_id = hash >> (SwissTable::bits_hash_ - target->log_blocks());
704-
int64_t block_id_mask = ((1LL << target->log_blocks()) - 1);
702+
uint32_t block_id = SwissTable::block_id_from_hash(hash, target->log_blocks());
703+
uint32_t block_id_mask = (1 << target->log_blocks()) - 1;
705704
int num_group_id_bits =
706705
SwissTable::num_groupid_bits_from_log_blocks(target->log_blocks());
707-
int64_t num_block_bytes = num_group_id_bits + sizeof(uint64_t);
706+
int num_block_bytes =
707+
SwissTable::num_block_bytes_from_num_groupid_bits(num_group_id_bits);
708708
ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0);
709-
uint8_t* block_bytes = target->blocks() + block_id * num_block_bytes;
709+
const uint8_t* block_bytes = target->block_data(block_id, num_block_bytes);
710710
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes);
711711

712712
// Search for the first block with empty slots.
@@ -715,25 +715,23 @@ inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_i
715715
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL;
716716
while ((block & kHighBitOfEachByte) == 0 && block_id < max_block_id) {
717717
block_id = (block_id + 1) & block_id_mask;
718-
block_bytes = target->blocks() + block_id * num_block_bytes;
718+
block_bytes = target->block_data(block_id, num_block_bytes);
719719
block = *reinterpret_cast<const uint64_t*>(block_bytes);
720720
}
721721
if ((block & kHighBitOfEachByte) == 0) {
722722
return false;
723723
}
724-
constexpr int kSlotsPerBlock = 8;
725-
int local_slot_id =
726-
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
727-
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
728-
target->insert_into_empty_slot(static_cast<uint32_t>(global_slot_id), hash,
729-
static_cast<uint32_t>(group_id));
724+
int local_slot_id = SwissTable::kSlotsPerBlock -
725+
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
726+
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id);
727+
target->insert_into_empty_slot(global_slot_id, hash, group_id);
730728
return true;
731729
}
732730

733731
void SwissTableMerge::InsertNewGroups(SwissTable* target,
734732
const std::vector<uint32_t>& group_ids,
735733
const std::vector<uint32_t>& hashes) {
736-
int64_t num_blocks = 1LL << target->log_blocks();
734+
uint32_t num_blocks = 1 << target->log_blocks();
737735
for (size_t i = 0; i < group_ids.size(); ++i) {
738736
std::ignore = InsertNewGroup(target, group_ids[i], hashes[i], num_blocks);
739737
}
@@ -1191,7 +1189,7 @@ Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
11911189
// We want each partition to correspond to a range of block indices,
11921190
// so we also partition on the highest bits of the hash.
11931191
//
1194-
return locals.batch_hashes[i] >> (31 - log_num_prtns_) >> 1;
1192+
return locals.batch_hashes[i] >> (SwissTable::bits_hash_ - log_num_prtns_);
11951193
},
11961194
[&locals](int64_t i, int pos) {
11971195
locals.batch_prtn_row_ids[pos] = static_cast<uint16_t>(i);

cpp/src/arrow/acero/swiss_join_internal.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ class SwissTableMerge {
380380
// Max block id value greater or equal to the number of blocks guarantees that
381381
// the search will not be stopped.
382382
//
383-
static inline bool InsertNewGroup(SwissTable* target, uint64_t group_id, uint32_t hash,
384-
int64_t max_block_id);
383+
static inline bool InsertNewGroup(SwissTable* target, uint32_t group_id, uint32_t hash,
384+
uint32_t max_block_id);
385385
};
386386

387387
struct SwissTableWithKeys {

0 commit comments

Comments
 (0)