Skip to content

Commit

Permalink
IVFPQ AVX2 optimization for PQ, including polysemous filtering (faceb…
Browse files Browse the repository at this point in the history
…ookresearch#2277)

Summary:
Pull Request resolved: facebookresearch#2277

* extend a specialized AVX2 version for IVFPQScannerT::scan_list_with_table to cover  IVFPQScannerT::scan_list_polysemous_hc as well
* lower the comparison precision in test_lowlevel_ivf tests from EXPECT_EQ to EXPECT-FLOAT_EQ because of the AVX2 change in IVFPQScannerT::scan_list_polysemous_hc, otherwise tests fail

Reviewed By: mdouze

Differential Revision: D34964138

fbshipit-source-id: 1d304a8f6eda040fa4c626676b4d492f2c12f04f
  • Loading branch information
alexanderguzhva authored and facebook-github-bot committed Mar 24, 2022
1 parent 291353c commit 438b64c
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 127 deletions.
243 changes: 117 additions & 126 deletions faiss/IndexIVFPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,156 +871,153 @@ struct IVFPQScannerT : QueryTables {
*****************************************************/

#ifdef __AVX2__
/// version of the scan where we use precomputed tables.
/// non PQDecoder8 version.
/// Returns the distance to a single code.
/// General-purpose version.
template <class SearchResultType, typename T = PQDecoder>
typename std::enable_if<!(std::is_same<T, PQDecoder8>::value), void>::type
scan_list_with_table(
size_t ncode,
const uint8_t* codes,
SearchResultType& res) const {
for (size_t j = 0; j < ncode; j++) {
PQDecoder decoder(codes, pq.nbits);
codes += pq.code_size;
float dis = dis0;
const float* tab = sim_table;
typename std::enable_if<!(std::is_same<T, PQDecoder8>::value), float>::
type inline distance_single_code(const uint8_t* code) const {
PQDecoder decoder(code, pq.nbits);

for (size_t m = 0; m < pq.M; m++) {
dis += tab[decoder.decode()];
tab += pq.ksub;
}
const float* tab = sim_table;
float result = 0;

res.add(j, dis);
for (size_t m = 0; m < pq.M; m++) {
result += tab[decoder.decode()];
tab += pq.ksub;
}

return result;
}

/// version of the scan where we use precomputed tables.
/// AVX2 PQDecoder8 version.
/// Returns the distance to a single code.
/// Specialized AVX2 PQDecoder8 version.
template <class SearchResultType, typename T = PQDecoder>
typename std::enable_if<std::is_same<T, PQDecoder8>::value, void>::type
scan_list_with_table(
size_t ncode,
const uint8_t* codes,
SearchResultType& res) const {
for (size_t j = 0; j < ncode; j++) {
float dis = dis0;

//
size_t m = 0;
const size_t pqM16 = pq.M / 16;

const float* tab = sim_table;

if (pqM16 > 0) {
// process 16 values per loop

const __m256i ksub = _mm256_set1_epi32(pq.ksub);
__m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
offsets_0 = _mm256_mullo_epi32(offsets_0, ksub);

// accumulators of partial sums
__m256 partialSum = _mm256_setzero_ps();

// loop
for (m = 0; m < pqM16 * 16; m += 16) {
// load 16 uint8 values
const __m128i mm1 =
_mm_loadu_si128((const __m128i_u*)(codes + m));
{
// convert uint8 values (low part of __m128i) to int32
// values
const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);

// add offsets
const __m256i indices_to_read_from =
_mm256_add_epi32(idx1, offsets_0);
typename std::enable_if<(std::is_same<T, PQDecoder8>::value), float>::
type inline distance_single_code(const uint8_t* code) const {
float result = 0;

size_t m = 0;
const size_t pqM16 = pq.M / 16;

const float* tab = sim_table;

if (pqM16 > 0) {
// process 16 values per loop

const __m256i ksub = _mm256_set1_epi32(pq.ksub);
__m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
offsets_0 = _mm256_mullo_epi32(offsets_0, ksub);

// accumulators of partial sums
__m256 partialSum = _mm256_setzero_ps();

// loop
for (m = 0; m < pqM16 * 16; m += 16) {
// load 16 uint8 values
const __m128i mm1 =
_mm_loadu_si128((const __m128i_u*)(code + m));
{
// convert uint8 values (low part of __m128i) to int32
// values
const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);

// add offsets
const __m256i indices_to_read_from =
_mm256_add_epi32(idx1, offsets_0);

// gather 8 values, similar to 8 operations of tab[idx]
__m256 collected = _mm256_i32gather_ps(
tab, indices_to_read_from, sizeof(float));
tab += pq.ksub * 8;

// collect partial sums
partialSum = _mm256_add_ps(partialSum, collected);
}

// gather 8 values, similar to 8 operations of tab[idx]
__m256 collected = _mm256_i32gather_ps(
tab, indices_to_read_from, sizeof(float));
tab += pq.ksub * 8;
// move high 8 uint8 to low ones
const __m128i mm2 =
_mm_unpackhi_epi64(mm1, _mm_setzero_si128());
{
// convert uint8 values (low part of __m128i) to int32
// values
const __m256i idx1 = _mm256_cvtepu8_epi32(mm2);

// add offsets
const __m256i indices_to_read_from =
_mm256_add_epi32(idx1, offsets_0);

// gather 8 values, similar to 8 operations of tab[idx]
__m256 collected = _mm256_i32gather_ps(
tab, indices_to_read_from, sizeof(float));
tab += pq.ksub * 8;

// collect partial sums
partialSum = _mm256_add_ps(partialSum, collected);
}
}

// collect partial sums
partialSum = _mm256_add_ps(partialSum, collected);
}
// horizontal sum for partialSum
const __m256 h0 = _mm256_hadd_ps(partialSum, partialSum);
const __m256 h1 = _mm256_hadd_ps(h0, h0);

// move high 8 uint8 to low ones
const __m128i mm2 =
_mm_unpackhi_epi64(mm1, _mm_setzero_si128());
{
// convert uint8 values (low part of __m128i) to int32
// values
const __m256i idx1 = _mm256_cvtepu8_epi32(mm2);

// add offsets
const __m256i indices_to_read_from =
_mm256_add_epi32(idx1, offsets_0);

// gather 8 values, similar to 8 operations of tab[idx]
__m256 collected = _mm256_i32gather_ps(
tab, indices_to_read_from, sizeof(float));
tab += pq.ksub * 8;

// collect partial sums
partialSum = _mm256_add_ps(partialSum, collected);
}
}
// extract high and low __m128 regs from __m256
const __m128 h2 = _mm256_extractf128_ps(h1, 1);
const __m128 h3 = _mm256_castps256_ps128(h1);

// horizontal sum for partialSum
const __m256 h0 = _mm256_hadd_ps(partialSum, partialSum);
const __m256 h1 = _mm256_hadd_ps(h0, h0);
// get a final hsum into all 4 regs
const __m128 h4 = _mm_add_ss(h2, h3);

// extract high and low __m128 regs from __m256
const __m128 h2 = _mm256_extractf128_ps(h1, 1);
const __m128 h3 = _mm256_castps256_ps128(h1);
// extract f[0] from __m128
const float hsum = _mm_cvtss_f32(h4);
result += hsum;
}

// get a final hsum into all 4 regs
const __m128 h4 = _mm_add_ss(h2, h3);
//
if (m < pq.M) {
// process leftovers
PQDecoder decoder(code + m, pq.nbits);

// extract f[0] from __m128
const float hsum = _mm_cvtss_f32(h4);
dis += hsum;
for (; m < pq.M; m++) {
result += tab[decoder.decode()];
tab += pq.ksub;
}
}

//
if (m < pq.M) {
// process leftovers
PQDecoder decoder(codes + m, pq.nbits);
return result;
}

for (; m < pq.M; m++) {
dis += tab[decoder.decode()];
tab += pq.ksub;
}
}
#else
/// Returns the distance to a single code.
/// General-purpose version.
template <class SearchResultType>
inline float distance_single_code(const uint8_t* code) const {
PQDecoder decoder(code, pq.nbits);

codes += pq.code_size;
const float* tab = sim_table;
float result = 0;

// done
res.add(j, dis);
for (size_t m = 0; m < pq.M; m++) {
result += tab[decoder.decode()];
tab += pq.ksub;
}

return result;
}
#else
/// version of the scan where we use precomputed tables
#endif

/// version of the scan where we use precomputed tables.
template <class SearchResultType>
void scan_list_with_table(
size_t ncode,
const uint8_t* codes,
SearchResultType& res) const {
for (size_t j = 0; j < ncode; j++) {
PQDecoder decoder(codes, pq.nbits);
float dis = dis0 + distance_single_code<SearchResultType>(codes);
codes += pq.code_size;
float dis = dis0;
const float* tab = sim_table;

for (size_t m = 0; m < pq.M; m++) {
dis += tab[decoder.decode()];
tab += pq.ksub;
}

res.add(j, dis);
}
}
#endif

/// tables are not precomputed, but pointers are provided to the
/// relevant X_c|x_r tables
Expand Down Expand Up @@ -1101,15 +1098,9 @@ struct IVFPQScannerT : QueryTables {
int hd = hc.hamming(b_code);
if (hd < ht) {
n_hamming_pass++;
PQDecoder decoder(codes, pq.nbits);

float dis = dis0;
const float* tab = sim_table;

for (size_t m = 0; m < pq.M; m++) {
dis += tab[decoder.decode()];
tab += pq.ksub;
}
float dis =
dis0 + distance_single_code<SearchResultType>(codes);

res.add(j, dis);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lowlevel_ivf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void test_lowlevel_access(const char* index_key, MetricType metric) {
float computed_D = scanner->distance_to_code(
codes.data() + vno * il->code_size);

EXPECT_EQ(computed_D, D[jj]);
EXPECT_FLOAT_EQ(computed_D, D[jj]);
}
}
}
Expand Down

0 comments on commit 438b64c

Please sign in to comment.