diff --git a/src/avx512.hpp b/src/avx512.hpp index f6d1aeac..a2751c83 100644 --- a/src/avx512.hpp +++ b/src/avx512.hpp @@ -171,6 +171,12 @@ inline Vec vselect(const Vmask& c, const Vec& a, const Vec& b) return vpandq(c, a, a, b); } +template +inline Vec vpternlogq(const Vec& a, const Vec& b, const Vec& c) +{ + return _mm512_ternarylogic_epi64(a, b, c, imm); +} + ///// inline VecA vmulL(const VecA& a, const VecA& b, const VecA& c) @@ -434,6 +440,22 @@ inline VecA vselect(const VmaskA& c, const Vec& a, const Vec& b) return r; } +template +inline VecA vpternlogq(const VecA& a, const VecA& b, const VecA& c) +{ + VecA r; + for (size_t i = 0; i < vN; i++) r.v[i] = vpternlogq(a.v[i], b.v[i], c.v[i]); + return r; +} + +template +inline VecA vpternlogq(const VecA& a, const VecA& b, const Vec& c) +{ + VecA r; + for (size_t i = 0; i < vN; i++) r.v[i] = vpternlogq(a.v[i], b.v[i], c); + return r; +} + #if defined(__GNUC__) #pragma GCC diagnostic pop #endif diff --git a/src/msm_avx.cpp b/src/msm_avx.cpp index 40f2f711..52293f6b 100644 --- a/src/msm_avx.cpp +++ b/src/msm_avx.cpp @@ -275,18 +275,32 @@ inline V getUnitAt(const V *x, size_t xN, size_t bitPos) x|52:12|40:24|28:36|16:48|4:52:8|44:20| y|52|52 |52 |52 |52 |52|52 |20| */ -template +template inline void split52bit(V y[8], const V x[6]) { assert(&y != &x); - y[0*D] = vpandq(x[0], G::mask()); - y[1*D] = vpandq(vporq(vpsrlq(x[0], 52), vpsllq(x[1], 12)), G::mask()); - y[2*D] = vpandq(vporq(vpsrlq(x[1], 40), vpsllq(x[2], 24)), G::mask()); - y[3*D] = vpandq(vporq(vpsrlq(x[2], 28), vpsllq(x[3], 36)), G::mask()); - y[4*D] = vpandq(vporq(vpsrlq(x[3], 16), vpsllq(x[4], 48)), G::mask()); - y[5*D] = vpandq(vpsrlq(x[4], 4), G::mask()); - y[6*D] = vpandq(vporq(vpsrlq(x[4], 56), vpsllq(x[5], 8)), G::mask()); - y[7*D] = vpsrlq(x[5], 44); +#if 1 + const Vec m = vpbroadcastq(getMask(52)); + // and(or(A, B), C) = andCorAB = 0xa8 + const uint8_t imm = 0xA8; + y[0] = vpandq(x[0], m); + y[1] = vpternlogq(vpsrlq(x[0], 52), vpsllq(x[1], 12), m); + y[2] = vpternlogq(vpsrlq(x[1], 40), vpsllq(x[2], 24), m); + y[3] = vpternlogq(vpsrlq(x[2], 28), vpsllq(x[3], 36), m); + y[4] = vpternlogq(vpsrlq(x[3], 16), vpsllq(x[4], 48), m); + y[5] = vpandq(vpsrlq(x[4], 4), m); + y[6] = vpternlogq(vpsrlq(x[4], 56), vpsllq(x[5], 8), m); + y[7] = vpsrlq(x[5], 44); +#else + y[0] = vpandq(x[0], G::mask()); + y[1] = vpandq(vporq(vpsrlq(x[0], 52), vpsllq(x[1], 12)), G::mask()); + y[2] = vpandq(vporq(vpsrlq(x[1], 40), vpsllq(x[2], 24)), G::mask()); + y[3] = vpandq(vporq(vpsrlq(x[2], 28), vpsllq(x[3], 36)), G::mask()); + y[4] = vpandq(vporq(vpsrlq(x[3], 16), vpsllq(x[4], 48)), G::mask()); + y[5] = vpandq(vpsrlq(x[4], 4), G::mask()); + y[6] = vpandq(vporq(vpsrlq(x[4], 56), vpsllq(x[5], 8)), G::mask()); + y[7] = vpsrlq(x[5], 44); +#endif } /*