Skip to content

Commit

Permalink
use vpternlogq instead of and(or(A, B), C)
Browse files Browse the repository at this point in the history
  • Loading branch information
herumi committed Aug 16, 2024
1 parent f826137 commit 8c89bd4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
22 changes: 22 additions & 0 deletions src/avx512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ inline Vec vselect(const Vmask& c, const Vec& a, const Vec& b)
return vpandq(c, a, a, b);
}

template<uint8_t imm>
inline Vec vpternlogq(const Vec& a, const Vec& b, const Vec& c)
{
return _mm512_ternarylogic_epi64(a, b, c, imm);
}

/////

inline VecA vmulL(const VecA& a, const VecA& b, const VecA& c)
Expand Down Expand Up @@ -434,6 +440,22 @@ inline VecA vselect(const VmaskA& c, const Vec& a, const Vec& b)
return r;
}

template<uint8_t imm>
inline VecA vpternlogq(const VecA& a, const VecA& b, const VecA& c)
{
VecA r;
for (size_t i = 0; i < vN; i++) r.v[i] = vpternlogq<imm>(a.v[i], b.v[i], c.v[i]);
return r;
}

template<uint8_t imm>
inline VecA vpternlogq(const VecA& a, const VecA& b, const Vec& c)
{
VecA r;
for (size_t i = 0; i < vN; i++) r.v[i] = vpternlogq<imm>(a.v[i], b.v[i], c);
return r;
}

#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
Expand Down
32 changes: 23 additions & 9 deletions src/msm_avx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,32 @@ inline V getUnitAt(const V *x, size_t xN, size_t bitPos)
x|52:12|40:24|28:36|16:48|4:52:8|44:20|
y|52|52 |52 |52 |52 |52|52 |20|
*/
template<class V, size_t D=1>
template<class V>
inline void split52bit(V y[8], const V x[6])
{
assert(&y != &x);
y[0*D] = vpandq(x[0], G::mask());
y[1*D] = vpandq(vporq(vpsrlq(x[0], 52), vpsllq(x[1], 12)), G::mask());
y[2*D] = vpandq(vporq(vpsrlq(x[1], 40), vpsllq(x[2], 24)), G::mask());
y[3*D] = vpandq(vporq(vpsrlq(x[2], 28), vpsllq(x[3], 36)), G::mask());
y[4*D] = vpandq(vporq(vpsrlq(x[3], 16), vpsllq(x[4], 48)), G::mask());
y[5*D] = vpandq(vpsrlq(x[4], 4), G::mask());
y[6*D] = vpandq(vporq(vpsrlq(x[4], 56), vpsllq(x[5], 8)), G::mask());
y[7*D] = vpsrlq(x[5], 44);
#if 1
const Vec m = vpbroadcastq(getMask(52));
// and(or(A, B), C) = andCorAB = 0xa8
const uint8_t imm = 0xA8;
y[0] = vpandq(x[0], m);
y[1] = vpternlogq<imm>(vpsrlq(x[0], 52), vpsllq(x[1], 12), m);
y[2] = vpternlogq<imm>(vpsrlq(x[1], 40), vpsllq(x[2], 24), m);
y[3] = vpternlogq<imm>(vpsrlq(x[2], 28), vpsllq(x[3], 36), m);
y[4] = vpternlogq<imm>(vpsrlq(x[3], 16), vpsllq(x[4], 48), m);
y[5] = vpandq(vpsrlq(x[4], 4), m);
y[6] = vpternlogq<imm>(vpsrlq(x[4], 56), vpsllq(x[5], 8), m);
y[7] = vpsrlq(x[5], 44);
#else
y[0] = vpandq(x[0], G::mask());
y[1] = vpandq(vporq(vpsrlq(x[0], 52), vpsllq(x[1], 12)), G::mask());
y[2] = vpandq(vporq(vpsrlq(x[1], 40), vpsllq(x[2], 24)), G::mask());
y[3] = vpandq(vporq(vpsrlq(x[2], 28), vpsllq(x[3], 36)), G::mask());
y[4] = vpandq(vporq(vpsrlq(x[3], 16), vpsllq(x[4], 48)), G::mask());
y[5] = vpandq(vpsrlq(x[4], 4), G::mask());
y[6] = vpandq(vporq(vpsrlq(x[4], 56), vpsllq(x[5], 8)), G::mask());
y[7] = vpsrlq(x[5], 44);
#endif
}

/*
Expand Down

0 comments on commit 8c89bd4

Please sign in to comment.