Skip to content

Commit 7af1d3c

Browse files
committed
Use aligned read to extract group id
1 parent f5db159 commit 7af1d3c

File tree

2 files changed

+51
-24
lines changed

2 files changed

+51
-24
lines changed

cpp/src/arrow/compute/key_map_internal.cc

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot,
9494
*out_slot = static_cast<int>(CountLeadingZeros(matches | block_high_bits) >> 3);
9595
}
9696

97-
template <bool use_selection>
97+
template <typename T, bool use_selection>
9898
void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selection,
9999
const uint32_t* hashes, const uint8_t* local_slots,
100100
uint32_t* out_group_ids) const {
101101
if (log_blocks_ == 0) {
102+
DCHECK_EQ(sizeof(T), sizeof(uint8_t));
102103
for (int i = 0; i < num_keys; ++i) {
103104
uint32_t id = use_selection ? selection[i] : i;
104105
uint32_t group_id =
@@ -108,18 +109,16 @@ void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selec
108109
}
109110
} else {
110111
int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
111-
int num_groupid_bytes = num_groupid_bits / 8;
112-
uint32_t group_id_mask = group_id_mask_from_num_groupid_bits(num_groupid_bits);
112+
DCHECK_EQ(sizeof(T) * 8, num_groupid_bits);
113113
int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits);
114114

115115
for (int i = 0; i < num_keys; ++i) {
116116
uint32_t id = use_selection ? selection[i] : i;
117117
uint32_t hash = hashes[id];
118118
uint32_t block_id = block_id_from_hash(hash, log_blocks_);
119-
uint32_t group_id = *reinterpret_cast<const uint32_t*>(
120-
block_data(block_id, num_block_bytes) + local_slots[id] * num_groupid_bytes +
121-
bytes_status_in_block_);
122-
group_id &= group_id_mask;
119+
const T* slots_base = reinterpret_cast<const T*>(
120+
block_data(block_id, num_block_bytes) + bytes_status_in_block_);
121+
uint32_t group_id = static_cast<uint32_t>(slots_base[local_slots[id]]);
123122
out_group_ids[id] = group_id;
124123
}
125124
}
@@ -137,13 +136,40 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_
137136
num_processed = extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids);
138137
}
139138
#endif
140-
if (optional_selection) {
141-
extract_group_ids_imp<true>(num_keys, optional_selection, hashes, local_slots,
142-
out_group_ids);
143-
} else {
144-
extract_group_ids_imp<false>(num_keys - num_processed, nullptr,
145-
hashes + num_processed, local_slots + num_processed,
146-
out_group_ids + num_processed);
139+
int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
140+
switch (num_groupid_bits) {
141+
case 8:
142+
if (optional_selection) {
143+
extract_group_ids_imp<uint8_t, true>(num_keys, optional_selection, hashes,
144+
local_slots, out_group_ids);
145+
} else {
146+
extract_group_ids_imp<uint8_t, false>(
147+
num_keys - num_processed, nullptr, hashes + num_processed,
148+
local_slots + num_processed, out_group_ids + num_processed);
149+
}
150+
break;
151+
case 16:
152+
if (optional_selection) {
153+
extract_group_ids_imp<uint16_t, true>(num_keys, optional_selection, hashes,
154+
local_slots, out_group_ids);
155+
} else {
156+
extract_group_ids_imp<uint16_t, false>(
157+
num_keys - num_processed, nullptr, hashes + num_processed,
158+
local_slots + num_processed, out_group_ids + num_processed);
159+
}
160+
break;
161+
case 32:
162+
if (optional_selection) {
163+
extract_group_ids_imp<uint32_t, true>(num_keys, optional_selection, hashes,
164+
local_slots, out_group_ids);
165+
} else {
166+
extract_group_ids_imp<uint32_t, false>(
167+
num_keys - num_processed, nullptr, hashes + num_processed,
168+
local_slots + num_processed, out_group_ids + num_processed);
169+
}
170+
break;
171+
default:
172+
DCHECK(false);
147173
}
148174
}
149175

cpp/src/arrow/compute/key_map_internal.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,15 @@ class ARROW_EXPORT SwissTable {
8888
/// \brief Extract group id for a given slot in a given block.
8989
///
9090
static uint32_t extract_group_id(const uint8_t* block_ptr, int local_slot,
91-
int64_t num_group_id_bits) {
92-
uint32_t group_id_mask = group_id_mask_from_num_groupid_bits(num_group_id_bits);
93-
uint32_t group_id = *reinterpret_cast<const uint32_t*>(
94-
block_ptr + bytes_status_in_block_ + local_slot * num_group_id_bits / 8);
95-
return group_id & group_id_mask;
91+
int num_group_id_bits) {
92+
// Extract group id using aligned 32-bit read.
93+
uint32_t group_id_mask = static_cast<uint32_t>((1ULL << num_group_id_bits) - 1);
94+
int slot_bit_offset = local_slot * num_group_id_bits;
95+
const uint32_t* group_id_ptr32 =
96+
reinterpret_cast<const uint32_t*>(block_ptr + bytes_status_in_block_) +
97+
(slot_bit_offset >> 5);
98+
uint32_t group_id = (*group_id_ptr32 >> (slot_bit_offset & 31)) & group_id_mask;
99+
return group_id;
96100
}
97101

98102
inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id);
@@ -106,6 +110,7 @@ class ARROW_EXPORT SwissTable {
106110
}
107111

108112
static int num_groupid_bits_from_log_blocks(int log_blocks) {
113+
assert(log_blocks >= 0 && log_blocks <= 32 - 3);
109114
int required_bits = log_blocks + 3;
110115
return required_bits <= 8 ? 8
111116
: required_bits <= 16 ? 16
@@ -178,7 +183,7 @@ class ARROW_EXPORT SwissTable {
178183
const uint32_t* hashes, const uint8_t* local_slots,
179184
uint32_t* out_group_ids) const;
180185

181-
template <bool use_selection>
186+
template <typename T, bool use_selection>
182187
void extract_group_ids_imp(const int num_keys, const uint16_t* selection,
183188
const uint32_t* hashes, const uint8_t* local_slots,
184189
uint32_t* out_group_ids) const;
@@ -257,10 +262,6 @@ class ARROW_EXPORT SwissTable {
257262
return bits_stamp_;
258263
}
259264

260-
static uint32_t group_id_mask_from_num_groupid_bits(int64_t num_groupid_bits) {
261-
return static_cast<uint32_t>((1ULL << num_groupid_bits) - 1);
262-
}
263-
264265
static constexpr int bytes_status_in_block_ = 8;
265266

266267
// Number of hash bits stored in slots in a block.

0 commit comments

Comments
 (0)