Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions src/avx2-64bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@
#include "xss-common-qsort.h"
#include "avx2-emu-funcs.hpp"

/*
* Constants used in sorting 8 elements in a ymm registers. Based on Bitonic
* sorting network (see
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
*/
// ymm 3, 2, 1, 0
#define NETWORK_64BIT_R 0, 1, 2, 3
#define NETWORK_64BIT_1 1, 0, 3, 2

/*
* Assumes ymm is random and performs a full sorting network defined in
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
Expand Down
197 changes: 4 additions & 193 deletions src/avx512-64bit-argsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,195 +352,6 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr,
return l_store;
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE void
argsort_8_64bit(type_t *arr, arrsize_t *arg, int32_t N)
{
using reg_t = typename vtype::reg_t;
typename vtype::opmask_t load_mask = (0x01 << N) - 0x01;
argreg_t argzmm = argtype::maskz_loadu(load_mask, arg);
reg_t arrzmm = vtype::template mask_i64gather<sizeof(type_t)>(
vtype::zmm_max(), load_mask, argzmm, arr);
arrzmm = sort_zmm_64bit<vtype, argtype>(arrzmm, argzmm);
argtype::mask_storeu(arg, load_mask, argzmm);
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE void
argsort_16_64bit(type_t *arr, arrsize_t *arg, int32_t N)
{
if (N <= 8) {
argsort_8_64bit<vtype>(arr, arg, N);
return;
}
using reg_t = typename vtype::reg_t;
typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01;
argreg_t argzmm1 = argtype::loadu(arg);
argreg_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8);
reg_t arrzmm1 = vtype::i64gather(arr, arg);
reg_t arrzmm2 = vtype::template mask_i64gather<sizeof(type_t)>(
vtype::zmm_max(), load_mask, argzmm2, arr);
arrzmm1 = sort_zmm_64bit<vtype, argtype>(arrzmm1, argzmm1);
arrzmm2 = sort_zmm_64bit<vtype, argtype>(arrzmm2, argzmm2);
bitonic_merge_two_zmm_64bit<vtype, argtype>(
arrzmm1, arrzmm2, argzmm1, argzmm2);
argtype::storeu(arg, argzmm1);
argtype::mask_storeu(arg + 8, load_mask, argzmm2);
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE void
argsort_32_64bit(type_t *arr, arrsize_t *arg, int32_t N)
{
if (N <= 16) {
argsort_16_64bit<vtype>(arr, arg, N);
return;
}
using reg_t = typename vtype::reg_t;
using opmask_t = typename vtype::opmask_t;
reg_t arrzmm[4];
argreg_t argzmm[4];

X86_SIMD_SORT_UNROLL_LOOP(2)
for (int ii = 0; ii < 2; ++ii) {
argzmm[ii] = argtype::loadu(arg + 8 * ii);
arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii);
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
}

uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull;
opmask_t load_mask[2] = {0xFF, 0xFF};
X86_SIMD_SORT_UNROLL_LOOP(2)
for (int ii = 0; ii < 2; ++ii) {
load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF;
argzmm[ii + 2] = argtype::maskz_loadu(load_mask[ii], arg + 16 + 8 * ii);
arrzmm[ii + 2] = vtype::template mask_i64gather<sizeof(type_t)>(
vtype::zmm_max(), load_mask[ii], argzmm[ii + 2], arr);
arrzmm[ii + 2] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii + 2],
argzmm[ii + 2]);
}

bitonic_merge_two_zmm_64bit<vtype, argtype>(
arrzmm[0], arrzmm[1], argzmm[0], argzmm[1]);
bitonic_merge_two_zmm_64bit<vtype, argtype>(
arrzmm[2], arrzmm[3], argzmm[2], argzmm[3]);
bitonic_merge_four_zmm_64bit<vtype, argtype>(arrzmm, argzmm);

argtype::storeu(arg, argzmm[0]);
argtype::storeu(arg + 8, argzmm[1]);
argtype::mask_storeu(arg + 16, load_mask[0], argzmm[2]);
argtype::mask_storeu(arg + 24, load_mask[1], argzmm[3]);
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE void
argsort_64_64bit(type_t *arr, arrsize_t *arg, int32_t N)
{
if (N <= 32) {
argsort_32_64bit<vtype>(arr, arg, N);
return;
}
using reg_t = typename vtype::reg_t;
using opmask_t = typename vtype::opmask_t;
reg_t arrzmm[8];
argreg_t argzmm[8];

X86_SIMD_SORT_UNROLL_LOOP(4)
for (int ii = 0; ii < 4; ++ii) {
argzmm[ii] = argtype::loadu(arg + 8 * ii);
arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii);
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
}

opmask_t load_mask[4] = {0xFF, 0xFF, 0xFF, 0xFF};
uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull;
X86_SIMD_SORT_UNROLL_LOOP(4)
for (int ii = 0; ii < 4; ++ii) {
load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF;
argzmm[ii + 4] = argtype::maskz_loadu(load_mask[ii], arg + 32 + 8 * ii);
arrzmm[ii + 4] = vtype::template mask_i64gather<sizeof(type_t)>(
vtype::zmm_max(), load_mask[ii], argzmm[ii + 4], arr);
arrzmm[ii + 4] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii + 4],
argzmm[ii + 4]);
}

X86_SIMD_SORT_UNROLL_LOOP(4)
for (int ii = 0; ii < 8; ii = ii + 2) {
bitonic_merge_two_zmm_64bit<vtype, argtype>(
arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]);
}
bitonic_merge_four_zmm_64bit<vtype, argtype>(arrzmm, argzmm);
bitonic_merge_four_zmm_64bit<vtype, argtype>(arrzmm + 4, argzmm + 4);
bitonic_merge_eight_zmm_64bit<vtype, argtype>(arrzmm, argzmm);

X86_SIMD_SORT_UNROLL_LOOP(4)
for (int ii = 0; ii < 4; ++ii) {
argtype::storeu(arg + 8 * ii, argzmm[ii]);
}
X86_SIMD_SORT_UNROLL_LOOP(4)
for (int ii = 0; ii < 4; ++ii) {
argtype::mask_storeu(arg + 32 + 8 * ii, load_mask[ii], argzmm[ii + 4]);
}
}

/* arsort 128 doesn't seem to make much of a difference to perf*/
//template <typename vtype, typename type_t>
//X86_SIMD_SORT_INLINE void
//argsort_128_64bit(type_t *arr, arrsize_t *arg, int32_t N)
//{
// if (N <= 64) {
// argsort_64_64bit<vtype>(arr, arg, N);
// return;
// }
// using reg_t = typename vtype::reg_t;
// using opmask_t = typename vtype::opmask_t;
// reg_t arrzmm[16];
// argreg_t argzmm[16];
//
//X86_SIMD_SORT_UNROLL_LOOP(8)
// for (int ii = 0; ii < 8; ++ii) {
// argzmm[ii] = argtype::loadu(arg + 8*ii);
// arrzmm[ii] = vtype::i64gather(argzmm[ii], arr);
// arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
// }
//
// opmask_t load_mask[8] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF};
// if (N != 128) {
// uarrsize_t combined_mask = (0x1ull << (N - 64)) - 0x1ull;
//X86_SIMD_SORT_UNROLL_LOOP(8)
// for (int ii = 0; ii < 8; ++ii) {
// load_mask[ii] = (combined_mask >> (ii*8)) & 0xFF;
// }
// }
//X86_SIMD_SORT_UNROLL_LOOP(8)
// for (int ii = 0; ii < 8; ++ii) {
// argzmm[ii+8] = argtype::maskz_loadu(load_mask[ii], arg + 64 + 8*ii);
// arrzmm[ii+8] = vtype::template mask_i64gather<sizeof(type_t)>(vtype::zmm_max(), load_mask[ii], argzmm[ii+8], arr);
// arrzmm[ii+8] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii+8], argzmm[ii+8]);
// }
//
//X86_SIMD_SORT_UNROLL_LOOP(8)
// for (int ii = 0; ii < 16; ii = ii + 2) {
// bitonic_merge_two_zmm_64bit<vtype, argtype>(arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]);
// }
// bitonic_merge_four_zmm_64bit<vtype, argtype>(arrzmm, argzmm);
// bitonic_merge_four_zmm_64bit<vtype, argtype>(arrzmm + 4, argzmm + 4);
// bitonic_merge_four_zmm_64bit<vtype, argtype>(arrzmm + 8, argzmm + 8);
// bitonic_merge_four_zmm_64bit<vtype, argtype>(arrzmm + 12, argzmm + 12);
// bitonic_merge_eight_zmm_64bit<vtype, argtype>(arrzmm, argzmm);
// bitonic_merge_eight_zmm_64bit<vtype, argtype>(arrzmm+8, argzmm+8);
// bitonic_merge_sixteen_zmm_64bit<vtype, argtype>(arrzmm, argzmm);
//
//X86_SIMD_SORT_UNROLL_LOOP(8)
// for (int ii = 0; ii < 8; ++ii) {
// argtype::storeu(arg + 8*ii, argzmm[ii]);
// }
//X86_SIMD_SORT_UNROLL_LOOP(8)
// for (int ii = 0; ii < 8; ++ii) {
// argtype::mask_storeu(arg + 64 + 8*ii, load_mask[ii], argzmm[ii + 8]);
// }
//}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
arrsize_t *arg,
Expand Down Expand Up @@ -585,8 +396,8 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
/*
* Base case: use bitonic networks to sort arrays <= 64
*/
if (right + 1 - left <= 64) {
argsort_64_64bit<vtype>(arr, arg + left, (int32_t)(right + 1 - left));
if (right + 1 - left <= 256) {
argsort_n<vtype, 256>(arr, arg + left, (int32_t)(right + 1 - left));
return;
}
type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
Expand Down Expand Up @@ -618,8 +429,8 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
/*
* Base case: use bitonic networks to sort arrays <= 64
*/
if (right + 1 - left <= 64) {
argsort_64_64bit<vtype>(arr, arg + left, (int32_t)(right + 1 - left));
if (right + 1 - left <= 256) {
argsort_n<vtype, 256>(arr, arg + left, (int32_t)(right + 1 - left));
return;
}
type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
Expand Down
40 changes: 40 additions & 0 deletions src/avx512-64bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define AVX512_64BIT_COMMON

#include "xss-common-includes.h"
#include "avx2-32bit-qsort.hpp"

/*
* Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic
Expand Down Expand Up @@ -194,6 +195,19 @@ struct ymm_vector<float> {
{
_mm256_storeu_ps((float *)mem, x);
}
static reg_t cast_from(__m256i v)
{
return _mm256_castsi256_ps(v);
}
static __m256i cast_to(reg_t v)
{
return _mm256_castps_si256(v);
}
static reg_t reverse(reg_t ymm)
{
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
return permutexvar(rev_index, ymm);
}
};
template <>
struct ymm_vector<uint32_t> {
Expand Down Expand Up @@ -354,6 +368,19 @@ struct ymm_vector<uint32_t> {
{
_mm256_storeu_si256((__m256i *)mem, x);
}
static reg_t cast_from(__m256i v)
{
return v;
}
static __m256i cast_to(reg_t v)
{
return v;
}
static reg_t reverse(reg_t ymm)
{
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
return permutexvar(rev_index, ymm);
}
};
template <>
struct ymm_vector<int32_t> {
Expand Down Expand Up @@ -514,6 +541,19 @@ struct ymm_vector<int32_t> {
{
_mm256_storeu_si256((__m256i *)mem, x);
}
static reg_t cast_from(__m256i v)
{
return v;
}
static __m256i cast_to(reg_t v)
{
return v;
}
static reg_t reverse(reg_t ymm)
{
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
return permutexvar(rev_index, ymm);
}
};
template <>
struct zmm_vector<int64_t> {
Expand Down
Loading