@@ -688,7 +688,7 @@ static inline int64_t partition_avx512(type_t1 *keys,
688688}
689689
690690template  <typename  vtype, typename  type_t >
691- X86_SIMD_SORT_INLINE type_t  get_pivot (type_t  *arr,
691+ X86_SIMD_SORT_INLINE type_t  get_pivot_scalar (type_t  *arr,
692692                                      const  int64_t  left,
693693                                      const  int64_t  right)
694694{
@@ -703,9 +703,132 @@ X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr,
703703
704704    auto  vec = vtype::loadu (samples);
705705    vec = vtype::sort_vec (vec);
706-     vtype::storeu (samples, vec);
706+     return  ((type_t  *)&vec)[numSamples / 2 ];
707+ }
708+ 
709+ template  <typename  vtype, typename  reg_t >
710+ X86_SIMD_SORT_INLINE reg_t  sort_zmm_16bit (reg_t  zmm);
711+ 
712+ template  <typename  vtype, typename  reg_t >
713+ X86_SIMD_SORT_INLINE reg_t  sort_zmm_32bit (reg_t  zmm);
714+ 
715+ template  <typename  vtype, typename  reg_t >
716+ X86_SIMD_SORT_INLINE reg_t  sort_zmm_64bit (reg_t  zmm);
717+ 
718+ template  <typename  vtype, typename  type_t >
719+ X86_SIMD_SORT_INLINE type_t  get_pivot_16bit (type_t  *arr,
720+                                             const  int64_t  left,
721+                                             const  int64_t  right)
722+ {
723+     //  median of 32
724+     int64_t  size = (right - left) / 32 ;
725+     type_t  vec_arr[32 ] = {arr[left],
726+                           arr[left + size],
727+                           arr[left + 2  * size],
728+                           arr[left + 3  * size],
729+                           arr[left + 4  * size],
730+                           arr[left + 5  * size],
731+                           arr[left + 6  * size],
732+                           arr[left + 7  * size],
733+                           arr[left + 8  * size],
734+                           arr[left + 9  * size],
735+                           arr[left + 10  * size],
736+                           arr[left + 11  * size],
737+                           arr[left + 12  * size],
738+                           arr[left + 13  * size],
739+                           arr[left + 14  * size],
740+                           arr[left + 15  * size],
741+                           arr[left + 16  * size],
742+                           arr[left + 17  * size],
743+                           arr[left + 18  * size],
744+                           arr[left + 19  * size],
745+                           arr[left + 20  * size],
746+                           arr[left + 21  * size],
747+                           arr[left + 22  * size],
748+                           arr[left + 23  * size],
749+                           arr[left + 24  * size],
750+                           arr[left + 25  * size],
751+                           arr[left + 26  * size],
752+                           arr[left + 27  * size],
753+                           arr[left + 28  * size],
754+                           arr[left + 29  * size],
755+                           arr[left + 30  * size],
756+                           arr[left + 31  * size]};
757+     typename  vtype::reg_t  rand_vec = vtype::loadu (vec_arr);
758+     typename  vtype::reg_t  sort = sort_zmm_16bit<vtype>(rand_vec);
759+     return  ((type_t  *)&sort)[16 ];
760+ }
761+ 
762+ template  <typename  vtype, typename  type_t >
763+ X86_SIMD_SORT_INLINE type_t  get_pivot_32bit (type_t  *arr,
764+                                             const  int64_t  left,
765+                                             const  int64_t  right)
766+ {
767+     //  median of 16
768+     int64_t  size = (right - left) / 16 ;
769+     using  zmm_t  = typename  vtype::reg_t ;
770+     using  ymm_t  = typename  vtype::halfreg_t ;
771+     __m512i rand_index1 = _mm512_set_epi64 (left + size,
772+                                            left + 2  * size,
773+                                            left + 3  * size,
774+                                            left + 4  * size,
775+                                            left + 5  * size,
776+                                            left + 6  * size,
777+                                            left + 7  * size,
778+                                            left + 8  * size);
779+     __m512i rand_index2 = _mm512_set_epi64 (left + 9  * size,
780+                                            left + 10  * size,
781+                                            left + 11  * size,
782+                                            left + 12  * size,
783+                                            left + 13  * size,
784+                                            left + 14  * size,
785+                                            left + 15  * size,
786+                                            left + 16  * size);
787+     ymm_t  rand_vec1
788+             = vtype::template  i64gather<sizeof (type_t )>(rand_index1, arr);
789+     ymm_t  rand_vec2
790+             = vtype::template  i64gather<sizeof (type_t )>(rand_index2, arr);
791+     zmm_t  rand_vec = vtype::merge (rand_vec1, rand_vec2);
792+     zmm_t  sort = sort_zmm_32bit<vtype>(rand_vec);
793+     //  pivot will never be a nan, since there are no nan's!
794+     return  ((type_t  *)&sort)[8 ];
795+ }
796+ 
797+ template  <typename  vtype, typename  type_t >
798+ X86_SIMD_SORT_INLINE type_t  get_pivot_64bit (type_t  *arr,
799+                                             const  int64_t  left,
800+                                             const  int64_t  right)
801+ {
802+     //  median of 8
803+     int64_t  size = (right - left) / 8 ;
804+     using  zmm_t  = typename  vtype::reg_t ;
805+     __m512i rand_index = _mm512_set_epi64 (left + size,
806+                                           left + 2  * size,
807+                                           left + 3  * size,
808+                                           left + 4  * size,
809+                                           left + 5  * size,
810+                                           left + 6  * size,
811+                                           left + 7  * size,
812+                                           left + 8  * size);
813+     zmm_t  rand_vec = vtype::template  i64gather<sizeof (type_t )>(rand_index, arr);
814+     //  pivot will never be a nan, since there are no nan's!
815+     zmm_t  sort = sort_zmm_64bit<vtype>(rand_vec);
816+     return  ((type_t  *)&sort)[4 ];
817+ }
707818
708-     return  samples[numSamples / 2 ];
819+ template  <typename  vtype, typename  type_t >
820+ X86_SIMD_SORT_INLINE type_t  get_pivot (type_t  *arr,
821+                                       const  int64_t  left,
822+                                       const  int64_t  right)
823+ {
824+     if  constexpr  (vtype::numlanes == 8 )
825+         return  get_pivot_64bit<vtype>(arr, left, right);
826+     else  if  constexpr  (vtype::numlanes == 16 )
827+         return  get_pivot_32bit<vtype>(arr, left, right);
828+     else  if  constexpr  (vtype::numlanes == 32 )
829+         return  get_pivot_16bit<vtype>(arr, left, right);
830+     else 
831+         return  get_pivot_scalar<vtype>(arr, left, right);
709832}
710833
711834template  <typename  vtype, int64_t  maxN>
0 commit comments