Skip to content

Commit f028325

Browse files
committed
Changed pivot code back to previous logic for performance reasons
1 parent 0aaead0 commit f028325

File tree

1 file changed

+126
-3
lines changed

1 file changed

+126
-3
lines changed

src/avx512-common-qsort.h

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ static inline int64_t partition_avx512(type_t1 *keys,
688688
}
689689

690690
template <typename vtype, typename type_t>
691-
X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr,
691+
X86_SIMD_SORT_INLINE type_t get_pivot_scalar(type_t *arr,
692692
const int64_t left,
693693
const int64_t right)
694694
{
@@ -703,9 +703,132 @@ X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr,
703703

704704
auto vec = vtype::loadu(samples);
705705
vec = vtype::sort_vec(vec);
706-
vtype::storeu(samples, vec);
706+
return ((type_t *)&vec)[numSamples / 2];
707+
}
708+
709+
template <typename vtype, typename reg_t>
710+
X86_SIMD_SORT_INLINE reg_t sort_zmm_16bit(reg_t zmm);
711+
712+
template <typename vtype, typename reg_t>
713+
X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm);
714+
715+
template <typename vtype, typename reg_t>
716+
X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm);
717+
718+
template <typename vtype, typename type_t>
719+
X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr,
720+
const int64_t left,
721+
const int64_t right)
722+
{
723+
// median of 32
724+
int64_t size = (right - left) / 32;
725+
type_t vec_arr[32] = {arr[left],
726+
arr[left + size],
727+
arr[left + 2 * size],
728+
arr[left + 3 * size],
729+
arr[left + 4 * size],
730+
arr[left + 5 * size],
731+
arr[left + 6 * size],
732+
arr[left + 7 * size],
733+
arr[left + 8 * size],
734+
arr[left + 9 * size],
735+
arr[left + 10 * size],
736+
arr[left + 11 * size],
737+
arr[left + 12 * size],
738+
arr[left + 13 * size],
739+
arr[left + 14 * size],
740+
arr[left + 15 * size],
741+
arr[left + 16 * size],
742+
arr[left + 17 * size],
743+
arr[left + 18 * size],
744+
arr[left + 19 * size],
745+
arr[left + 20 * size],
746+
arr[left + 21 * size],
747+
arr[left + 22 * size],
748+
arr[left + 23 * size],
749+
arr[left + 24 * size],
750+
arr[left + 25 * size],
751+
arr[left + 26 * size],
752+
arr[left + 27 * size],
753+
arr[left + 28 * size],
754+
arr[left + 29 * size],
755+
arr[left + 30 * size],
756+
arr[left + 31 * size]};
757+
typename vtype::reg_t rand_vec = vtype::loadu(vec_arr);
758+
typename vtype::reg_t sort = sort_zmm_16bit<vtype>(rand_vec);
759+
return ((type_t *)&sort)[16];
760+
}
761+
762+
template <typename vtype, typename type_t>
763+
X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr,
764+
const int64_t left,
765+
const int64_t right)
766+
{
767+
// median of 16
768+
int64_t size = (right - left) / 16;
769+
using zmm_t = typename vtype::reg_t;
770+
using ymm_t = typename vtype::halfreg_t;
771+
__m512i rand_index1 = _mm512_set_epi64(left + size,
772+
left + 2 * size,
773+
left + 3 * size,
774+
left + 4 * size,
775+
left + 5 * size,
776+
left + 6 * size,
777+
left + 7 * size,
778+
left + 8 * size);
779+
__m512i rand_index2 = _mm512_set_epi64(left + 9 * size,
780+
left + 10 * size,
781+
left + 11 * size,
782+
left + 12 * size,
783+
left + 13 * size,
784+
left + 14 * size,
785+
left + 15 * size,
786+
left + 16 * size);
787+
ymm_t rand_vec1
788+
= vtype::template i64gather<sizeof(type_t)>(rand_index1, arr);
789+
ymm_t rand_vec2
790+
= vtype::template i64gather<sizeof(type_t)>(rand_index2, arr);
791+
zmm_t rand_vec = vtype::merge(rand_vec1, rand_vec2);
792+
zmm_t sort = sort_zmm_32bit<vtype>(rand_vec);
793+
// pivot will never be a nan, since there are no nan's!
794+
return ((type_t *)&sort)[8];
795+
}
796+
797+
template <typename vtype, typename type_t>
798+
X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
799+
const int64_t left,
800+
const int64_t right)
801+
{
802+
// median of 8
803+
int64_t size = (right - left) / 8;
804+
using zmm_t = typename vtype::reg_t;
805+
__m512i rand_index = _mm512_set_epi64(left + size,
806+
left + 2 * size,
807+
left + 3 * size,
808+
left + 4 * size,
809+
left + 5 * size,
810+
left + 6 * size,
811+
left + 7 * size,
812+
left + 8 * size);
813+
zmm_t rand_vec = vtype::template i64gather<sizeof(type_t)>(rand_index, arr);
814+
// pivot will never be a nan, since there are no nan's!
815+
zmm_t sort = sort_zmm_64bit<vtype>(rand_vec);
816+
return ((type_t *)&sort)[4];
817+
}
707818

708-
return samples[numSamples / 2];
819+
template <typename vtype, typename type_t>
820+
X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr,
821+
const int64_t left,
822+
const int64_t right)
823+
{
824+
if constexpr (vtype::numlanes == 8)
825+
return get_pivot_64bit<vtype>(arr, left, right);
826+
else if constexpr (vtype::numlanes == 16)
827+
return get_pivot_32bit<vtype>(arr, left, right);
828+
else if constexpr (vtype::numlanes == 32)
829+
return get_pivot_16bit<vtype>(arr, left, right);
830+
else
831+
return get_pivot_scalar<vtype>(arr, left, right);
709832
}
710833

711834
template <typename vtype, int64_t maxN>

0 commit comments

Comments
 (0)