-
Notifications
You must be signed in to change notification settings - Fork 142
Implement count
with SSE 4.2 and AVX2
#202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
76e175b
ae78b0f
71a7172
1c0899a
35532f8
13f43a4
c7481ef
90ec933
27ed60c
bba08bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
-- | | ||
-- Copyright : (c) 2021 Georg Rudoy | ||
-- License : BSD3-style (see LICENSE) | ||
-- | ||
-- Maintainer : Georg Rudoy <0xd34df00d+github@gmail.com> | ||
-- | ||
-- Benchmark count | ||
|
||
module BenchCount (benchCount) where | ||
|
||
import Test.Tasty.Bench | ||
import qualified Data.ByteString.Char8 as B | ||
|
||
benchCount :: Benchmark | ||
benchCount = bgroup "Count" | ||
[ bgroup "no matches, same char" $ mkBenches (1 : commonSizes) (\s -> B.replicate s 'b') | ||
, bgroup "no matches, different chars" $ mkBenches commonSizes (\s -> genCyclic 10 s 'b') | ||
, bgroup "some matches, alternating" $ mkBenches commonSizes (\s -> genCyclic 2 s 'a') | ||
, bgroup "some matches, short cycle" $ mkBenches commonSizes (\s -> genCyclic 5 s 'a') | ||
, bgroup "some matches, long cycle" $ mkBenches commonSizes (\s -> genCyclic 10 s 'a') | ||
, bgroup "all matches" $ mkBenches (1 : commonSizes) (\s -> B.replicate s 'a') | ||
] | ||
where | ||
aboveSimdSwitchThreshold = 1030 -- something above the threshold of 1024 that's divisible by cycle lengths | ||
commonSizes = [ 10, 100, 1000, aboveSimdSwitchThreshold, 10000, 100000, 1000000 ] | ||
mkBenches sizes gen = [ bench (show size ++ " chars long") $ nf (B.count 'a') (gen size) | ||
| size <- sizes | ||
] | ||
genCyclic cycleLen size from = B.concat $ replicate (size `div` cycleLen) $ B.pack (take cycleLen [from..]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,8 +31,15 @@ | |
|
||
#include "fpstring.h" | ||
#if defined(__x86_64__) | ||
#include <emmintrin.h> | ||
#include <xmmintrin.h> | ||
#include <x86intrin.h> | ||
#include <cpuid.h> | ||
#endif | ||
|
||
#include <stdint.h> | ||
#include <stdbool.h> | ||
|
||
#ifndef __STDC_NO_ATOMICS__ | ||
#include <stdatomic.h> | ||
#endif | ||
|
||
/* copy a string in reverse */ | ||
|
@@ -90,19 +97,190 @@ unsigned char fps_minimum(unsigned char *p, size_t len) { | |
return c; | ||
} | ||
|
||
int fps_compare(const void *a, const void *b) { | ||
return (int)*(unsigned char*)a - (int)*(unsigned char*)b; | ||
} | ||
|
||
void fps_sort(unsigned char *p, size_t len) { | ||
return qsort(p, len, 1, fps_compare); | ||
} | ||
|
||
/* count the number of occurences of a char in a string */ | ||
size_t fps_count(unsigned char *p, size_t len, unsigned char w) { | ||
size_t fps_count_naive(unsigned char *str, size_t len, unsigned char w) { | ||
size_t c; | ||
for (c = 0; len-- != 0; ++p) | ||
if (*p == w) | ||
for (c = 0; len-- != 0; ++str) | ||
if (*str == w) | ||
++c; | ||
return c; | ||
} | ||
|
||
int fps_compare(const void *a, const void *b) { | ||
return (int)*(unsigned char*)a - (int)*(unsigned char*)b; | ||
#if defined(__x86_64__) && (__GNUC__ >= 6 || defined(__clang_major__)) && !defined(__STDC_NO_ATOMICS__) | ||
#define USE_SIMD_COUNT | ||
#endif | ||
|
||
#ifdef USE_SIMD_COUNT | ||
__attribute__((target("sse4.2"))) | ||
size_t fps_count_cmpestrm(unsigned char *str, size_t len, unsigned char w) { | ||
const __m128i pat = _mm_set1_epi8(w); | ||
|
||
size_t res = 0; | ||
|
||
size_t i = 0; | ||
|
||
for (; i < len && (intptr_t)(str + i) % 64; ++i) { | ||
res += str[i] == w; | ||
} | ||
|
||
for (size_t end = len - 128; i < end; i += 128) { | ||
__m128i p0 = _mm_load_si128((const __m128i*)(str + i + 16 * 0)); | ||
__m128i p1 = _mm_load_si128((const __m128i*)(str + i + 16 * 1)); | ||
__m128i p2 = _mm_load_si128((const __m128i*)(str + i + 16 * 2)); | ||
__m128i p3 = _mm_load_si128((const __m128i*)(str + i + 16 * 3)); | ||
__m128i p4 = _mm_load_si128((const __m128i*)(str + i + 16 * 4)); | ||
__m128i p5 = _mm_load_si128((const __m128i*)(str + i + 16 * 5)); | ||
__m128i p6 = _mm_load_si128((const __m128i*)(str + i + 16 * 6)); | ||
__m128i p7 = _mm_load_si128((const __m128i*)(str + i + 16 * 7)); | ||
// Here, cmpestrm compares two strings in the following mode: | ||
// * _SIDD_SBYTE_OPS: interprets the strings as consisting of 8-bit chars, | ||
// * _SIDD_CMP_EQUAL_EACH: computes the number of `i`s | ||
// for which `p[i]`, a part of `str`, is equal to `pat[i]` | ||
// (the latter being always equal to `w`). | ||
// | ||
// q.v. https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cmpestrm&expand=835 | ||
#define MODE _SIDD_SBYTE_OPS | _SIDD_CMP_EQUAL_EACH | ||
__m128i r0 = _mm_cmpestrm(p0, 16, pat, 16, MODE); | ||
__m128i r1 = _mm_cmpestrm(p1, 16, pat, 16, MODE); | ||
__m128i r2 = _mm_cmpestrm(p2, 16, pat, 16, MODE); | ||
__m128i r3 = _mm_cmpestrm(p3, 16, pat, 16, MODE); | ||
__m128i r4 = _mm_cmpestrm(p4, 16, pat, 16, MODE); | ||
__m128i r5 = _mm_cmpestrm(p5, 16, pat, 16, MODE); | ||
__m128i r6 = _mm_cmpestrm(p6, 16, pat, 16, MODE); | ||
__m128i r7 = _mm_cmpestrm(p7, 16, pat, 16, MODE); | ||
#undef MODE | ||
res += _popcnt64(_mm_extract_epi64(r0, 0)); | ||
res += _popcnt64(_mm_extract_epi64(r1, 0)); | ||
res += _popcnt64(_mm_extract_epi64(r2, 0)); | ||
res += _popcnt64(_mm_extract_epi64(r3, 0)); | ||
res += _popcnt64(_mm_extract_epi64(r4, 0)); | ||
res += _popcnt64(_mm_extract_epi64(r5, 0)); | ||
res += _popcnt64(_mm_extract_epi64(r6, 0)); | ||
res += _popcnt64(_mm_extract_epi64(r7, 0)); | ||
} | ||
|
||
for (; i < len; ++i) { | ||
res += str[i] == w; | ||
} | ||
|
||
return res; | ||
} | ||
|
||
void fps_sort(unsigned char *p, size_t len) { | ||
return qsort(p, len, 1, fps_compare); | ||
__attribute__((target("avx2"))) | ||
size_t fps_count_avx2(unsigned char *str, size_t len, unsigned char w) { | ||
__m256i pat = _mm256_set1_epi8(w); | ||
|
||
size_t prefix = 0, res = 0; | ||
|
||
size_t i = 0; | ||
|
||
for (; i < len && (intptr_t)(str + i) % 64; ++i) { | ||
prefix += str[i] == w; | ||
} | ||
|
||
for (size_t end = len - 128; i < end; i += 128) { | ||
__m256i p0 = _mm256_load_si256((const __m256i*)(str + i + 32 * 0)); | ||
__m256i p1 = _mm256_load_si256((const __m256i*)(str + i + 32 * 1)); | ||
__m256i p2 = _mm256_load_si256((const __m256i*)(str + i + 32 * 2)); | ||
__m256i p3 = _mm256_load_si256((const __m256i*)(str + i + 32 * 3)); | ||
__m256i r0 = _mm256_cmpeq_epi8(p0, pat); | ||
__m256i r1 = _mm256_cmpeq_epi8(p1, pat); | ||
__m256i r2 = _mm256_cmpeq_epi8(p2, pat); | ||
__m256i r3 = _mm256_cmpeq_epi8(p3, pat); | ||
res += _popcnt64(_mm256_extract_epi64(r0, 0)); | ||
res += _popcnt64(_mm256_extract_epi64(r0, 1)); | ||
res += _popcnt64(_mm256_extract_epi64(r0, 2)); | ||
res += _popcnt64(_mm256_extract_epi64(r0, 3)); | ||
res += _popcnt64(_mm256_extract_epi64(r1, 0)); | ||
res += _popcnt64(_mm256_extract_epi64(r1, 1)); | ||
res += _popcnt64(_mm256_extract_epi64(r1, 2)); | ||
res += _popcnt64(_mm256_extract_epi64(r1, 3)); | ||
res += _popcnt64(_mm256_extract_epi64(r2, 0)); | ||
res += _popcnt64(_mm256_extract_epi64(r2, 1)); | ||
res += _popcnt64(_mm256_extract_epi64(r2, 2)); | ||
res += _popcnt64(_mm256_extract_epi64(r2, 3)); | ||
res += _popcnt64(_mm256_extract_epi64(r3, 0)); | ||
res += _popcnt64(_mm256_extract_epi64(r3, 1)); | ||
res += _popcnt64(_mm256_extract_epi64(r3, 2)); | ||
res += _popcnt64(_mm256_extract_epi64(r3, 3)); | ||
} | ||
|
||
// _mm256_cmpeq_epi8(p, pat) returns a SIMD vector | ||
// with `i`th byte consisting of eight `1`s if `p[i] == pat[i]`, | ||
// and of eight `0`s otherwise, | ||
// hence each matching byte is counted 8 times by popcnt. | ||
// Dividing by 8 corrects for that. | ||
res /= 8; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I needed a moment to figure why this division is needed. Maybe add a comment that r0, r1, r2, r3 have 0xff (so 8 bits) for every matching byte and 0x00 for all the rest. |
||
|
||
res += prefix; | ||
|
||
for (; i < len; ++i) { | ||
res += str[i] == w; | ||
} | ||
|
||
return res; | ||
} | ||
|
||
typedef size_t (*fps_impl_t) (unsigned char*, size_t, unsigned char); | ||
|
||
fps_impl_t select_fps_simd_impl() { | ||
uint32_t eax = 0, ebx = 0, ecx = 0, edx = 0; | ||
|
||
uint32_t ecx1 = 0; | ||
if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { | ||
ecx1 = ecx; | ||
} | ||
|
||
const bool has_xsave = ecx1 & (1 << 26); | ||
const bool has_popcnt = ecx1 & (1 << 23); | ||
|
||
if (__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx)) { | ||
const bool has_avx2 = has_xsave && (ebx & (1 << 5)); | ||
if (has_avx2 && has_popcnt) { | ||
return &fps_count_avx2; | ||
} | ||
} | ||
|
||
const bool has_sse42 = ecx1 & (1 << 19); | ||
if (has_sse42 && has_popcnt) { | ||
return &fps_count_cmpestrm; | ||
} | ||
|
||
return &fps_count_naive; | ||
} | ||
#endif | ||
|
||
|
||
|
||
size_t fps_count(unsigned char *str, size_t len, unsigned char w) { | ||
#ifndef USE_SIMD_COUNT | ||
return fps_count_naive(str, len, w); | ||
#else | ||
// 1024 is a rough guesstimate of the string length | ||
// for which the extra performance of the main SIMD loop | ||
// starts to compensate the extra work and extra branching outside the SIMD loop. | ||
// The real optimal number depends on the specific μarch | ||
// and isn't worth optimizing for in this context, | ||
// since counting characters in shorter strings is unlikely to be a hot spot. | ||
if (len <= 1024) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A comment that explains this magic 1024 would be nice to have. Did you compare different cutoff values? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I was waiting for somebody to ask! Nope, that's a rough guesstimate of how big the string should be for the SIMD stuff to compensate for all the extra work outside the main SIMD-based loop. The optimal value will very much depend on the specific hardware this is going to be run on, and, moreover, for short strings (shorter than this cutoff) the absolute difference wouldn't be that big, and I presume the user just won't care. So, given that, I wouldn't put too much thought into it. I can put a comment like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, please add a comment. 👍 |
||
return fps_count_naive(str, len, w); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh sweet, it does do the fallback for small inputs :) |
||
} | ||
|
||
static _Atomic fps_impl_t s_impl = (fps_impl_t)NULL; | ||
fps_impl_t impl = atomic_load_explicit(&s_impl, memory_order_relaxed); | ||
if (!impl) { | ||
impl = select_fps_simd_impl(); | ||
atomic_store_explicit(&s_impl, impl, memory_order_relaxed); | ||
} | ||
|
||
return (*impl)(str, len, w); | ||
#endif | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the reference documentation link https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1885,6090,5596,5653,4115,4076,4115,4938,4956,4115,5608,5656,5608,835,2426,835&text=_mm_cmpestrm for the next person who doesn't know what
_mm_cmpestrm
is.