@@ -1052,13 +1052,13 @@ void InterleavedBackSubst(InterleavedSolutionStorage *iss,
1052
1052
std::unique_ptr<CoeffRow[]> state{new CoeffRow[num_columns]()};
1053
1053
1054
1054
Index block = num_blocks;
1055
- Index segment = num_segments;
1055
+ Index segment_num = num_segments;
1056
1056
while (block > upper_start_block) {
1057
1057
--block;
1058
1058
BackSubstBlock (state.get (), num_columns, bs, block * kCoeffBits );
1059
- segment -= num_columns;
1059
+ segment_num -= num_columns;
1060
1060
for (Index i = 0 ; i < num_columns; ++i) {
1061
- iss->StoreSegment (segment + i, state[i]);
1061
+ iss->StoreSegment (segment_num + i, state[i]);
1062
1062
}
1063
1063
}
1064
1064
// Now (if applicable), region using lower number of columns
@@ -1068,60 +1068,92 @@ void InterleavedBackSubst(InterleavedSolutionStorage *iss,
1068
1068
while (block > 0 ) {
1069
1069
--block;
1070
1070
BackSubstBlock (state.get (), num_columns, bs, block * kCoeffBits );
1071
- segment -= num_columns;
1071
+ segment_num -= num_columns;
1072
1072
for (Index i = 0 ; i < num_columns; ++i) {
1073
- iss->StoreSegment (segment + i, state[i]);
1073
+ iss->StoreSegment (segment_num + i, state[i]);
1074
1074
}
1075
1075
}
1076
1076
// Verify everything processed
1077
1077
assert (block == 0 );
1078
- assert (segment == 0 );
1078
+ assert (segment_num == 0 );
1079
1079
}
1080
1080
1081
- // General PHSF query a key from InterleavedSolutionStorage.
1081
+ // Prefetch memory for a key in InterleavedSolutionStorage.
1082
1082
template <typename InterleavedSolutionStorage, typename PhsfQueryHasher>
1083
- typename InterleavedSolutionStorage::ResultRow InterleavedPhsfQuery (
1083
+ inline void InterleavedPrepareQuery (
1084
1084
const typename PhsfQueryHasher::Key &key, const PhsfQueryHasher &hasher,
1085
- const InterleavedSolutionStorage &iss) {
1085
+ const InterleavedSolutionStorage &iss,
1086
+ typename PhsfQueryHasher::Hash *saved_hash,
1087
+ typename InterleavedSolutionStorage::Index *saved_segment_num,
1088
+ typename InterleavedSolutionStorage::Index *saved_num_columns,
1089
+ typename InterleavedSolutionStorage::Index *saved_start_bit) {
1086
1090
using Hash = typename PhsfQueryHasher::Hash;
1087
-
1088
1091
using CoeffRow = typename InterleavedSolutionStorage::CoeffRow;
1089
1092
using Index = typename InterleavedSolutionStorage::Index;
1090
- using ResultRow = typename InterleavedSolutionStorage::ResultRow;
1091
1093
1092
1094
static_assert (sizeof (Index) == sizeof (typename PhsfQueryHasher::Index),
1093
1095
" must be same" );
1094
- static_assert (sizeof (CoeffRow) == sizeof (typename PhsfQueryHasher::CoeffRow),
1095
- " must be same" );
1096
-
1097
- constexpr auto kCoeffBits = static_cast <Index>(sizeof (CoeffRow) * 8U );
1098
1096
1099
1097
const Hash hash = hasher.GetHash (key);
1100
1098
const Index start_slot = hasher.GetStart (hash, iss.GetNumStarts ());
1101
1099
1100
+ constexpr auto kCoeffBits = static_cast <Index>(sizeof (CoeffRow) * 8U );
1101
+
1102
1102
const Index upper_start_block = iss.GetUpperStartBlock ();
1103
1103
Index num_columns = iss.GetUpperNumColumns ();
1104
1104
Index start_block_num = start_slot / kCoeffBits ;
1105
- Index segment = start_block_num * num_columns -
1105
+ Index segment_num = start_block_num * num_columns -
1106
1106
std::min (start_block_num, upper_start_block);
1107
1107
// Change to lower num columns if applicable.
1108
1108
// (This should not compile to a conditional branch.)
1109
1109
num_columns -= (start_block_num < upper_start_block) ? 1 : 0 ;
1110
1110
1111
- const CoeffRow cr = hasher.GetCoeffRow (hash);
1112
1111
Index start_bit = start_slot % kCoeffBits ;
1113
1112
1113
+ Index segment_count = num_columns + (start_bit == 0 ? 0 : num_columns);
1114
+
1115
+ iss.PrefetchSegmentRange (segment_num, segment_num + segment_count);
1116
+
1117
+ *saved_hash = hash;
1118
+ *saved_segment_num = segment_num;
1119
+ *saved_num_columns = num_columns;
1120
+ *saved_start_bit = start_bit;
1121
+ }
1122
+
1123
+ // General PHSF query from InterleavedSolutionStorage, using data for
1124
+ // the query key from InterleavedPrepareQuery
1125
+ template <typename InterleavedSolutionStorage, typename PhsfQueryHasher>
1126
+ inline typename InterleavedSolutionStorage::ResultRow InterleavedPhsfQuery (
1127
+ typename PhsfQueryHasher::Hash hash,
1128
+ typename InterleavedSolutionStorage::Index segment_num,
1129
+ typename InterleavedSolutionStorage::Index num_columns,
1130
+ typename InterleavedSolutionStorage::Index start_bit,
1131
+ const PhsfQueryHasher &hasher, const InterleavedSolutionStorage &iss) {
1132
+ using CoeffRow = typename InterleavedSolutionStorage::CoeffRow;
1133
+ using Index = typename InterleavedSolutionStorage::Index;
1134
+ using ResultRow = typename InterleavedSolutionStorage::ResultRow;
1135
+
1136
+ static_assert (sizeof (Index) == sizeof (typename PhsfQueryHasher::Index),
1137
+ " must be same" );
1138
+ static_assert (sizeof (CoeffRow) == sizeof (typename PhsfQueryHasher::CoeffRow),
1139
+ " must be same" );
1140
+
1141
+ constexpr auto kCoeffBits = static_cast <Index>(sizeof (CoeffRow) * 8U );
1142
+
1143
+ const CoeffRow cr = hasher.GetCoeffRow (hash);
1144
+
1114
1145
ResultRow sr = 0 ;
1115
- const CoeffRow cr_left = cr << start_bit;
1146
+ const CoeffRow cr_left = cr << static_cast < unsigned >( start_bit) ;
1116
1147
for (Index i = 0 ; i < num_columns; ++i) {
1117
- sr ^= BitParity (iss.LoadSegment (segment + i) & cr_left) << i;
1148
+ sr ^= BitParity (iss.LoadSegment (segment_num + i) & cr_left) << i;
1118
1149
}
1119
1150
1120
1151
if (start_bit > 0 ) {
1121
- segment += num_columns;
1122
- const CoeffRow cr_right = cr >> (kCoeffBits - start_bit);
1152
+ segment_num += num_columns;
1153
+ const CoeffRow cr_right =
1154
+ cr >> static_cast <unsigned >(kCoeffBits - start_bit);
1123
1155
for (Index i = 0 ; i < num_columns; ++i) {
1124
- sr ^= BitParity (iss.LoadSegment (segment + i) & cr_right) << i;
1156
+ sr ^= BitParity (iss.LoadSegment (segment_num + i) & cr_right) << i;
1125
1157
}
1126
1158
}
1127
1159
@@ -1130,12 +1162,12 @@ typename InterleavedSolutionStorage::ResultRow InterleavedPhsfQuery(
1130
1162
1131
1163
// Filter query a key from InterleavedFilterQuery.
1132
1164
template <typename InterleavedSolutionStorage, typename FilterQueryHasher>
1133
- bool InterleavedFilterQuery (const typename FilterQueryHasher::Key &key,
1134
- const FilterQueryHasher &hasher ,
1135
- const InterleavedSolutionStorage &iss) {
1136
- // BEGIN mostly copied from InterleavedPhsfQuery
1137
- using Hash = typename FilterQueryHasher::Hash;
1138
-
1165
+ inline bool InterleavedFilterQuery (
1166
+ typename FilterQueryHasher::Hash hash ,
1167
+ typename InterleavedSolutionStorage::Index segment_num,
1168
+ typename InterleavedSolutionStorage::Index num_columns,
1169
+ typename InterleavedSolutionStorage::Index start_bit,
1170
+ const FilterQueryHasher &hasher, const InterleavedSolutionStorage &iss) {
1139
1171
using CoeffRow = typename InterleavedSolutionStorage::CoeffRow;
1140
1172
using Index = typename InterleavedSolutionStorage::Index;
1141
1173
using ResultRow = typename InterleavedSolutionStorage::ResultRow;
@@ -1151,41 +1183,28 @@ bool InterleavedFilterQuery(const typename FilterQueryHasher::Key &key,
1151
1183
1152
1184
constexpr auto kCoeffBits = static_cast <Index>(sizeof (CoeffRow) * 8U );
1153
1185
1154
- const Hash hash = hasher.GetHash (key);
1155
- const Index start_slot = hasher.GetStart (hash, iss.GetNumStarts ());
1156
-
1157
- const Index upper_start_block = iss.GetUpperStartBlock ();
1158
- Index num_columns = iss.GetUpperNumColumns ();
1159
- Index start_block_num = start_slot / kCoeffBits ;
1160
- Index segment = start_block_num * num_columns -
1161
- std::min (start_block_num, upper_start_block);
1162
- // Change to lower num columns if applicable.
1163
- // (This should not compile to a conditional branch.)
1164
- num_columns -= (start_block_num < upper_start_block) ? 1 : 0 ;
1165
-
1166
1186
const CoeffRow cr = hasher.GetCoeffRow (hash);
1167
- Index start_bit = start_slot % kCoeffBits ;
1168
- // END mostly copied from InterleavedPhsfQuery.
1169
-
1170
1187
const ResultRow expected = hasher.GetResultRowFromHash (hash);
1171
1188
1172
1189
// TODO: consider optimizations such as
1173
- // * mask fetched values and shift cr, rather than shifting fetched values
1174
1190
// * get rid of start_bit == 0 condition with careful fetching & shifting
1175
1191
if (start_bit == 0 ) {
1176
1192
for (Index i = 0 ; i < num_columns; ++i) {
1177
- if (BitParity (iss.LoadSegment (segment + i) & cr) !=
1193
+ if (BitParity (iss.LoadSegment (segment_num + i) & cr) !=
1178
1194
(static_cast <int >(expected >> i) & 1 )) {
1179
1195
return false ;
1180
1196
}
1181
1197
}
1182
1198
} else {
1199
+ const CoeffRow cr_left = cr << static_cast <unsigned >(start_bit);
1200
+ const CoeffRow cr_right =
1201
+ cr >> static_cast <unsigned >(kCoeffBits - start_bit);
1202
+
1183
1203
for (Index i = 0 ; i < num_columns; ++i) {
1184
- CoeffRow soln_col =
1185
- (iss.LoadSegment (segment + i) >> static_cast <unsigned >(start_bit)) |
1186
- (iss.LoadSegment (segment + num_columns + i)
1187
- << static_cast <unsigned >(kCoeffBits - start_bit));
1188
- if (BitParity (soln_col & cr) != (static_cast <int >(expected >> i) & 1 )) {
1204
+ CoeffRow soln_data =
1205
+ (iss.LoadSegment (segment_num + i) & cr_left) ^
1206
+ (iss.LoadSegment (segment_num + num_columns + i) & cr_right);
1207
+ if (BitParity (soln_data) != (static_cast <int >(expected >> i) & 1 )) {
1189
1208
return false ;
1190
1209
}
1191
1210
}
0 commit comments