Skip to content

Commit 1a368de

Browse files
committed
Optimizations
1 parent 887f09d commit 1a368de

File tree

3 files changed

+103
-52
lines changed

3 files changed

+103
-52
lines changed

src/ribbon/port.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#pragma once
77

8+
#define CACHE_LINE_SIZE 64U
9+
810
namespace port {
911

1012
// FIXME

src/ribbon/ribbon_alg.h

Lines changed: 69 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,13 +1052,13 @@ void InterleavedBackSubst(InterleavedSolutionStorage *iss,
10521052
std::unique_ptr<CoeffRow[]> state{new CoeffRow[num_columns]()};
10531053

10541054
Index block = num_blocks;
1055-
Index segment = num_segments;
1055+
Index segment_num = num_segments;
10561056
while (block > upper_start_block) {
10571057
--block;
10581058
BackSubstBlock(state.get(), num_columns, bs, block * kCoeffBits);
1059-
segment -= num_columns;
1059+
segment_num -= num_columns;
10601060
for (Index i = 0; i < num_columns; ++i) {
1061-
iss->StoreSegment(segment + i, state[i]);
1061+
iss->StoreSegment(segment_num + i, state[i]);
10621062
}
10631063
}
10641064
// Now (if applicable), region using lower number of columns
@@ -1068,60 +1068,92 @@ void InterleavedBackSubst(InterleavedSolutionStorage *iss,
10681068
while (block > 0) {
10691069
--block;
10701070
BackSubstBlock(state.get(), num_columns, bs, block * kCoeffBits);
1071-
segment -= num_columns;
1071+
segment_num -= num_columns;
10721072
for (Index i = 0; i < num_columns; ++i) {
1073-
iss->StoreSegment(segment + i, state[i]);
1073+
iss->StoreSegment(segment_num + i, state[i]);
10741074
}
10751075
}
10761076
// Verify everything processed
10771077
assert(block == 0);
1078-
assert(segment == 0);
1078+
assert(segment_num == 0);
10791079
}
10801080

1081-
// General PHSF query a key from InterleavedSolutionStorage.
1081+
// Prefetch memory for a key in InterleavedSolutionStorage.
10821082
template <typename InterleavedSolutionStorage, typename PhsfQueryHasher>
1083-
typename InterleavedSolutionStorage::ResultRow InterleavedPhsfQuery(
1083+
inline void InterleavedPrepareQuery(
10841084
const typename PhsfQueryHasher::Key &key, const PhsfQueryHasher &hasher,
1085-
const InterleavedSolutionStorage &iss) {
1085+
const InterleavedSolutionStorage &iss,
1086+
typename PhsfQueryHasher::Hash *saved_hash,
1087+
typename InterleavedSolutionStorage::Index *saved_segment_num,
1088+
typename InterleavedSolutionStorage::Index *saved_num_columns,
1089+
typename InterleavedSolutionStorage::Index *saved_start_bit) {
10861090
using Hash = typename PhsfQueryHasher::Hash;
1087-
10881091
using CoeffRow = typename InterleavedSolutionStorage::CoeffRow;
10891092
using Index = typename InterleavedSolutionStorage::Index;
1090-
using ResultRow = typename InterleavedSolutionStorage::ResultRow;
10911093

10921094
static_assert(sizeof(Index) == sizeof(typename PhsfQueryHasher::Index),
10931095
"must be same");
1094-
static_assert(sizeof(CoeffRow) == sizeof(typename PhsfQueryHasher::CoeffRow),
1095-
"must be same");
1096-
1097-
constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
10981096

10991097
const Hash hash = hasher.GetHash(key);
11001098
const Index start_slot = hasher.GetStart(hash, iss.GetNumStarts());
11011099

1100+
constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
1101+
11021102
const Index upper_start_block = iss.GetUpperStartBlock();
11031103
Index num_columns = iss.GetUpperNumColumns();
11041104
Index start_block_num = start_slot / kCoeffBits;
1105-
Index segment = start_block_num * num_columns -
1105+
Index segment_num = start_block_num * num_columns -
11061106
std::min(start_block_num, upper_start_block);
11071107
// Change to lower num columns if applicable.
11081108
// (This should not compile to a conditional branch.)
11091109
num_columns -= (start_block_num < upper_start_block) ? 1 : 0;
11101110

1111-
const CoeffRow cr = hasher.GetCoeffRow(hash);
11121111
Index start_bit = start_slot % kCoeffBits;
11131112

1113+
Index segment_count = num_columns + (start_bit == 0 ? 0 : num_columns);
1114+
1115+
iss.PrefetchSegmentRange(segment_num, segment_num + segment_count);
1116+
1117+
*saved_hash = hash;
1118+
*saved_segment_num = segment_num;
1119+
*saved_num_columns = num_columns;
1120+
*saved_start_bit = start_bit;
1121+
}
1122+
1123+
// General PHSF query from InterleavedSolutionStorage, using data for
1124+
// the query key from InterleavedPrepareQuery
1125+
template <typename InterleavedSolutionStorage, typename PhsfQueryHasher>
1126+
inline typename InterleavedSolutionStorage::ResultRow InterleavedPhsfQuery(
1127+
typename PhsfQueryHasher::Hash hash,
1128+
typename InterleavedSolutionStorage::Index segment_num,
1129+
typename InterleavedSolutionStorage::Index num_columns,
1130+
typename InterleavedSolutionStorage::Index start_bit,
1131+
const PhsfQueryHasher &hasher, const InterleavedSolutionStorage &iss) {
1132+
using CoeffRow = typename InterleavedSolutionStorage::CoeffRow;
1133+
using Index = typename InterleavedSolutionStorage::Index;
1134+
using ResultRow = typename InterleavedSolutionStorage::ResultRow;
1135+
1136+
static_assert(sizeof(Index) == sizeof(typename PhsfQueryHasher::Index),
1137+
"must be same");
1138+
static_assert(sizeof(CoeffRow) == sizeof(typename PhsfQueryHasher::CoeffRow),
1139+
"must be same");
1140+
1141+
constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
1142+
1143+
const CoeffRow cr = hasher.GetCoeffRow(hash);
1144+
11141145
ResultRow sr = 0;
1115-
const CoeffRow cr_left = cr << start_bit;
1146+
const CoeffRow cr_left = cr << static_cast<unsigned>(start_bit);
11161147
for (Index i = 0; i < num_columns; ++i) {
1117-
sr ^= BitParity(iss.LoadSegment(segment + i) & cr_left) << i;
1148+
sr ^= BitParity(iss.LoadSegment(segment_num + i) & cr_left) << i;
11181149
}
11191150

11201151
if (start_bit > 0) {
1121-
segment += num_columns;
1122-
const CoeffRow cr_right = cr >> (kCoeffBits - start_bit);
1152+
segment_num += num_columns;
1153+
const CoeffRow cr_right =
1154+
cr >> static_cast<unsigned>(kCoeffBits - start_bit);
11231155
for (Index i = 0; i < num_columns; ++i) {
1124-
sr ^= BitParity(iss.LoadSegment(segment + i) & cr_right) << i;
1156+
sr ^= BitParity(iss.LoadSegment(segment_num + i) & cr_right) << i;
11251157
}
11261158
}
11271159

@@ -1130,12 +1162,12 @@ typename InterleavedSolutionStorage::ResultRow InterleavedPhsfQuery(
11301162

11311163
// Filter query a key from InterleavedFilterQuery.
11321164
template <typename InterleavedSolutionStorage, typename FilterQueryHasher>
1133-
bool InterleavedFilterQuery(const typename FilterQueryHasher::Key &key,
1134-
const FilterQueryHasher &hasher,
1135-
const InterleavedSolutionStorage &iss) {
1136-
// BEGIN mostly copied from InterleavedPhsfQuery
1137-
using Hash = typename FilterQueryHasher::Hash;
1138-
1165+
inline bool InterleavedFilterQuery(
1166+
typename FilterQueryHasher::Hash hash,
1167+
typename InterleavedSolutionStorage::Index segment_num,
1168+
typename InterleavedSolutionStorage::Index num_columns,
1169+
typename InterleavedSolutionStorage::Index start_bit,
1170+
const FilterQueryHasher &hasher, const InterleavedSolutionStorage &iss) {
11391171
using CoeffRow = typename InterleavedSolutionStorage::CoeffRow;
11401172
using Index = typename InterleavedSolutionStorage::Index;
11411173
using ResultRow = typename InterleavedSolutionStorage::ResultRow;
@@ -1151,41 +1183,28 @@ bool InterleavedFilterQuery(const typename FilterQueryHasher::Key &key,
11511183

11521184
constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
11531185

1154-
const Hash hash = hasher.GetHash(key);
1155-
const Index start_slot = hasher.GetStart(hash, iss.GetNumStarts());
1156-
1157-
const Index upper_start_block = iss.GetUpperStartBlock();
1158-
Index num_columns = iss.GetUpperNumColumns();
1159-
Index start_block_num = start_slot / kCoeffBits;
1160-
Index segment = start_block_num * num_columns -
1161-
std::min(start_block_num, upper_start_block);
1162-
// Change to lower num columns if applicable.
1163-
// (This should not compile to a conditional branch.)
1164-
num_columns -= (start_block_num < upper_start_block) ? 1 : 0;
1165-
11661186
const CoeffRow cr = hasher.GetCoeffRow(hash);
1167-
Index start_bit = start_slot % kCoeffBits;
1168-
// END mostly copied from InterleavedPhsfQuery.
1169-
11701187
const ResultRow expected = hasher.GetResultRowFromHash(hash);
11711188

11721189
// TODO: consider optimizations such as
1173-
// * mask fetched values and shift cr, rather than shifting fetched values
11741190
// * get rid of start_bit == 0 condition with careful fetching & shifting
11751191
if (start_bit == 0) {
11761192
for (Index i = 0; i < num_columns; ++i) {
1177-
if (BitParity(iss.LoadSegment(segment + i) & cr) !=
1193+
if (BitParity(iss.LoadSegment(segment_num + i) & cr) !=
11781194
(static_cast<int>(expected >> i) & 1)) {
11791195
return false;
11801196
}
11811197
}
11821198
} else {
1199+
const CoeffRow cr_left = cr << static_cast<unsigned>(start_bit);
1200+
const CoeffRow cr_right =
1201+
cr >> static_cast<unsigned>(kCoeffBits - start_bit);
1202+
11831203
for (Index i = 0; i < num_columns; ++i) {
1184-
CoeffRow soln_col =
1185-
(iss.LoadSegment(segment + i) >> static_cast<unsigned>(start_bit)) |
1186-
(iss.LoadSegment(segment + num_columns + i)
1187-
<< static_cast<unsigned>(kCoeffBits - start_bit));
1188-
if (BitParity(soln_col & cr) != (static_cast<int>(expected >> i) & 1)) {
1204+
CoeffRow soln_data =
1205+
(iss.LoadSegment(segment_num + i) & cr_left) ^
1206+
(iss.LoadSegment(segment_num + num_columns + i) & cr_right);
1207+
if (BitParity(soln_data) != (static_cast<int>(expected >> i) & 1)) {
11891208
return false;
11901209
}
11911210
}

src/ribbon/ribbon_impl.h

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,20 @@ class SerializableInterleavedSolution {
838838
assert(data_ != nullptr); // suppress clang analyzer report
839839
EncodeFixedGeneric(data_ + segment_num * sizeof(CoeffRow), val);
840840
}
841+
void PrefetchSegmentRange(Index begin_segment_num,
842+
Index end_segment_num) const {
843+
if (end_segment_num == begin_segment_num) {
844+
// Nothing to do
845+
return;
846+
}
847+
char* cur = data_ + begin_segment_num * sizeof(CoeffRow);
848+
char* last = data_ + (end_segment_num - 1) * sizeof(CoeffRow);
849+
while (cur < last) {
850+
PREFETCH(cur, 0 /* rw */, 1 /* locality */);
851+
cur += CACHE_LINE_SIZE;
852+
}
853+
PREFETCH(last, 0 /* rw */, 1 /* locality */);
854+
}
841855

842856
// ********************************************************************
843857
// High-level API
@@ -874,7 +888,15 @@ class SerializableInterleavedSolution {
874888
return 0;
875889
} else {
876890
// Normal
877-
return InterleavedPhsfQuery(input, hasher, *this);
891+
// NOTE: not using a struct to encourage compiler optimization
892+
Hash hash;
893+
Index segment_num;
894+
Index num_columns;
895+
Index start_bit;
896+
InterleavedPrepareQuery(input, hasher, *this, &hash, &segment_num,
897+
&num_columns, &start_bit);
898+
return InterleavedPhsfQuery(hash, segment_num, num_columns, start_bit,
899+
hasher, *this);
878900
}
879901
}
880902

@@ -887,7 +909,15 @@ class SerializableInterleavedSolution {
887909
} else {
888910
// Normal, or upper_num_columns_ == 0 means "no space for data" and
889911
// thus will always return true.
890-
return InterleavedFilterQuery(input, hasher, *this);
912+
// NOTE: not using a struct to encourage compiler optimization
913+
Hash hash;
914+
Index segment_num;
915+
Index num_columns;
916+
Index start_bit;
917+
InterleavedPrepareQuery(input, hasher, *this, &hash, &segment_num,
918+
&num_columns, &start_bit);
919+
return InterleavedFilterQuery(hash, segment_num, num_columns, start_bit,
920+
hasher, *this);
891921
}
892922
}
893923

0 commit comments

Comments
 (0)