Skip to content

Added AVX512 support for space_l2 and space_ip. #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#define USE_SSE
#ifdef __AVX__
#define USE_AVX
#ifdef __AVX512F__
#define USE_AVX512
#endif
#endif
#endif
#endif
Expand All @@ -16,10 +19,16 @@
#include <x86intrin.h>
#endif

#if defined(USE_AVX512)
#include <immintrin.h>
#endif

#if defined(__GNUC__)
#define PORTABLE_ALIGN32 __attribute__((aligned(32)))
#define PORTABLE_ALIGN64 __attribute__((aligned(64)))
#else
#define PORTABLE_ALIGN32 __declspec(align(32))
#define PORTABLE_ALIGN64 __declspec(align(64))
#endif
#endif

Expand Down
39 changes: 36 additions & 3 deletions hnswlib/space_ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,40 @@ namespace hnswlib {

#endif

#if defined(USE_AVX)

#if defined(USE_AVX512)

static float
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN64 TmpRes[16];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);

size_t qty16 = qty / 16;


const float *pEnd1 = pVect1 + 16 * qty16;

__m512 sum512 = _mm512_set1_ps(0);

while (pVect1 < pEnd1) {
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);

__m512 v1 = _mm512_loadu_ps(pVect1);
pVect1 += 16;
__m512 v2 = _mm512_loadu_ps(pVect2);
pVect2 += 16;
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2));
}

_mm512_store_ps(TmpRes, sum512);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15];

return 1.0f - sum;
}

#elif defined(USE_AVX)

static float
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
Expand Down Expand Up @@ -211,7 +244,7 @@ namespace hnswlib {

#endif

#if defined(USE_SSE) || defined(USE_AVX)
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
static float
InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
Expand Down Expand Up @@ -249,7 +282,7 @@ namespace hnswlib {
public:
InnerProductSpace(size_t dim) {
fstdistfunc_ = InnerProduct;
#if defined(USE_AVX) || defined(USE_SSE)
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
if (dim % 16 == 0)
fstdistfunc_ = InnerProductSIMD16Ext;
else if (dim % 4 == 0)
Expand Down
42 changes: 38 additions & 4 deletions hnswlib/space_l2.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,41 @@ namespace hnswlib {
return (res);
}

#if defined(USE_AVX)
#if defined(USE_AVX512)

// Favor using AVX512 if available.
static float
L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
float PORTABLE_ALIGN64 TmpRes[16];
size_t qty16 = qty >> 4;

const float *pEnd1 = pVect1 + (qty16 << 4);

__m512 diff, v1, v2;
__m512 sum = _mm512_set1_ps(0);

while (pVect1 < pEnd1) {
v1 = _mm512_loadu_ps(pVect1);
pVect1 += 16;
v2 = _mm512_loadu_ps(pVect2);
pVect2 += 16;
diff = _mm512_sub_ps(v1, v2);
// sum = _mm512_fmadd_ps(diff, diff, sum);
sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff));
}

_mm512_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] +
TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] +
TmpRes[13] + TmpRes[14] + TmpRes[15];

return (res);
}

#elif defined(USE_AVX)

// Favor using AVX if available.
static float
Expand Down Expand Up @@ -106,7 +140,7 @@ namespace hnswlib {
}
#endif

#if defined(USE_SSE) || defined(USE_AVX)
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
static float
L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
Expand Down Expand Up @@ -174,7 +208,7 @@ namespace hnswlib {
public:
L2Space(size_t dim) {
fstdistfunc_ = L2Sqr;
#if defined(USE_SSE) || defined(USE_AVX)
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
if (dim % 16 == 0)
fstdistfunc_ = L2SqrSIMD16Ext;
else if (dim % 4 == 0)
Expand Down Expand Up @@ -278,4 +312,4 @@ namespace hnswlib {
};


}
}