Skip to content

Commit aeb8b9d

Browse files
committed
Some more cleanup found in the last commit
1 parent 7af1d3c commit aeb8b9d

File tree

2 files changed

+68
-54
lines changed

2 files changed

+68
-54
lines changed

cpp/src/arrow/compute/key_map_internal.cc

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,14 @@ void SwissTable::early_filter_imp(const int num_keys, const uint32_t* hashes,
272272
// How many groups we can keep in the hash table without the need for resizing.
273273
// When we reach this limit, we need to break processing of any further rows and resize.
274274
//
275-
uint64_t SwissTable::num_groups_for_resize() const {
275+
int64_t SwissTable::num_groups_for_resize() const {
276276
// Consider N = 9 (aka 2 ^ 9 = 512 blocks) as small.
277277
// When N = 9, a slot id takes N + 3 = 12 bits, rounded up to 16 bits. This is also the
278278
// number of bits needed for a key id. Since each slot stores a status byte and a key
279279
// id, then a slot takes 1 byte + 16 bits = 3 bytes. Therefore a block of 8 slots takes
280280
// 24 bytes. The threshold of a small hash table ends up being 24 bytes * 512 = 12 KB.
281281
constexpr int log_blocks_small_ = 9;
282-
uint64_t num_slots = 1ULL << (log_blocks_ + 3);
282+
int64_t num_slots = num_slots_from_log_blocks(log_blocks_);
283283
if (log_blocks_ <= log_blocks_small_) {
284284
// Resize small hash tables when 50% full.
285285
return num_slots / 2;
@@ -290,7 +290,8 @@ uint64_t SwissTable::num_groups_for_resize() const {
290290
}
291291

292292
uint32_t SwissTable::wrap_global_slot_id(uint32_t global_slot_id) const {
293-
uint32_t global_slot_id_mask = static_cast<uint32_t>((1ULL << (log_blocks_ + 3)) - 1);
293+
uint32_t global_slot_id_mask =
294+
static_cast<uint32_t>((1ULL << (log_blocks_ + kLogSlotsPerBlock)) - 1ULL);
294295
return global_slot_id & global_slot_id_mask;
295296
}
296297

@@ -398,18 +399,20 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl
398399
int local_slot;
399400
const uint8_t* blockbase;
400401
for (;;) {
401-
blockbase = block_data(start_slot_id >> 3, num_block_bytes);
402+
blockbase = block_data(start_slot_id >> kLogSlotsPerBlock, num_block_bytes);
402403
uint64_t block = *reinterpret_cast<const uint64_t*>(blockbase);
403404

404-
search_block<true>(block, stamp, start_slot_id & 7, &local_slot, &match_found);
405+
search_block<true>(block, stamp, start_slot_id & kLocalSlotMask, &local_slot,
406+
&match_found);
405407

406-
start_slot_id = wrap_global_slot_id((start_slot_id & ~7U) + local_slot + match_found);
408+
start_slot_id =
409+
wrap_global_slot_id((start_slot_id & ~kLocalSlotMask) + local_slot + match_found);
407410

408411
// Match found can be 1 in two cases:
409412
// - match was found
410413
// - match was not found in a full block
411414
// In the second case search needs to continue in the next block.
412-
if (match_found == 0 || blockbase[7 - local_slot] == stamp) {
415+
if (match_found == 0 || blockbase[kMaxLocalSlot - local_slot] == stamp) {
413416
break;
414417
}
415418
}
@@ -635,7 +638,7 @@ Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t*
635638
// First slot in the new starting block
636639
const int16_t id = ids[i];
637640
uint32_t block_id = block_id_from_hash(hashes[id], log_blocks_);
638-
slot_ids[id] = global_slot_id(block_id, 0);
641+
slot_ids[id] = global_slot_id(block_id, /*local_slot_id=*/0);
639642
}
640643
}
641644
} while (num_ids > 0);
@@ -647,7 +650,8 @@ Status SwissTable::grow_double() {
647650
// Before and after metadata
648651
int num_group_id_bits_before = num_groupid_bits_from_log_blocks(log_blocks_);
649652
int num_group_id_bits_after = num_groupid_bits_from_log_blocks(log_blocks_ + 1);
650-
uint64_t group_id_mask_before = ~0ULL >> (64 - num_group_id_bits_before);
653+
uint32_t group_id_mask_before =
654+
group_id_mask_from_num_groupid_bits(num_group_id_bits_before);
651655
int log_blocks_after = log_blocks_ + 1;
652656
int bits_shift_for_block_and_stamp_after =
653657
ComputeBitsShiftForBlockAndStamp(log_blocks_after);
@@ -657,7 +661,7 @@ Status SwissTable::grow_double() {
657661
int64_t block_size_total_after =
658662
num_bytes_total_blocks(block_size_after, log_blocks_after);
659663
int64_t hashes_size_total_after =
660-
(bits_hash_ / 8 * (1 << (log_blocks_after + 3))) + padding_;
664+
(bits_hash_ / 8 * num_slots_from_log_blocks(log_blocks_after)) + padding_;
661665
constexpr uint32_t stamp_mask = (1 << bits_stamp_) - 1;
662666

663667
// Allocate new buffers
@@ -685,7 +689,7 @@ Status SwissTable::grow_double() {
685689
util::SafeStore(double_block_base_new + block_size_after, kHighBitOfEachByte);
686690

687691
for (uint32_t j = 0; j < full_slots; ++j) {
688-
uint64_t slot_id = global_slot_id(i, j);
692+
uint32_t slot_id = global_slot_id(i, j);
689693
uint32_t hash = hashes()[slot_id];
690694
uint32_t block_id_new = block_id_from_hash(hash, log_blocks_after);
691695
bool is_overflow_entry = ((block_id_new >> 1) != static_cast<uint64_t>(i));
@@ -695,22 +699,22 @@ Status SwissTable::grow_double() {
695699

696700
uint32_t ihalf = block_id_new & 1;
697701
uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) & stamp_mask;
698-
uint64_t group_id_bit_offs = j * num_group_id_bits_before;
699-
uint64_t group_id =
700-
(util::SafeLoadAs<uint64_t>(block_base + bytes_status_in_block_ +
702+
int group_id_bit_offs = j * num_group_id_bits_before;
703+
uint32_t group_id =
704+
(util::SafeLoadAs<uint32_t>(block_base + bytes_status_in_block_ +
701705
(group_id_bit_offs >> 3)) >>
702706
(group_id_bit_offs & 7)) &
703707
group_id_mask_before;
704708

705-
uint64_t slot_id_new = global_slot_id(i * 2 + ihalf, full_slots_new[ihalf]);
709+
uint32_t slot_id_new = global_slot_id(i * 2 + ihalf, full_slots_new[ihalf]);
706710
hashes_new[slot_id_new] = hash;
707711
uint8_t* block_base_new = double_block_base_new + ihalf * block_size_after;
708-
block_base_new[7 - full_slots_new[ihalf]] = stamp_new;
709-
int64_t group_id_bit_offs_new = full_slots_new[ihalf] * num_group_id_bits_after;
712+
block_base_new[kMaxLocalSlot - full_slots_new[ihalf]] = stamp_new;
713+
int group_id_bit_offs_new = full_slots_new[ihalf] * num_group_id_bits_after;
710714
uint64_t* ptr = reinterpret_cast<uint64_t*>(
711715
block_base_new + bytes_status_in_block_ + (group_id_bit_offs_new >> 3));
712-
util::SafeStore(ptr,
713-
util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7)));
716+
util::SafeStore(ptr, util::SafeLoad(ptr) | (static_cast<uint64_t>(group_id)
717+
<< (group_id_bit_offs_new & 7)));
714718
full_slots_new[ihalf]++;
715719
}
716720
}
@@ -724,17 +728,17 @@ Status SwissTable::grow_double() {
724728
uint32_t full_slots = CountLeadingZeros(block & kHighBitOfEachByte) >> 3;
725729

726730
for (uint32_t j = 0; j < full_slots; ++j) {
727-
uint64_t slot_id = global_slot_id(i, j);
731+
uint32_t slot_id = global_slot_id(i, j);
728732
uint32_t hash = hashes()[slot_id];
729733
uint32_t block_id_new = block_id_from_hash(hash, log_blocks_after);
730734
bool is_overflow_entry = ((block_id_new >> 1) != static_cast<uint64_t>(i));
731735
if (!is_overflow_entry) {
732736
continue;
733737
}
734738

735-
uint64_t group_id_bit_offs = j * num_group_id_bits_before;
736-
uint64_t group_id =
737-
(util::SafeLoadAs<uint64_t>(block_base + bytes_status_in_block_ +
739+
int group_id_bit_offs = j * num_group_id_bits_before;
740+
uint32_t group_id =
741+
(util::SafeLoadAs<uint32_t>(block_base + bytes_status_in_block_ +
738742
(group_id_bit_offs >> 3)) >>
739743
(group_id_bit_offs & 7)) &
740744
group_id_mask_before;
@@ -753,13 +757,13 @@ Status SwissTable::grow_double() {
753757
static_cast<int>(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3);
754758
}
755759

756-
hashes_new[block_id_new * 8 + full_slots_new] = hash;
757-
block_base_new[7 - full_slots_new] = stamp_new;
758-
int64_t group_id_bit_offs_new = full_slots_new * num_group_id_bits_after;
760+
hashes_new[block_id_new * kSlotsPerBlock + full_slots_new] = hash;
761+
block_base_new[kMaxLocalSlot - full_slots_new] = stamp_new;
762+
int group_id_bit_offs_new = full_slots_new * num_group_id_bits_after;
759763
uint64_t* ptr = reinterpret_cast<uint64_t*>(
760764
block_base_new + bytes_status_in_block_ + (group_id_bit_offs_new >> 3));
761-
util::SafeStore(ptr,
762-
util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7)));
765+
util::SafeStore(ptr, util::SafeLoad(ptr) | (static_cast<uint64_t>(group_id)
766+
<< (group_id_bit_offs_new & 7)));
763767
}
764768
}
765769

@@ -800,9 +804,9 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, int log_blocks
800804
if (no_hash_array) {
801805
hashes_ = nullptr;
802806
} else {
803-
uint64_t num_slots = 1ULL << (log_blocks_ + 3);
804-
const uint64_t hash_size = sizeof(uint32_t);
805-
const uint64_t hash_bytes = hash_size * num_slots + padding_;
807+
int64_t num_slots = num_slots_from_log_blocks(log_blocks);
808+
const int hash_size = bits_hash_ >> 3;
809+
const int64_t hash_bytes = hash_size * num_slots + padding_;
806810
ARROW_ASSIGN_OR_RAISE(hashes_, AllocateBuffer(hash_bytes, pool_));
807811
}
808812

cpp/src/arrow/compute/key_map_internal.h

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class ARROW_EXPORT SwissTable {
9090
static uint32_t extract_group_id(const uint8_t* block_ptr, int local_slot,
9191
int num_group_id_bits) {
9292
// Extract group id using aligned 32-bit read.
93-
uint32_t group_id_mask = static_cast<uint32_t>((1ULL << num_group_id_bits) - 1);
93+
uint32_t group_id_mask = group_id_mask_from_num_groupid_bits(num_group_id_bits);
9494
int slot_bit_offset = local_slot * num_group_id_bits;
9595
const uint32_t* group_id_ptr32 =
9696
reinterpret_cast<const uint32_t*>(block_ptr + bytes_status_in_block_) +
@@ -99,7 +99,8 @@ class ARROW_EXPORT SwissTable {
9999
return group_id;
100100
}
101101

102-
inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id);
102+
inline void insert_into_empty_slot(uint32_t global_slot_id, uint32_t hash,
103+
uint32_t group_id);
103104

104105
static uint32_t block_id_from_hash(uint32_t hash, int log_blocks) {
105106
return hash >> (bits_hash_ - log_blocks);
@@ -110,22 +111,16 @@ class ARROW_EXPORT SwissTable {
110111
}
111112

112113
static int num_groupid_bits_from_log_blocks(int log_blocks) {
113-
assert(log_blocks >= 0 && log_blocks <= 32 - 3);
114-
int required_bits = log_blocks + 3;
115-
return required_bits <= 8 ? 8
116-
: required_bits <= 16 ? 16
117-
: required_bits <= 32 ? 32
118-
: 64;
114+
assert(log_blocks >= 0);
115+
int required_bits = log_blocks + kLogSlotsPerBlock;
116+
assert(required_bits <= 32);
117+
return required_bits <= 8 ? 8 : required_bits <= 16 ? 16 : 32;
119118
}
120119

121120
static int num_block_bytes_from_num_groupid_bits(int num_groupid_bits) {
122121
return num_groupid_bits + bytes_status_in_block_;
123122
}
124123

125-
static int64_t num_bytes_total_blocks(int num_block_bytes, int log_blocks) {
126-
return (static_cast<int64_t>(num_block_bytes) << log_blocks) + padding_;
127-
}
128-
129124
const uint8_t* block_data(uint32_t block_id, int num_block_bytes) const {
130125
return block_data(blocks_->data(), block_id, num_block_bytes);
131126
}
@@ -188,10 +183,24 @@ class ARROW_EXPORT SwissTable {
188183
const uint32_t* hashes, const uint8_t* local_slots,
189184
uint32_t* out_group_ids) const;
190185

191-
inline uint64_t next_slot_to_visit(uint64_t block_index, int slot,
192-
int match_found) const;
186+
static constexpr int kLogSlotsPerBlock = 3;
187+
static constexpr int kMaxLocalSlot = kSlotsPerBlock - 1;
188+
static constexpr uint32_t kLocalSlotMask = (1U << kLogSlotsPerBlock) - 1U;
189+
190+
static int64_t num_slots_from_log_blocks(int log_blocks) {
191+
return 1LL << (log_blocks + kLogSlotsPerBlock);
192+
}
193+
194+
static int64_t num_bytes_total_blocks(int num_block_bytes, int log_blocks) {
195+
return (static_cast<int64_t>(num_block_bytes) << log_blocks) + padding_;
196+
}
197+
198+
inline int64_t num_groups_for_resize() const;
193199

194-
inline uint64_t num_groups_for_resize() const;
200+
static uint32_t group_id_mask_from_num_groupid_bits(int num_groupid_bits) {
201+
// num_groupid_bits could be 32, so using 64-bit shifting.
202+
return static_cast<uint32_t>((1ULL << num_groupid_bits) - 1ULL);
203+
}
195204

196205
inline uint32_t wrap_global_slot_id(uint32_t global_slot_id) const;
197206

@@ -307,7 +316,7 @@ class ARROW_EXPORT SwissTable {
307316
MemoryPool* pool_;
308317
};
309318

310-
void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash,
319+
void SwissTable::insert_into_empty_slot(uint32_t global_slot_id, uint32_t hash,
311320
uint32_t group_id) {
312321
const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
313322

@@ -317,19 +326,20 @@ void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash,
317326
num_groupid_bits == 64);
318327

319328
const int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits);
320-
constexpr uint64_t stamp_mask = 0x7f;
329+
constexpr uint32_t stamp_mask = 0x7f;
321330

322-
int start_slot = (slot_id & 7);
323-
int stamp = static_cast<int>((hash >> bits_shift_for_block_and_stamp_) & stamp_mask);
324-
uint32_t block_id = slot_id >> 3;
331+
int start_slot = (global_slot_id & kLocalSlotMask);
332+
int stamp = (hash >> bits_shift_for_block_and_stamp_) & stamp_mask;
333+
uint32_t block_id = global_slot_id >> kLogSlotsPerBlock;
325334
uint8_t* blockbase = mutable_block_data(block_id, num_block_bytes);
326335

327-
blockbase[7 - start_slot] = static_cast<uint8_t>(stamp);
328-
int groupid_bit_offset = static_cast<int>(start_slot * num_groupid_bits);
336+
blockbase[kMaxLocalSlot - start_slot] = static_cast<uint8_t>(stamp);
337+
int groupid_bit_offset = start_slot * num_groupid_bits;
329338

330339
// Block status bytes should start at an address aligned to 8 bytes
331340
assert((reinterpret_cast<uint64_t>(blockbase) & 7) == 0);
332-
uint64_t* ptr = reinterpret_cast<uint64_t*>(blockbase) + 1 + (groupid_bit_offset >> 6);
341+
uint64_t* ptr = reinterpret_cast<uint64_t*>(blockbase + bytes_status_in_block_) +
342+
(groupid_bit_offset >> 6);
333343
*ptr |= (static_cast<uint64_t>(group_id) << (groupid_bit_offset & 63));
334344
}
335345

0 commit comments

Comments
 (0)