Skip to content

Implement count with SSE 4.2 and AVX2 #202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 15, 2021
2 changes: 2 additions & 0 deletions bench/BenchAll.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import Foreign
import System.Random

import BenchBoundsCheckFusion
import BenchCount
import BenchCSV
import BenchIndices

Expand Down Expand Up @@ -437,6 +438,7 @@ main = do
, bench "map (+1) small" $ nf (S.map (+ 1)) smallTraversalInput
]
, benchBoundsCheckFusion
, benchCount
, benchCSV
, benchIndices
]
29 changes: 29 additions & 0 deletions bench/BenchCount.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
-- |
-- Copyright : (c) 2021 Georg Rudoy
-- License : BSD3-style (see LICENSE)
--
-- Maintainer : Georg Rudoy <0xd34df00d+github@gmail.com>
--
-- Benchmark count

module BenchCount (benchCount) where

import Test.Tasty.Bench
import qualified Data.ByteString.Char8 as B

benchCount :: Benchmark
benchCount = bgroup "Count"
[ bgroup "no matches, same char" $ mkBenches (1 : commonSizes) (\s -> B.replicate s 'b')
, bgroup "no matches, different chars" $ mkBenches commonSizes (\s -> genCyclic 10 s 'b')
, bgroup "some matches, alternating" $ mkBenches commonSizes (\s -> genCyclic 2 s 'a')
, bgroup "some matches, short cycle" $ mkBenches commonSizes (\s -> genCyclic 5 s 'a')
, bgroup "some matches, long cycle" $ mkBenches commonSizes (\s -> genCyclic 10 s 'a')
, bgroup "all matches" $ mkBenches (1 : commonSizes) (\s -> B.replicate s 'a')
]
where
aboveSimdSwitchThreshold = 1030 -- something above the threshold of 1024 that's divisible by cycle lengths
commonSizes = [ 10, 100, 1000, aboveSimdSwitchThreshold, 10000, 100000, 1000000 ]
mkBenches sizes gen = [ bench (show size ++ " chars long") $ nf (B.count 'a') (gen size)
| size <- sizes
]
genCyclic cycleLen size from = B.concat $ replicate (size `div` cycleLen) $ B.pack (take cycleLen [from..])
2 changes: 2 additions & 0 deletions bytestring.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ library

c-sources: cbits/fpstring.c
cbits/itoa.c
cc-options: -std=c11
include-dirs: include
includes: fpstring.h
install-includes: fpstring.h
Expand Down Expand Up @@ -167,6 +168,7 @@ test-suite test-builder
benchmark bytestring-bench
main-is: BenchAll.hs
other-modules: BenchBoundsCheckFusion
BenchCount
BenchCSV
BenchIndices
type: exitcode-stdio-1.0
Expand Down
196 changes: 187 additions & 9 deletions cbits/fpstring.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,15 @@

#include "fpstring.h"
#if defined(__x86_64__)
#include <emmintrin.h>
#include <xmmintrin.h>
#include <x86intrin.h>
#include <cpuid.h>
#endif

#include <stdint.h>
#include <stdbool.h>

#ifndef __STDC_NO_ATOMICS__
#include <stdatomic.h>
#endif

/* copy a string in reverse */
Expand Down Expand Up @@ -90,19 +97,190 @@ unsigned char fps_minimum(unsigned char *p, size_t len) {
return c;
}

int fps_compare(const void *a, const void *b) {
return (int)*(unsigned char*)a - (int)*(unsigned char*)b;
}

void fps_sort(unsigned char *p, size_t len) {
return qsort(p, len, 1, fps_compare);
}

/* count the number of occurences of a char in a string */
size_t fps_count(unsigned char *p, size_t len, unsigned char w) {
size_t fps_count_naive(unsigned char *str, size_t len, unsigned char w) {
size_t c;
for (c = 0; len-- != 0; ++p)
if (*p == w)
for (c = 0; len-- != 0; ++str)
if (*str == w)
++c;
return c;
}

int fps_compare(const void *a, const void *b) {
return (int)*(unsigned char*)a - (int)*(unsigned char*)b;
#if defined(__x86_64__) && (__GNUC__ >= 6 || defined(__clang_major__)) && !defined(__STDC_NO_ATOMICS__)
#define USE_SIMD_COUNT
#endif

#ifdef USE_SIMD_COUNT
__attribute__((target("sse4.2")))
size_t fps_count_cmpestrm(unsigned char *str, size_t len, unsigned char w) {
const __m128i pat = _mm_set1_epi8(w);

size_t res = 0;

size_t i = 0;

for (; i < len && (intptr_t)(str + i) % 64; ++i) {
res += str[i] == w;
}

for (size_t end = len - 128; i < end; i += 128) {
__m128i p0 = _mm_load_si128((const __m128i*)(str + i + 16 * 0));
__m128i p1 = _mm_load_si128((const __m128i*)(str + i + 16 * 1));
__m128i p2 = _mm_load_si128((const __m128i*)(str + i + 16 * 2));
__m128i p3 = _mm_load_si128((const __m128i*)(str + i + 16 * 3));
__m128i p4 = _mm_load_si128((const __m128i*)(str + i + 16 * 4));
__m128i p5 = _mm_load_si128((const __m128i*)(str + i + 16 * 5));
__m128i p6 = _mm_load_si128((const __m128i*)(str + i + 16 * 6));
__m128i p7 = _mm_load_si128((const __m128i*)(str + i + 16 * 7));
// Here, cmpestrm compares two strings in the following mode:
// * _SIDD_SBYTE_OPS: interprets the strings as consisting of 8-bit chars,
// * _SIDD_CMP_EQUAL_EACH: computes the number of `i`s
// for which `p[i]`, a part of `str`, is equal to `pat[i]`
// (the latter being always equal to `w`).
//
// q.v. https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cmpestrm&expand=835
#define MODE _SIDD_SBYTE_OPS | _SIDD_CMP_EQUAL_EACH
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__m128i r0 = _mm_cmpestrm(p0, 16, pat, 16, MODE);
__m128i r1 = _mm_cmpestrm(p1, 16, pat, 16, MODE);
__m128i r2 = _mm_cmpestrm(p2, 16, pat, 16, MODE);
__m128i r3 = _mm_cmpestrm(p3, 16, pat, 16, MODE);
__m128i r4 = _mm_cmpestrm(p4, 16, pat, 16, MODE);
__m128i r5 = _mm_cmpestrm(p5, 16, pat, 16, MODE);
__m128i r6 = _mm_cmpestrm(p6, 16, pat, 16, MODE);
__m128i r7 = _mm_cmpestrm(p7, 16, pat, 16, MODE);
#undef MODE
res += _popcnt64(_mm_extract_epi64(r0, 0));
res += _popcnt64(_mm_extract_epi64(r1, 0));
res += _popcnt64(_mm_extract_epi64(r2, 0));
res += _popcnt64(_mm_extract_epi64(r3, 0));
res += _popcnt64(_mm_extract_epi64(r4, 0));
res += _popcnt64(_mm_extract_epi64(r5, 0));
res += _popcnt64(_mm_extract_epi64(r6, 0));
res += _popcnt64(_mm_extract_epi64(r7, 0));
}

for (; i < len; ++i) {
res += str[i] == w;
}

return res;
}

void fps_sort(unsigned char *p, size_t len) {
return qsort(p, len, 1, fps_compare);
__attribute__((target("avx2")))
size_t fps_count_avx2(unsigned char *str, size_t len, unsigned char w) {
__m256i pat = _mm256_set1_epi8(w);

size_t prefix = 0, res = 0;

size_t i = 0;

for (; i < len && (intptr_t)(str + i) % 64; ++i) {
prefix += str[i] == w;
}

for (size_t end = len - 128; i < end; i += 128) {
__m256i p0 = _mm256_load_si256((const __m256i*)(str + i + 32 * 0));
__m256i p1 = _mm256_load_si256((const __m256i*)(str + i + 32 * 1));
__m256i p2 = _mm256_load_si256((const __m256i*)(str + i + 32 * 2));
__m256i p3 = _mm256_load_si256((const __m256i*)(str + i + 32 * 3));
__m256i r0 = _mm256_cmpeq_epi8(p0, pat);
__m256i r1 = _mm256_cmpeq_epi8(p1, pat);
__m256i r2 = _mm256_cmpeq_epi8(p2, pat);
__m256i r3 = _mm256_cmpeq_epi8(p3, pat);
res += _popcnt64(_mm256_extract_epi64(r0, 0));
res += _popcnt64(_mm256_extract_epi64(r0, 1));
res += _popcnt64(_mm256_extract_epi64(r0, 2));
res += _popcnt64(_mm256_extract_epi64(r0, 3));
res += _popcnt64(_mm256_extract_epi64(r1, 0));
res += _popcnt64(_mm256_extract_epi64(r1, 1));
res += _popcnt64(_mm256_extract_epi64(r1, 2));
res += _popcnt64(_mm256_extract_epi64(r1, 3));
res += _popcnt64(_mm256_extract_epi64(r2, 0));
res += _popcnt64(_mm256_extract_epi64(r2, 1));
res += _popcnt64(_mm256_extract_epi64(r2, 2));
res += _popcnt64(_mm256_extract_epi64(r2, 3));
res += _popcnt64(_mm256_extract_epi64(r3, 0));
res += _popcnt64(_mm256_extract_epi64(r3, 1));
res += _popcnt64(_mm256_extract_epi64(r3, 2));
res += _popcnt64(_mm256_extract_epi64(r3, 3));
}

// _mm256_cmpeq_epi8(p, pat) returns a SIMD vector
// with `i`th byte consisting of eight `1`s if `p[i] == pat[i]`,
// and of eight `0`s otherwise,
// hence each matching byte is counted 8 times by popcnt.
// Dividing by 8 corrects for that.
res /= 8;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed a moment to figure why this division is needed. Maybe add a comment that r0, r1, r2, r3 have 0xff (so 8 bits) for every matching byte and 0x00 for all the rest.


res += prefix;

for (; i < len; ++i) {
res += str[i] == w;
}

return res;
}

typedef size_t (*fps_impl_t) (unsigned char*, size_t, unsigned char);

fps_impl_t select_fps_simd_impl() {
uint32_t eax = 0, ebx = 0, ecx = 0, edx = 0;

uint32_t ecx1 = 0;
if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) {
ecx1 = ecx;
}

const bool has_xsave = ecx1 & (1 << 26);
const bool has_popcnt = ecx1 & (1 << 23);

if (__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx)) {
const bool has_avx2 = has_xsave && (ebx & (1 << 5));
if (has_avx2 && has_popcnt) {
return &fps_count_avx2;
}
}

const bool has_sse42 = ecx1 & (1 << 19);
if (has_sse42 && has_popcnt) {
return &fps_count_cmpestrm;
}

return &fps_count_naive;
}
#endif



size_t fps_count(unsigned char *str, size_t len, unsigned char w) {
#ifndef USE_SIMD_COUNT
return fps_count_naive(str, len, w);
#else
// 1024 is a rough guesstimate of the string length
// for which the extra performance of the main SIMD loop
// starts to compensate the extra work and extra branching outside the SIMD loop.
// The real optimal number depends on the specific μarch
// and isn't worth optimizing for in this context,
// since counting characters in shorter strings is unlikely to be a hot spot.
if (len <= 1024) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment that explains this magic 1024 would be nice to have. Did you compare different cutoff values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I was waiting for somebody to ask! Nope, that's a rough guesstimate of how big the string should be for the SIMD stuff to compensate for all the extra work outside the main SIMD-based loop.

The optimal value will very much depend on the specific hardware this is going to be run on, and, moreover, for short strings (shorter than this cutoff) the absolute difference wouldn't be that big, and I presume the user just won't care. So, given that, I wouldn't put too much thought into it.

I can put a comment like // a rough guesstimate if this line of thought seems reasonable to you.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please add a comment. 👍

return fps_count_naive(str, len, w);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sweet, it does do the fallback for small inputs :)

}

static _Atomic fps_impl_t s_impl = (fps_impl_t)NULL;
fps_impl_t impl = atomic_load_explicit(&s_impl, memory_order_relaxed);
if (!impl) {
impl = select_fps_simd_impl();
atomic_store_explicit(&s_impl, impl, memory_order_relaxed);
}

return (*impl)(str, len, w);
#endif
}
5 changes: 5 additions & 0 deletions tests/Properties/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ tests =
\x -> B.length x === fromIntegral (length (B.unpack x))
, testProperty "count" $
\(toElem -> c) x -> B.count c x === fromIntegral (length (elemIndices c (B.unpack x)))
-- for long strings, the multiplier is non-round (and not power of 2)
-- to ensure non-trivial prefix or suffix of the string is handled outside any possible SIMD-based loop,
-- which typically handles chunks of 16 or 32 or 64 etc bytes.
, testProperty "count (long strings)" $
\(toElem -> c) x (Positive n) -> B.count c x * fromIntegral n === B.count c (B.concat $ replicate n x)
, testProperty "filter" $
\f x -> B.unpack (B.filter f x) === filter f (B.unpack x)
, testProperty "filter compose" $
Expand Down